Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py +656 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/coordinate_descent_tuner.py +315 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/debug.py +655 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py +678 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_operators.py +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/virtualized.py +351 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h +98 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BucketizationUtils.h +173 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvUtils.h +446 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Cross.h +14 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DistributionTemplates.h +394 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Histogram.h +16 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexKernel.h +41 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexingUtils.h +160 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitsFallback.h +157 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MaxPooling.h +97 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h +27 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Padding.h +62 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PointwiseOps.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pool.h +340 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RNN.h +53 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Repeat.h +48 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Resize.h +173 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ResizeCommon.h +75 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SharedReduceOps.h +544 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SparseTensorUtils.h +190 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/StridedRandomAccessor.h +301 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h +92 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorDimApply.h +55 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorFactories.h +142 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h +52 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorProperties.h +12 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorShape.h +105 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TopKImpl.h +98 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TransposeType.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold3d.h +49 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnfoldBackward.h +112 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UpSample.h +506 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h +48 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h +35 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h +494 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h +672 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh +681 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh +321 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh +187 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py
ADDED
|
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import dataclasses
|
| 5 |
+
import functools
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import queue
|
| 9 |
+
import time
|
| 10 |
+
import warnings
|
| 11 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 12 |
+
from ctypes import byref, c_size_t, c_void_p
|
| 13 |
+
from multiprocessing.process import BaseProcess
|
| 14 |
+
from multiprocessing.queues import Queue
|
| 15 |
+
from typing import (
|
| 16 |
+
Any,
|
| 17 |
+
Callable,
|
| 18 |
+
Dict,
|
| 19 |
+
Iterable,
|
| 20 |
+
List,
|
| 21 |
+
Optional,
|
| 22 |
+
Sequence,
|
| 23 |
+
TYPE_CHECKING,
|
| 24 |
+
Union,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from torch import multiprocessing
|
| 29 |
+
from torch._dynamo.testing import rand_strided
|
| 30 |
+
|
| 31 |
+
from torch._inductor import ir
|
| 32 |
+
from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache
|
| 33 |
+
|
| 34 |
+
if TYPE_CHECKING:
|
| 35 |
+
from torch._inductor.select_algorithm import TritonTemplateCaller
|
| 36 |
+
|
| 37 |
+
from . import config
|
| 38 |
+
from .utils import do_bench
|
| 39 |
+
from .virtualized import V
|
| 40 |
+
|
| 41 |
+
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
|
| 42 |
+
EXIT_HANDLER_REGISTERED = False
|
| 43 |
+
|
| 44 |
+
log = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Used to synchronize between parent and child processes
|
| 48 |
+
class Ping:
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Pong:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@contextlib.contextmanager
|
| 57 |
+
def set_cuda_visible_device(device: Optional[int]):
|
| 58 |
+
"""
|
| 59 |
+
Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
|
| 60 |
+
specified single device. If device is None, don't manipulate the environment.
|
| 61 |
+
"""
|
| 62 |
+
if device is None:
|
| 63 |
+
yield
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
current = os.environ.get(CUDA_VISIBLE_DEVICES)
|
| 67 |
+
os.environ[CUDA_VISIBLE_DEVICES] = str(device)
|
| 68 |
+
try:
|
| 69 |
+
yield
|
| 70 |
+
finally:
|
| 71 |
+
if current is None:
|
| 72 |
+
del os.environ[CUDA_VISIBLE_DEVICES]
|
| 73 |
+
else:
|
| 74 |
+
os.environ[CUDA_VISIBLE_DEVICES] = current
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclasses.dataclass
|
| 78 |
+
class TuningProcess:
|
| 79 |
+
"""
|
| 80 |
+
Abstraction for launching a helper process to benchmark kernels. Spawns
|
| 81 |
+
the parent process and uses multiprocessing queues to send benchmark
|
| 82 |
+
requests and return results.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
device: Optional[int] = None
|
| 86 |
+
process: Optional[BaseProcess] = None
|
| 87 |
+
request_queue: Optional[Queue[Any]] = None
|
| 88 |
+
response_queue: Optional[Queue[Any]] = None
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def process_main(
|
| 92 |
+
request_queue: Queue[Any],
|
| 93 |
+
response_queue: Queue[Any],
|
| 94 |
+
) -> None:
|
| 95 |
+
"""
|
| 96 |
+
Entry point for the child process.
|
| 97 |
+
"""
|
| 98 |
+
log.debug(
|
| 99 |
+
"Entering TuningProcess child. Visible devices = %s",
|
| 100 |
+
os.environ.get(CUDA_VISIBLE_DEVICES),
|
| 101 |
+
)
|
| 102 |
+
try:
|
| 103 |
+
TuningProcess.workloop(request_queue, response_queue)
|
| 104 |
+
except Exception as ex:
|
| 105 |
+
log.exception("Exception in TuningProcess: %s", ex)
|
| 106 |
+
|
| 107 |
+
@staticmethod
|
| 108 |
+
def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
|
| 109 |
+
"""
|
| 110 |
+
Work loop for the benchmarking subprocess.
|
| 111 |
+
"""
|
| 112 |
+
while True:
|
| 113 |
+
obj = request_queue.get()
|
| 114 |
+
|
| 115 |
+
if obj is None:
|
| 116 |
+
break # None is a sentinel for the child to terminate
|
| 117 |
+
elif isinstance(obj, Ping):
|
| 118 |
+
response_queue.put(Pong())
|
| 119 |
+
elif isinstance(obj, BenchmarkRequest):
|
| 120 |
+
response_queue.put(obj.benchmark())
|
| 121 |
+
else:
|
| 122 |
+
raise RuntimeError(f"Invalid request type {type(obj)}")
|
| 123 |
+
|
| 124 |
+
def valid(self) -> bool:
|
| 125 |
+
"""
|
| 126 |
+
True if the sub-process has been initialized.
|
| 127 |
+
"""
|
| 128 |
+
return (
|
| 129 |
+
self.process is not None
|
| 130 |
+
and self.request_queue is not None
|
| 131 |
+
and self.response_queue is not None
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def clear(self) -> None:
|
| 135 |
+
"""
|
| 136 |
+
Reset to an uninitialized state.
|
| 137 |
+
"""
|
| 138 |
+
self.process = self.request_queue = self.response_queue = None
|
| 139 |
+
|
| 140 |
+
def initialize(self) -> None:
|
| 141 |
+
"""
|
| 142 |
+
Create child process, request/response queues, and do the warm up.
|
| 143 |
+
Set the environment to make only the provided GPU device visible
|
| 144 |
+
to the process.
|
| 145 |
+
"""
|
| 146 |
+
if self.valid():
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
# cuda runtime does not work with "fork", use "spawn" to start processes.
|
| 150 |
+
ctx = multiprocessing.get_context("spawn")
|
| 151 |
+
self.request_queue = ctx.Queue()
|
| 152 |
+
self.response_queue = ctx.Queue()
|
| 153 |
+
|
| 154 |
+
self.process = ctx.Process(
|
| 155 |
+
target=self.process_main,
|
| 156 |
+
args=(
|
| 157 |
+
self.request_queue,
|
| 158 |
+
self.response_queue,
|
| 159 |
+
),
|
| 160 |
+
)
|
| 161 |
+
assert self.process is not None
|
| 162 |
+
with set_cuda_visible_device(self.device):
|
| 163 |
+
self.process.start()
|
| 164 |
+
|
| 165 |
+
def put(self, obj: Any) -> None:
|
| 166 |
+
"""
|
| 167 |
+
Push a work item to the child process.
|
| 168 |
+
"""
|
| 169 |
+
# In case of a prior crash, ensure the subprocess is running
|
| 170 |
+
self.initialize()
|
| 171 |
+
assert self.request_queue is not None
|
| 172 |
+
self.request_queue.put(obj)
|
| 173 |
+
|
| 174 |
+
def get(self) -> Any:
|
| 175 |
+
"""
|
| 176 |
+
Get a response from the child process.
|
| 177 |
+
"""
|
| 178 |
+
assert self.process is not None
|
| 179 |
+
assert self.response_queue is not None
|
| 180 |
+
while True:
|
| 181 |
+
try:
|
| 182 |
+
return self.response_queue.get(timeout=1.0)
|
| 183 |
+
except queue.Empty:
|
| 184 |
+
status = self.process.exitcode
|
| 185 |
+
if status is None:
|
| 186 |
+
# child process is still running
|
| 187 |
+
continue
|
| 188 |
+
# child process crashed
|
| 189 |
+
self.clear()
|
| 190 |
+
raise
|
| 191 |
+
|
| 192 |
+
def terminate(self) -> None:
|
| 193 |
+
"""
|
| 194 |
+
Signal the child process to terminate.
|
| 195 |
+
"""
|
| 196 |
+
if self.valid():
|
| 197 |
+
assert self.process is not None
|
| 198 |
+
assert self.request_queue is not None
|
| 199 |
+
self.request_queue.put(None)
|
| 200 |
+
|
| 201 |
+
def wait(self) -> None:
|
| 202 |
+
"""
|
| 203 |
+
Wait for the child process to exit.
|
| 204 |
+
"""
|
| 205 |
+
if self.process is not None:
|
| 206 |
+
self.process.join()
|
| 207 |
+
self.clear()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@dataclasses.dataclass
|
| 211 |
+
class TuningProcessPool:
|
| 212 |
+
"""
|
| 213 |
+
Maintains a pool of TuningProcesses to benchmark kernels in parallel
|
| 214 |
+
across devices. By default, we create one TuningProcess per device and
|
| 215 |
+
set the sub-process environment to make only that device visible.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
processes: Optional[queue.Queue[TuningProcess]] = None
|
| 219 |
+
executor: Optional[ThreadPoolExecutor] = None
|
| 220 |
+
|
| 221 |
+
def initialize(self) -> None:
|
| 222 |
+
"""
|
| 223 |
+
Start the child processes.
|
| 224 |
+
"""
|
| 225 |
+
assert (self.processes is None) == (self.executor is None)
|
| 226 |
+
if self.processes is not None:
|
| 227 |
+
return
|
| 228 |
+
|
| 229 |
+
devices = self.get_device_list()
|
| 230 |
+
log.debug("Sub-process autotune device list: %s", devices)
|
| 231 |
+
|
| 232 |
+
# Launch the child processes and push a msg to "warm up"
|
| 233 |
+
self.processes = queue.Queue()
|
| 234 |
+
for device in devices:
|
| 235 |
+
p = TuningProcess(device=device)
|
| 236 |
+
p.initialize()
|
| 237 |
+
p.put(Ping())
|
| 238 |
+
self.processes.put(p)
|
| 239 |
+
|
| 240 |
+
# Wait for the initialization to finish
|
| 241 |
+
for p in self.processes.queue:
|
| 242 |
+
assert isinstance(p.get(), Pong)
|
| 243 |
+
|
| 244 |
+
# Use a thread pool to manage distributing work to the subprocesses.
|
| 245 |
+
# Threads block on an available process, so it makes sense to match
|
| 246 |
+
# the number of threads with the number of devices.
|
| 247 |
+
self.executor = ThreadPoolExecutor(max_workers=len(devices))
|
| 248 |
+
|
| 249 |
+
# Register the exit handler for the parent process so it will terminate
|
| 250 |
+
# the child processes.
|
| 251 |
+
global EXIT_HANDLER_REGISTERED
|
| 252 |
+
if not EXIT_HANDLER_REGISTERED:
|
| 253 |
+
EXIT_HANDLER_REGISTERED = True
|
| 254 |
+
import atexit
|
| 255 |
+
|
| 256 |
+
atexit.register(self.terminate)
|
| 257 |
+
|
| 258 |
+
def get_device_list(self) -> Sequence[Optional[int]]:
|
| 259 |
+
"""
|
| 260 |
+
Gather the list of devices to be used in the pool.
|
| 261 |
+
"""
|
| 262 |
+
if not config.autotune_multi_device:
|
| 263 |
+
# Don't use multiple devices
|
| 264 |
+
return [None]
|
| 265 |
+
|
| 266 |
+
count = torch.cuda.device_count()
|
| 267 |
+
|
| 268 |
+
# If the user specified the visible devices in the env, use those.
|
| 269 |
+
if CUDA_VISIBLE_DEVICES in os.environ:
|
| 270 |
+
devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")]
|
| 271 |
+
assert len(devices) <= count
|
| 272 |
+
return devices
|
| 273 |
+
|
| 274 |
+
return list(range(count))
|
| 275 |
+
|
| 276 |
+
def terminate(self) -> None:
|
| 277 |
+
"""
|
| 278 |
+
Signal all child processes to terminate.
|
| 279 |
+
"""
|
| 280 |
+
if self.executor is not None:
|
| 281 |
+
self.executor.shutdown()
|
| 282 |
+
self.executor = None
|
| 283 |
+
|
| 284 |
+
if self.processes is not None:
|
| 285 |
+
for p in self.processes.queue:
|
| 286 |
+
p.terminate()
|
| 287 |
+
for p in self.processes.queue:
|
| 288 |
+
p.wait()
|
| 289 |
+
self.processes = None
|
| 290 |
+
|
| 291 |
+
def target(self, choice: TritonTemplateCaller) -> float:
|
| 292 |
+
"""
|
| 293 |
+
Entry point for the thread-pool helper threads: Wait for an open TuningProcess,
|
| 294 |
+
remove it from the queue, execute the benchmark in that subprocess, and return
|
| 295 |
+
the TuningProcess to the queue.
|
| 296 |
+
"""
|
| 297 |
+
assert choice.bmreq is not None
|
| 298 |
+
assert self.processes is not None
|
| 299 |
+
|
| 300 |
+
process = self.processes.get()
|
| 301 |
+
process.put(choice.bmreq)
|
| 302 |
+
try:
|
| 303 |
+
return process.get()
|
| 304 |
+
except queue.Empty:
|
| 305 |
+
warnings.warn(
|
| 306 |
+
f"Failed to benchmark choice '{choice}'. It will be ignored. "
|
| 307 |
+
"Please debug the root cause in case the choice can bring perf gains."
|
| 308 |
+
)
|
| 309 |
+
# set to INF so this choice will be ignored
|
| 310 |
+
return float("inf")
|
| 311 |
+
finally:
|
| 312 |
+
self.processes.put(process)
|
| 313 |
+
|
| 314 |
+
def benchmark(
|
| 315 |
+
self,
|
| 316 |
+
choices: List[TritonTemplateCaller],
|
| 317 |
+
) -> Dict[TritonTemplateCaller, float]:
|
| 318 |
+
"""
|
| 319 |
+
Benchmark each choice in a separate process.
|
| 320 |
+
"""
|
| 321 |
+
assert self.processes is not None, "Tuning process pool is not initialized"
|
| 322 |
+
assert self.executor is not None
|
| 323 |
+
|
| 324 |
+
results = {}
|
| 325 |
+
|
| 326 |
+
# Use a ThreadExecutorPool to spread the work across the subprocesses and
|
| 327 |
+
# to grab subprocesses as soon as they're free.
|
| 328 |
+
for choice, result in zip(choices, self.executor.map(self.target, choices)):
|
| 329 |
+
results[choice] = result
|
| 330 |
+
|
| 331 |
+
return results
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
tuning_pool = TuningProcessPool()
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
@dataclasses.dataclass
|
| 341 |
+
class TensorMeta:
|
| 342 |
+
device: torch.device
|
| 343 |
+
dtype: torch.dtype
|
| 344 |
+
sizes: torch._prims_common.ShapeType
|
| 345 |
+
strides: torch._prims_common.StrideType
|
| 346 |
+
offset: int
|
| 347 |
+
|
| 348 |
+
@classmethod
|
| 349 |
+
def from_irnodes(
|
| 350 |
+
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
|
| 351 |
+
) -> Union[TensorMeta, List[TensorMeta]]:
|
| 352 |
+
if isinstance(irnodes, Sequence):
|
| 353 |
+
result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
|
| 354 |
+
assert all(isinstance(x, TensorMeta) for x in result)
|
| 355 |
+
return result
|
| 356 |
+
|
| 357 |
+
node = irnodes
|
| 358 |
+
if isinstance(node, ir.Layout):
|
| 359 |
+
node = ir.Buffer("fake", node)
|
| 360 |
+
|
| 361 |
+
dtype = node.get_dtype()
|
| 362 |
+
assert dtype is not None
|
| 363 |
+
|
| 364 |
+
return TensorMeta(
|
| 365 |
+
device=node.get_device(),
|
| 366 |
+
dtype=dtype,
|
| 367 |
+
sizes=V.graph.sizevars.size_hints(
|
| 368 |
+
node.get_size(),
|
| 369 |
+
fallback=config.unbacked_symint_fallback,
|
| 370 |
+
),
|
| 371 |
+
strides=V.graph.sizevars.size_hints(
|
| 372 |
+
node.get_stride(),
|
| 373 |
+
fallback=config.unbacked_symint_fallback,
|
| 374 |
+
),
|
| 375 |
+
offset=V.graph.sizevars.size_hint(
|
| 376 |
+
node.get_layout().offset,
|
| 377 |
+
fallback=config.unbacked_symint_fallback,
|
| 378 |
+
),
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def to_tensor(self) -> torch.Tensor:
|
| 382 |
+
return rand_strided(
|
| 383 |
+
self.sizes,
|
| 384 |
+
self.strides,
|
| 385 |
+
device=self.device,
|
| 386 |
+
dtype=self.dtype,
|
| 387 |
+
extra_size=self.offset,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@dataclasses.dataclass
|
| 392 |
+
class BenchmarkRequest:
|
| 393 |
+
"""
|
| 394 |
+
Only handle triton template benchmark for now. The extern kernel benchmark
|
| 395 |
+
can be done inside the same process since they usually don't cause crash.
|
| 396 |
+
|
| 397 |
+
Important: Instances of this class and subclasses have to be serializable
|
| 398 |
+
across process boundaries. Do not put CUDA Tensors in here!
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
def __init__(
|
| 402 |
+
self,
|
| 403 |
+
kernel_name: str,
|
| 404 |
+
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 405 |
+
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 406 |
+
extra_args: Iterable[Any],
|
| 407 |
+
):
|
| 408 |
+
# the kernel name defined in the module
|
| 409 |
+
self.kernel_name = kernel_name
|
| 410 |
+
|
| 411 |
+
if isinstance(input_tensor_meta, TensorMeta):
|
| 412 |
+
input_tensor_meta = [input_tensor_meta]
|
| 413 |
+
self.input_tensor_meta = input_tensor_meta
|
| 414 |
+
|
| 415 |
+
if isinstance(output_tensor_meta, (tuple, list)):
|
| 416 |
+
assert len(output_tensor_meta) == 1
|
| 417 |
+
output_tensor_meta = output_tensor_meta[0]
|
| 418 |
+
self.output_tensor_meta = output_tensor_meta
|
| 419 |
+
|
| 420 |
+
self.extra_args = extra_args
|
| 421 |
+
|
| 422 |
+
def make_run_fn(
|
| 423 |
+
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
|
| 424 |
+
) -> Callable[[], None]:
|
| 425 |
+
raise NotImplementedError()
|
| 426 |
+
|
| 427 |
+
def cleanup_run_fn(self) -> None:
|
| 428 |
+
pass
|
| 429 |
+
|
| 430 |
+
def benchmark(
|
| 431 |
+
self,
|
| 432 |
+
*input_tensors: torch.Tensor,
|
| 433 |
+
output_tensor: Optional[torch.Tensor] = None,
|
| 434 |
+
) -> float:
|
| 435 |
+
debug = log.isEnabledFor(logging.DEBUG)
|
| 436 |
+
if debug:
|
| 437 |
+
start_ts = time.time()
|
| 438 |
+
|
| 439 |
+
# create args and out tensor
|
| 440 |
+
if output_tensor is None:
|
| 441 |
+
assert len(input_tensors) == 0
|
| 442 |
+
input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
|
| 443 |
+
output_tensor = self.output_tensor_meta.to_tensor()
|
| 444 |
+
|
| 445 |
+
if debug:
|
| 446 |
+
create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
| 447 |
+
start_ts = time.time()
|
| 448 |
+
|
| 449 |
+
fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
|
| 450 |
+
|
| 451 |
+
if debug:
|
| 452 |
+
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
| 453 |
+
start_ts = time.time()
|
| 454 |
+
|
| 455 |
+
out = do_bench(fn)
|
| 456 |
+
torch.cuda.synchronize() # shake out any CUDA errors
|
| 457 |
+
|
| 458 |
+
if debug:
|
| 459 |
+
bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
| 460 |
+
log.debug(
|
| 461 |
+
"InChildProcess %s: load %f, create tensor %f, bench %f",
|
| 462 |
+
str(self),
|
| 463 |
+
load_elapse, # type: ignore[possibly-undefined]
|
| 464 |
+
create_tensor_elapse, # type: ignore[possibly-undefined]
|
| 465 |
+
bench_elapse,
|
| 466 |
+
)
|
| 467 |
+
self.cleanup_run_fn()
|
| 468 |
+
return out
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class TestBenchmarkRequest(BenchmarkRequest):
|
| 472 |
+
"""
|
| 473 |
+
Supports unit testing. Defined in this file so that the TuningProcess
|
| 474 |
+
sub-process knows how to unpickle these objects.
|
| 475 |
+
"""
|
| 476 |
+
|
| 477 |
+
def __init__(self, value: Optional[float] = None) -> None:
|
| 478 |
+
self.value = value
|
| 479 |
+
|
| 480 |
+
def benchmark(
|
| 481 |
+
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
|
| 482 |
+
) -> float:
|
| 483 |
+
if self.value is None:
|
| 484 |
+
raise Exception("Failed to run")
|
| 485 |
+
return self.value
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class TritonBenchmarkRequest(BenchmarkRequest):
|
| 489 |
+
# Important: Instances of this class have to be serializable
|
| 490 |
+
# across process boundaries. Do not put CUDA Tensors in here!
|
| 491 |
+
|
| 492 |
+
def __init__(
|
| 493 |
+
self,
|
| 494 |
+
kernel_name: str,
|
| 495 |
+
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 496 |
+
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 497 |
+
extra_args: Iterable[Any],
|
| 498 |
+
module_path: str, # the path of the module defining the triton kernel
|
| 499 |
+
module_cache_key: str,
|
| 500 |
+
grid: List[int],
|
| 501 |
+
num_stages: int,
|
| 502 |
+
num_warps: int,
|
| 503 |
+
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
|
| 504 |
+
):
|
| 505 |
+
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
| 506 |
+
self.module_path = module_path
|
| 507 |
+
self.module_cache_key = module_cache_key
|
| 508 |
+
self.grid = grid
|
| 509 |
+
self.num_stages = num_stages
|
| 510 |
+
self.num_warps = num_warps
|
| 511 |
+
self.matrix_instr_nonkdim = matrix_instr_nonkdim
|
| 512 |
+
|
| 513 |
+
def make_run_fn(
|
| 514 |
+
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
|
| 515 |
+
) -> Callable[[], None]:
|
| 516 |
+
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
|
| 517 |
+
log.debug(
|
| 518 |
+
"benchmark module key: %s, path: %s",
|
| 519 |
+
self.module_cache_key,
|
| 520 |
+
self.module_path,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
run_method = getattr(mod, self.kernel_name).run
|
| 524 |
+
extra_args = list(self.extra_args)
|
| 525 |
+
|
| 526 |
+
# Newer version of triton add warmup argument to JITFunction.run.
|
| 527 |
+
# This code handles backward-compatibility.
|
| 528 |
+
warmup_arg = {}
|
| 529 |
+
import inspect
|
| 530 |
+
|
| 531 |
+
if "warmup" in inspect.signature(run_method).parameters:
|
| 532 |
+
warmup_arg["warmup"] = False
|
| 533 |
+
|
| 534 |
+
if torch.version.hip and self.matrix_instr_nonkdim != 0:
|
| 535 |
+
return functools.partial(
|
| 536 |
+
run_method,
|
| 537 |
+
*input_tensors,
|
| 538 |
+
output_tensor,
|
| 539 |
+
*self.extra_args,
|
| 540 |
+
grid=self.grid,
|
| 541 |
+
**warmup_arg,
|
| 542 |
+
num_stages=self.num_stages,
|
| 543 |
+
num_warps=self.num_warps,
|
| 544 |
+
matrix_instr_nonkdim=self.matrix_instr_nonkdim,
|
| 545 |
+
)
|
| 546 |
+
else:
|
| 547 |
+
return functools.partial(
|
| 548 |
+
run_method,
|
| 549 |
+
*input_tensors,
|
| 550 |
+
output_tensor,
|
| 551 |
+
*self.extra_args,
|
| 552 |
+
grid=self.grid,
|
| 553 |
+
**warmup_arg,
|
| 554 |
+
num_stages=self.num_stages,
|
| 555 |
+
num_warps=self.num_warps,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
def __str__(self) -> str:
|
| 559 |
+
return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class CUDABenchmarkRequest(BenchmarkRequest):
|
| 563 |
+
# Important: Instances of this class have to be serializable
|
| 564 |
+
# across process boundaries. Do not put CUDA Tensors in here!
|
| 565 |
+
|
| 566 |
+
def __init__(
|
| 567 |
+
self,
|
| 568 |
+
kernel_name: str,
|
| 569 |
+
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 570 |
+
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
| 571 |
+
extra_args: Iterable[Any],
|
| 572 |
+
source_code: str,
|
| 573 |
+
):
|
| 574 |
+
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
| 575 |
+
self.source_code = source_code
|
| 576 |
+
self.workspace_size: int = 0
|
| 577 |
+
self.workspace: Optional[torch.Tensor] = None
|
| 578 |
+
self.DLL: Optional[DLLWrapper] = None
|
| 579 |
+
self.hash_key: str = ""
|
| 580 |
+
self.source_file: str = ""
|
| 581 |
+
self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
|
| 582 |
+
|
| 583 |
+
def precompile(self):
|
| 584 |
+
# Prepopulate CUDACodeCache
|
| 585 |
+
# may happen in separate Threadpool
|
| 586 |
+
log.debug("Precompiling %s", self)
|
| 587 |
+
CUDACodeCache.load(self.source_code, "so")
|
| 588 |
+
log.debug("Done precompiling %s", self)
|
| 589 |
+
|
| 590 |
+
def make_run_fn(
|
| 591 |
+
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
|
| 592 |
+
) -> Callable[[], None]:
|
| 593 |
+
self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
|
| 594 |
+
self.source_code, "so"
|
| 595 |
+
)
|
| 596 |
+
args = [
|
| 597 |
+
c_void_p(tensor.data_ptr())
|
| 598 |
+
for tensor in list(input_tensors) + [output_tensor]
|
| 599 |
+
]
|
| 600 |
+
log.debug(
|
| 601 |
+
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
|
| 602 |
+
self.kernel_name,
|
| 603 |
+
self.source_file,
|
| 604 |
+
self.hash_key,
|
| 605 |
+
self.DLL,
|
| 606 |
+
args,
|
| 607 |
+
self.extra_args,
|
| 608 |
+
)
|
| 609 |
+
run_method = getattr(self.DLL, self.kernel_name)
|
| 610 |
+
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
|
| 611 |
+
|
| 612 |
+
# Retrieve workspace_size and initialize workspace.
|
| 613 |
+
c_workspace_size = c_size_t()
|
| 614 |
+
run_method(
|
| 615 |
+
*args, # input ptrs and output ptrs
|
| 616 |
+
*self.extra_args,
|
| 617 |
+
byref(
|
| 618 |
+
c_workspace_size
|
| 619 |
+
), # set workspace size ptr to retrieve workspace size
|
| 620 |
+
None, # null workspace ptr
|
| 621 |
+
stream_ptr,
|
| 622 |
+
)
|
| 623 |
+
self.workspace_size = c_workspace_size.value
|
| 624 |
+
# TODO: Support non-zero workspace_size.
|
| 625 |
+
assert self.workspace_size == 0, (
|
| 626 |
+
"Things need to be fixed to support non-zero workspace_size: "
|
| 627 |
+
"1) max autotune cache needs to store workspace size; "
|
| 628 |
+
"2) memory allocation needs to allocate / deallocate workspace correctly; "
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Generate partial function.
|
| 632 |
+
return functools.partial(
|
| 633 |
+
run_method,
|
| 634 |
+
*args,
|
| 635 |
+
*self.extra_args,
|
| 636 |
+
None, # null workspace size ptr
|
| 637 |
+
None, # set workspace ptr, TODO: update it to a real ptr if workspace_size > 0
|
| 638 |
+
stream_ptr,
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
def cleanup_run_fn(self) -> None:
|
| 642 |
+
if self.DLL is not None:
|
| 643 |
+
self.DLL.close()
|
| 644 |
+
self.workspace = None
|
| 645 |
+
|
| 646 |
+
def __str__(self) -> str:
|
| 647 |
+
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def benchmark_in_sub_process(
|
| 651 |
+
choices: List[TritonTemplateCaller],
|
| 652 |
+
) -> Dict[TritonTemplateCaller, float]:
|
| 653 |
+
"""
|
| 654 |
+
Do benchmarking in a subprocess and return the perf number (latency).
|
| 655 |
+
"""
|
| 656 |
+
return tuning_pool.benchmark(choices)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/coordinate_descent_tuner.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import itertools
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Callable, Optional
|
| 5 |
+
|
| 6 |
+
from torch.utils._triton import has_triton
|
| 7 |
+
from .utils import red_text, triton_config_to_hashable
|
| 8 |
+
|
| 9 |
+
if has_triton():
|
| 10 |
+
import triton
|
| 11 |
+
else:
|
| 12 |
+
triton = None
|
| 13 |
+
|
| 14 |
+
from . import config as inductor_config
|
| 15 |
+
|
| 16 |
+
log = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_field(config, name):
|
| 20 |
+
if name == "num_warps":
|
| 21 |
+
return config.num_warps
|
| 22 |
+
elif name == "num_stages":
|
| 23 |
+
return config.num_stages
|
| 24 |
+
else:
|
| 25 |
+
return config.kwargs.get(name, None)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def set_field(config, name, value):
|
| 29 |
+
if name == "num_warps":
|
| 30 |
+
config.num_warps = value
|
| 31 |
+
elif name == "num_stages":
|
| 32 |
+
config.num_stages = value
|
| 33 |
+
else:
|
| 34 |
+
config.kwargs[name] = value
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CoordescTuner:
|
| 38 |
+
"""
|
| 39 |
+
The coordinate descent tuner. Tune one field/coordinate at a time.
|
| 40 |
+
|
| 41 |
+
TODO will it be necessary to tune multiple fields simultaneously.
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
TODO: what if both increasing and decreasing a field can improve perf.
|
| 45 |
+
i.e., there are multiple local optima..
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, is_mm=False, name="unknown", size_hints=None):
|
| 49 |
+
self.is_mm = is_mm # we will tune num_stages for mm
|
| 50 |
+
self.cached_benchmark_results = {}
|
| 51 |
+
self.name = name
|
| 52 |
+
self.size_hints = size_hints
|
| 53 |
+
|
| 54 |
+
def get_xmax(self):
|
| 55 |
+
xmax = inductor_config.triton.max_block["X"]
|
| 56 |
+
if self.size_hints and len(self.size_hints) > 0:
|
| 57 |
+
xmax = min(xmax, self.size_hints[0])
|
| 58 |
+
return xmax
|
| 59 |
+
|
| 60 |
+
def get_ymax(self):
|
| 61 |
+
ymax = inductor_config.triton.max_block["Y"]
|
| 62 |
+
if self.size_hints and len(self.size_hints) > 1:
|
| 63 |
+
ymax = min(ymax, self.size_hints[1])
|
| 64 |
+
return ymax
|
| 65 |
+
|
| 66 |
+
def get_zmax(self):
|
| 67 |
+
zmax = inductor_config.triton.max_block["Z"]
|
| 68 |
+
if self.size_hints and len(self.size_hints) > 2:
|
| 69 |
+
zmax = min(zmax, self.size_hints[2])
|
| 70 |
+
return zmax
|
| 71 |
+
|
| 72 |
+
def get_rmax(self):
|
| 73 |
+
if self.size_hints and len(self.size_hints) > 0:
|
| 74 |
+
return self.size_hints[-1] # the last one is for reduction
|
| 75 |
+
else:
|
| 76 |
+
# large enough. We should not pick this large RBLOCK anyway
|
| 77 |
+
return 2**30
|
| 78 |
+
|
| 79 |
+
def get_warpsmax(self):
|
| 80 |
+
# Currently, CUDA has a maximum of 1024 threads, so 32 is the max
|
| 81 |
+
# number of warps.
|
| 82 |
+
return 1024 // 32
|
| 83 |
+
|
| 84 |
+
def cache_benchmark_result(self, config, timing):
|
| 85 |
+
self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
|
| 86 |
+
|
| 87 |
+
def lookup_in_cache(self, config):
|
| 88 |
+
return self.cached_benchmark_results.get(triton_config_to_hashable(config))
|
| 89 |
+
|
| 90 |
+
def call_func(self, func, config):
|
| 91 |
+
found = self.lookup_in_cache(config)
|
| 92 |
+
if found is not None:
|
| 93 |
+
log.debug(" CACHED")
|
| 94 |
+
return found
|
| 95 |
+
timing = func(config)
|
| 96 |
+
self.cache_benchmark_result(config, timing)
|
| 97 |
+
return timing
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def tunable_fields(self):
|
| 101 |
+
out = [
|
| 102 |
+
"XBLOCK",
|
| 103 |
+
"YBLOCK",
|
| 104 |
+
"ZBLOCK",
|
| 105 |
+
# NOTE: we should not tune RBLOCK for persistent reduction.
|
| 106 |
+
# We rely on the fact that persistent reduction's triton.Config
|
| 107 |
+
# does not have the RBLOCK field to guarantee that.
|
| 108 |
+
"RBLOCK",
|
| 109 |
+
# the following 3 are for mm
|
| 110 |
+
"BLOCK_M",
|
| 111 |
+
"BLOCK_N",
|
| 112 |
+
"BLOCK_K",
|
| 113 |
+
"num_warps",
|
| 114 |
+
]
|
| 115 |
+
if self.is_mm:
|
| 116 |
+
out.append("num_stages")
|
| 117 |
+
|
| 118 |
+
return out
|
| 119 |
+
|
| 120 |
+
def value_too_large(self, name, val):
|
| 121 |
+
if name == "XBLOCK":
|
| 122 |
+
return val > self.get_xmax()
|
| 123 |
+
if name == "YBLOCK":
|
| 124 |
+
return val > self.get_ymax()
|
| 125 |
+
if name == "ZBLOCK":
|
| 126 |
+
return val > self.get_zmax()
|
| 127 |
+
if name == "RBLOCK":
|
| 128 |
+
return val > self.get_rmax()
|
| 129 |
+
if name == "num_warps":
|
| 130 |
+
return val > self.get_warpsmax()
|
| 131 |
+
|
| 132 |
+
return False
|
| 133 |
+
|
| 134 |
+
def get_neighbour_values(self, name, orig_val, radius=1, include_self=False):
|
| 135 |
+
"""
|
| 136 |
+
Get neighbour values in 'radius' steps. The original value is not
|
| 137 |
+
returned as it's own neighbour.
|
| 138 |
+
"""
|
| 139 |
+
assert radius >= 1
|
| 140 |
+
|
| 141 |
+
def update(cur_val, inc=True):
|
| 142 |
+
if name == "num_stages":
|
| 143 |
+
if inc:
|
| 144 |
+
return cur_val + 1
|
| 145 |
+
else:
|
| 146 |
+
return cur_val - 1
|
| 147 |
+
else:
|
| 148 |
+
if inc:
|
| 149 |
+
return cur_val * 2
|
| 150 |
+
else:
|
| 151 |
+
return cur_val // 2
|
| 152 |
+
|
| 153 |
+
out = []
|
| 154 |
+
# increment loop
|
| 155 |
+
cur_val = orig_val
|
| 156 |
+
for _ in range(radius):
|
| 157 |
+
cur_val = update(cur_val, True)
|
| 158 |
+
if self.value_too_large(name, cur_val):
|
| 159 |
+
break
|
| 160 |
+
out.append(cur_val)
|
| 161 |
+
|
| 162 |
+
# decrement loop
|
| 163 |
+
cur_val = orig_val
|
| 164 |
+
for _ in range(radius):
|
| 165 |
+
cur_val = update(cur_val, False)
|
| 166 |
+
if cur_val <= 0:
|
| 167 |
+
break
|
| 168 |
+
out.append(cur_val)
|
| 169 |
+
|
| 170 |
+
if include_self:
|
| 171 |
+
out.append(orig_val)
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def has_improvement(baseline, test):
|
| 176 |
+
threshold = 0.001 # 0.1%
|
| 177 |
+
return test is not None and test < baseline * (1 - threshold)
|
| 178 |
+
|
| 179 |
+
def check_all_tuning_directions(
|
| 180 |
+
self,
|
| 181 |
+
func: Callable[["triton.Config"], float],
|
| 182 |
+
best_config,
|
| 183 |
+
best_timing,
|
| 184 |
+
):
|
| 185 |
+
"""
|
| 186 |
+
Check all directions. We only do this once the regular coordinate
|
| 187 |
+
descent tuning find no better choices any more.
|
| 188 |
+
We only have a few tunable fields, so this should be fine.
|
| 189 |
+
"""
|
| 190 |
+
candidate_values_list = []
|
| 191 |
+
effective_fields = []
|
| 192 |
+
for field in self.tunable_fields:
|
| 193 |
+
old_value = get_field(best_config, field)
|
| 194 |
+
if old_value is None:
|
| 195 |
+
continue
|
| 196 |
+
candidate_values = self.get_neighbour_values(
|
| 197 |
+
field,
|
| 198 |
+
old_value,
|
| 199 |
+
radius=inductor_config.coordinate_descent_search_radius,
|
| 200 |
+
include_self=True,
|
| 201 |
+
)
|
| 202 |
+
candidate_values_list.append(candidate_values)
|
| 203 |
+
effective_fields.append(field)
|
| 204 |
+
|
| 205 |
+
choices = itertools.product(*candidate_values_list)
|
| 206 |
+
improved = False
|
| 207 |
+
for choice in choices:
|
| 208 |
+
assert len(choice) == len(effective_fields)
|
| 209 |
+
candidate_config = copy.deepcopy(best_config)
|
| 210 |
+
for new_val, field in zip(choice, effective_fields):
|
| 211 |
+
set_field(candidate_config, field, new_val)
|
| 212 |
+
cmp_res, candidate_timing = self.compare_config(
|
| 213 |
+
func, candidate_config, best_config, best_timing
|
| 214 |
+
)
|
| 215 |
+
if cmp_res:
|
| 216 |
+
improved = True
|
| 217 |
+
best_config = candidate_config
|
| 218 |
+
best_timing = candidate_timing
|
| 219 |
+
|
| 220 |
+
return improved, best_config, best_timing
|
| 221 |
+
|
| 222 |
+
def compare_config(self, func, candidate_config, best_config, best_timing):
|
| 223 |
+
"""
|
| 224 |
+
Check if candidate_config is better than best_config.
|
| 225 |
+
|
| 226 |
+
Return a touple of (compare_result, candidate_timing).
|
| 227 |
+
compare_result is true iff candidate_config is better.
|
| 228 |
+
"""
|
| 229 |
+
log.debug("Try config %s", candidate_config)
|
| 230 |
+
try:
|
| 231 |
+
candidate_timing = self.call_func(func, candidate_config)
|
| 232 |
+
except Exception as e:
|
| 233 |
+
log.debug("Got exception %s", e)
|
| 234 |
+
return False, float("inf")
|
| 235 |
+
|
| 236 |
+
if self.has_improvement(best_timing, candidate_timing):
|
| 237 |
+
log.debug(
|
| 238 |
+
"Tune from %s %f -> %s %f",
|
| 239 |
+
best_config,
|
| 240 |
+
best_timing,
|
| 241 |
+
candidate_config,
|
| 242 |
+
candidate_timing,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
return True, candidate_timing
|
| 246 |
+
return False, candidate_timing
|
| 247 |
+
|
| 248 |
+
def autotune(
|
| 249 |
+
self,
|
| 250 |
+
func: Callable[["triton.Config"], float],
|
| 251 |
+
baseline_config: "triton.Config",
|
| 252 |
+
baseline_timing: Optional[float] = None,
|
| 253 |
+
) -> "triton.Config":
|
| 254 |
+
if baseline_timing is None:
|
| 255 |
+
baseline_timing = self.call_func(func, baseline_config)
|
| 256 |
+
|
| 257 |
+
log.debug("= Do coordinate descent tuning for %s =", self.name)
|
| 258 |
+
log.debug(
|
| 259 |
+
"Baseline Config %s, baseline timing %f", baseline_config, baseline_timing
|
| 260 |
+
)
|
| 261 |
+
improved = True
|
| 262 |
+
best_config = baseline_config
|
| 263 |
+
best_timing = baseline_timing
|
| 264 |
+
tunable_fields = self.tunable_fields
|
| 265 |
+
|
| 266 |
+
while improved:
|
| 267 |
+
improved = False
|
| 268 |
+
|
| 269 |
+
for name in tunable_fields:
|
| 270 |
+
cur_val = get_field(best_config, name)
|
| 271 |
+
# some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None
|
| 272 |
+
if cur_val is None:
|
| 273 |
+
continue
|
| 274 |
+
|
| 275 |
+
# It's possible that candidate_values is empty.
|
| 276 |
+
# E.g., if XBLOCK is 1 initially and size_hint for x is also 1.
|
| 277 |
+
# We would not try either larger or smaller XBLOCK in this case.
|
| 278 |
+
candidate_values = self.get_neighbour_values(name, cur_val)
|
| 279 |
+
|
| 280 |
+
for next_val in candidate_values:
|
| 281 |
+
candidate_config = copy.deepcopy(best_config)
|
| 282 |
+
set_field(candidate_config, name, next_val)
|
| 283 |
+
|
| 284 |
+
cmp_res, candidate_timing = self.compare_config(
|
| 285 |
+
func, candidate_config, best_config, best_timing
|
| 286 |
+
)
|
| 287 |
+
if cmp_res:
|
| 288 |
+
improved = True
|
| 289 |
+
best_config, best_timing = candidate_config, candidate_timing
|
| 290 |
+
|
| 291 |
+
if not improved and inductor_config.coordinate_descent_check_all_directions:
|
| 292 |
+
old_best_timing = best_timing
|
| 293 |
+
improved, best_config, best_timing = self.check_all_tuning_directions(
|
| 294 |
+
func, best_config, best_timing
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if improved:
|
| 298 |
+
msg = red_text(
|
| 299 |
+
"Coordinate descend tuning found improvement of %.3fx by looking in all directions."
|
| 300 |
+
)
|
| 301 |
+
log.debug(
|
| 302 |
+
msg,
|
| 303 |
+
old_best_timing / best_timing,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
log.debug(
|
| 307 |
+
"Improve from %s %f -> %s %f, %.3fx",
|
| 308 |
+
baseline_config,
|
| 309 |
+
baseline_timing,
|
| 310 |
+
best_config,
|
| 311 |
+
best_timing,
|
| 312 |
+
baseline_timing / best_timing,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return best_config
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/debug.py
ADDED
|
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import contextlib
|
| 3 |
+
import cProfile
|
| 4 |
+
import dataclasses
|
| 5 |
+
import functools
|
| 6 |
+
import itertools
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import os.path
|
| 10 |
+
import pickle
|
| 11 |
+
import pstats
|
| 12 |
+
import shutil
|
| 13 |
+
import subprocess
|
| 14 |
+
from typing import Any, Dict, List, Optional
|
| 15 |
+
from unittest.mock import patch
|
| 16 |
+
|
| 17 |
+
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch import fx as fx
|
| 21 |
+
|
| 22 |
+
from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
|
| 23 |
+
from torch._dynamo.utils import get_debug_dir
|
| 24 |
+
from torch.fx.graph_module import GraphModule
|
| 25 |
+
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
| 26 |
+
from torch.fx.passes.tools_common import legalize_graph
|
| 27 |
+
from torch.utils._pytree import tree_map
|
| 28 |
+
|
| 29 |
+
from . import config, ir # noqa: F811, this is needed
|
| 30 |
+
from .scheduler import (
|
| 31 |
+
BaseSchedulerNode,
|
| 32 |
+
FusedSchedulerNode,
|
| 33 |
+
NopKernelSchedulerNode,
|
| 34 |
+
OutputNode,
|
| 35 |
+
SchedulerNode,
|
| 36 |
+
)
|
| 37 |
+
from .virtualized import V
|
| 38 |
+
|
| 39 |
+
log = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
SchedulerNodeList = List[Any]
|
| 42 |
+
BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
|
| 43 |
+
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@functools.lru_cache(None)
|
| 47 |
+
def has_dot() -> bool:
|
| 48 |
+
try:
|
| 49 |
+
subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
|
| 50 |
+
return True
|
| 51 |
+
except subprocess.SubprocessError:
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def draw_buffers(nodes: List[BaseSchedulerNode], print_graph=False, fname=None):
|
| 56 |
+
"""
|
| 57 |
+
Draw a graph in fname.svg.
|
| 58 |
+
"""
|
| 59 |
+
if not has_dot():
|
| 60 |
+
log.warning("draw_buffers() requires `graphviz` package")
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
if fname is None:
|
| 64 |
+
fname = get_graph_being_compiled()
|
| 65 |
+
|
| 66 |
+
graph = create_fx_from_snodes(nodes)
|
| 67 |
+
|
| 68 |
+
for node in graph.nodes:
|
| 69 |
+
if "fusion_meta" not in node.meta:
|
| 70 |
+
continue
|
| 71 |
+
group = node.meta["fusion_meta"].group
|
| 72 |
+
if isinstance(group, tuple):
|
| 73 |
+
if isinstance(group[1], int):
|
| 74 |
+
group = (group[1],)
|
| 75 |
+
else:
|
| 76 |
+
group = group[1]
|
| 77 |
+
|
| 78 |
+
# gather meta data
|
| 79 |
+
dtype = None
|
| 80 |
+
if isinstance(node, ir.ComputedBuffer):
|
| 81 |
+
dtype = node.data.dtype
|
| 82 |
+
|
| 83 |
+
metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type]
|
| 84 |
+
node.meta["tensor_meta"] = metadata
|
| 85 |
+
|
| 86 |
+
if print_graph:
|
| 87 |
+
print(graph)
|
| 88 |
+
|
| 89 |
+
gm = GraphModule({}, graph)
|
| 90 |
+
legalize_graph(gm)
|
| 91 |
+
gm.graph.lint()
|
| 92 |
+
draw_graph(
|
| 93 |
+
gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
| 98 |
+
"""
|
| 99 |
+
Creates a FX Graph from a list of SchedulerNode objects.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def get_fake_func(name):
|
| 103 |
+
def func1(*args):
|
| 104 |
+
return 0
|
| 105 |
+
|
| 106 |
+
func1.__name__ = name
|
| 107 |
+
return func1
|
| 108 |
+
|
| 109 |
+
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
|
| 110 |
+
|
| 111 |
+
buf_to_fx_node = {}
|
| 112 |
+
graph = torch.fx.Graph()
|
| 113 |
+
first_node = None
|
| 114 |
+
|
| 115 |
+
outputs = []
|
| 116 |
+
group: Any = None
|
| 117 |
+
# create call_function node for each Buffer and Kernel
|
| 118 |
+
for snode in snodes:
|
| 119 |
+
if snode.is_extern():
|
| 120 |
+
node_type = "extern"
|
| 121 |
+
group = node_type
|
| 122 |
+
elif snode.is_template():
|
| 123 |
+
node_type = "template"
|
| 124 |
+
group = node_type
|
| 125 |
+
elif isinstance(snode, NopKernelSchedulerNode):
|
| 126 |
+
node_type = "nop"
|
| 127 |
+
group = node_type
|
| 128 |
+
elif isinstance(snode, SchedulerNode):
|
| 129 |
+
node_type = "compute"
|
| 130 |
+
group = snode.group
|
| 131 |
+
elif isinstance(snode, FusedSchedulerNode):
|
| 132 |
+
node_type = "fused"
|
| 133 |
+
group = snode.group
|
| 134 |
+
else:
|
| 135 |
+
raise RuntimeError("Unknown node type")
|
| 136 |
+
|
| 137 |
+
fused_name = torch._inductor.utils.get_fused_kernel_name(
|
| 138 |
+
snode.get_nodes(), "original_aten"
|
| 139 |
+
)
|
| 140 |
+
func_name = f"{node_type}: {fused_name}"
|
| 141 |
+
node_func = get_fake_func(func_name)
|
| 142 |
+
kwargs = {}
|
| 143 |
+
if hasattr(snode, "get_device"):
|
| 144 |
+
kwargs = {"device": snode.get_device()}
|
| 145 |
+
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs)
|
| 146 |
+
|
| 147 |
+
def in_output(snode):
|
| 148 |
+
if isinstance(snode, FusedSchedulerNode):
|
| 149 |
+
return any(in_output(x) for x in snode.snodes)
|
| 150 |
+
return any(isinstance(user.node, OutputNode) for user in snode.users)
|
| 151 |
+
|
| 152 |
+
if in_output(snode):
|
| 153 |
+
outputs.append(fx_node)
|
| 154 |
+
name = snode.get_name()
|
| 155 |
+
fx_node.name = name
|
| 156 |
+
|
| 157 |
+
fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
|
| 158 |
+
|
| 159 |
+
if isinstance(snode, FusedSchedulerNode):
|
| 160 |
+
for x in snode.snodes:
|
| 161 |
+
buf_to_fx_node[x.get_name()] = fx_node
|
| 162 |
+
buf_to_fx_node[name] = fx_node
|
| 163 |
+
|
| 164 |
+
if first_node is None:
|
| 165 |
+
first_node = fx_node
|
| 166 |
+
|
| 167 |
+
# create edges between nodes
|
| 168 |
+
for snode in snodes:
|
| 169 |
+
name = snode.get_name()
|
| 170 |
+
deps = snode.read_writes.reads
|
| 171 |
+
|
| 172 |
+
fx_node = buf_to_fx_node[name]
|
| 173 |
+
new_args = []
|
| 174 |
+
for dep in deps:
|
| 175 |
+
if dep.name in buf_to_fx_node:
|
| 176 |
+
dep_node = buf_to_fx_node[dep.name]
|
| 177 |
+
else:
|
| 178 |
+
with graph.inserting_before(first_node):
|
| 179 |
+
dep_node = graph.placeholder(dep.name)
|
| 180 |
+
buf_to_fx_node[dep.name] = dep_node
|
| 181 |
+
new_args.append(dep_node)
|
| 182 |
+
|
| 183 |
+
fx_node.args = tuple(new_args)
|
| 184 |
+
|
| 185 |
+
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
|
| 186 |
+
return graph
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def update_orig_fx_node_name_to_buf_name(
|
| 190 |
+
nodes: SchedulerNodeList,
|
| 191 |
+
node_name_to_buf_name: Dict[str, str],
|
| 192 |
+
parent_buf_name: Optional[str] = None,
|
| 193 |
+
n_origins: int = 0,
|
| 194 |
+
):
|
| 195 |
+
if nodes is None:
|
| 196 |
+
return
|
| 197 |
+
for node in nodes:
|
| 198 |
+
# for FusedSchedulerNode, traverse recursively into get_nodes()
|
| 199 |
+
buf_name = node.get_name()
|
| 200 |
+
children_nodes = node.get_nodes()
|
| 201 |
+
if children_nodes is not None and len(children_nodes) > 1:
|
| 202 |
+
update_orig_fx_node_name_to_buf_name(
|
| 203 |
+
children_nodes,
|
| 204 |
+
node_name_to_buf_name,
|
| 205 |
+
buf_name if parent_buf_name is None else parent_buf_name,
|
| 206 |
+
)
|
| 207 |
+
continue
|
| 208 |
+
else:
|
| 209 |
+
assert len(children_nodes) == 1 and children_nodes[0] == node
|
| 210 |
+
|
| 211 |
+
ir_node = node.node
|
| 212 |
+
if ir_node is None or ir_node.origins is None:
|
| 213 |
+
continue
|
| 214 |
+
for origin in ir_node.origins:
|
| 215 |
+
node_name = origin.name
|
| 216 |
+
# when buf1 and buf2 both have origin=node1
|
| 217 |
+
# we draw node1 according to buf1
|
| 218 |
+
if node_name not in node_name_to_buf_name:
|
| 219 |
+
node_name_to_buf_name[node_name] = (
|
| 220 |
+
buf_name if parent_buf_name is None else parent_buf_name
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_node_name_to_buf_meta(node_name_to_buf_name: Dict[str, str]):
|
| 225 |
+
buf_name_to_n_node = {}
|
| 226 |
+
for node_name, buf_name in node_name_to_buf_name.items():
|
| 227 |
+
if buf_name not in buf_name_to_n_node:
|
| 228 |
+
buf_name_to_n_node[buf_name] = {node_name}
|
| 229 |
+
else:
|
| 230 |
+
buf_name_to_n_node[buf_name].add(node_name)
|
| 231 |
+
|
| 232 |
+
node_name_to_buf_meta = {}
|
| 233 |
+
for node_name, buf_name in node_name_to_buf_name.items():
|
| 234 |
+
n_node = len(buf_name_to_n_node[buf_name])
|
| 235 |
+
node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node)
|
| 236 |
+
return node_name_to_buf_meta
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def annotate_orig_fx_with_snodes(
|
| 240 |
+
gm: torch.fx.GraphModule, snodes: SchedulerNodeList
|
| 241 |
+
) -> None:
|
| 242 |
+
"""
|
| 243 |
+
Creates a FX Graph from a list of SchedulerNode objects.
|
| 244 |
+
"""
|
| 245 |
+
node_name_to_buf_name: Dict[str, str] = {}
|
| 246 |
+
update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
|
| 247 |
+
if node_name_to_buf_name is None:
|
| 248 |
+
return
|
| 249 |
+
node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name)
|
| 250 |
+
for node in gm.graph.nodes:
|
| 251 |
+
if node.name in node_name_to_buf_meta:
|
| 252 |
+
node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@contextlib.contextmanager
|
| 256 |
+
def enable_aot_logging():
|
| 257 |
+
compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
| 258 |
+
|
| 259 |
+
import torch._functorch.aot_autograd
|
| 260 |
+
|
| 261 |
+
log = logging.getLogger(torch._functorch.aot_autograd.__name__)
|
| 262 |
+
|
| 263 |
+
stack = contextlib.ExitStack()
|
| 264 |
+
if not compile_debug:
|
| 265 |
+
try:
|
| 266 |
+
yield
|
| 267 |
+
finally:
|
| 268 |
+
stack.close()
|
| 269 |
+
return
|
| 270 |
+
|
| 271 |
+
# Enable all graphs to be logged to a file by setting the flags to True
|
| 272 |
+
# and the log level of the file logger to DEBUG
|
| 273 |
+
stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
|
| 274 |
+
|
| 275 |
+
path = os.path.join(get_debug_dir(), "torchinductor")
|
| 276 |
+
os.makedirs(path, exist_ok=True)
|
| 277 |
+
|
| 278 |
+
fh = logging.FileHandler(
|
| 279 |
+
os.path.join(
|
| 280 |
+
path,
|
| 281 |
+
f"aot_{get_aot_graph_name()}_debug.log",
|
| 282 |
+
)
|
| 283 |
+
)
|
| 284 |
+
fh.setLevel(logging.DEBUG)
|
| 285 |
+
fh.setFormatter(
|
| 286 |
+
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
|
| 287 |
+
)
|
| 288 |
+
log.addHandler(fh)
|
| 289 |
+
try:
|
| 290 |
+
yield
|
| 291 |
+
finally:
|
| 292 |
+
log.removeHandler(fh)
|
| 293 |
+
stack.close()
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class DebugContext:
|
| 297 |
+
_counter = itertools.count()
|
| 298 |
+
|
| 299 |
+
@staticmethod
|
| 300 |
+
def wrap(fn):
|
| 301 |
+
@functools.wraps(fn)
|
| 302 |
+
def inner(*args, **kwargs):
|
| 303 |
+
with DebugContext():
|
| 304 |
+
return fn(*args, **kwargs)
|
| 305 |
+
|
| 306 |
+
return wrap_compiler_debug(inner, compiler_name="inductor")
|
| 307 |
+
|
| 308 |
+
@staticmethod
|
| 309 |
+
def create_debug_dir(folder_name: str) -> Optional[str]:
|
| 310 |
+
debug_dir = config.trace.debug_dir or get_debug_dir()
|
| 311 |
+
for n in DebugContext._counter:
|
| 312 |
+
dirname = os.path.join(
|
| 313 |
+
debug_dir,
|
| 314 |
+
"torchinductor",
|
| 315 |
+
f"{folder_name}.{n}",
|
| 316 |
+
)
|
| 317 |
+
if not os.path.exists(dirname):
|
| 318 |
+
os.makedirs(dirname)
|
| 319 |
+
return dirname
|
| 320 |
+
return None
|
| 321 |
+
|
| 322 |
+
def __init__(self):
|
| 323 |
+
self._prof = None
|
| 324 |
+
self._path = None
|
| 325 |
+
self._stack = contextlib.ExitStack()
|
| 326 |
+
|
| 327 |
+
def copy(self, new_path: str):
|
| 328 |
+
if not self._path:
|
| 329 |
+
return
|
| 330 |
+
assert new_path.endswith(".debug"), new_path
|
| 331 |
+
if os.path.exists(new_path):
|
| 332 |
+
shutil.rmtree(new_path)
|
| 333 |
+
try:
|
| 334 |
+
shutil.copytree(self._path, new_path)
|
| 335 |
+
self._path = new_path
|
| 336 |
+
except OSError:
|
| 337 |
+
log.warning(
|
| 338 |
+
"Failed to copy debug files from %s to %s", self._path, new_path
|
| 339 |
+
)
|
| 340 |
+
pass
|
| 341 |
+
|
| 342 |
+
def fopen(self, filename: str, write_mode: str = "w", *args, **kwargs):
|
| 343 |
+
assert self._path
|
| 344 |
+
return open(os.path.join(self._path, filename), write_mode, *args, **kwargs)
|
| 345 |
+
|
| 346 |
+
@contextlib.contextmanager
|
| 347 |
+
def fopen_context(self, filename: str, write_mode: str = "w", *args, **kwargs):
|
| 348 |
+
assert self._path
|
| 349 |
+
with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f:
|
| 350 |
+
yield f
|
| 351 |
+
|
| 352 |
+
def filename(self, suffix: str):
|
| 353 |
+
assert self._path
|
| 354 |
+
return os.path.join(self._path, suffix)
|
| 355 |
+
|
| 356 |
+
def upload_tar(self):
|
| 357 |
+
if config.trace.upload_tar is not None:
|
| 358 |
+
import tarfile
|
| 359 |
+
|
| 360 |
+
assert self._path
|
| 361 |
+
tar_file = os.path.join(
|
| 362 |
+
self._path, f"{os.path.basename(self._path)}.tar.gz"
|
| 363 |
+
)
|
| 364 |
+
with tarfile.open(tar_file, "w:gz") as tar:
|
| 365 |
+
tar.add(self._path, arcname=os.path.basename(self._path))
|
| 366 |
+
config.trace.upload_tar(tar_file)
|
| 367 |
+
|
| 368 |
+
def __enter__(self):
|
| 369 |
+
if config.debug:
|
| 370 |
+
log = logging.getLogger("torch._dynamo")
|
| 371 |
+
prev_level = log.level
|
| 372 |
+
log.setLevel(logging.DEBUG)
|
| 373 |
+
|
| 374 |
+
def reset_log_level(level):
|
| 375 |
+
log.setLevel(level)
|
| 376 |
+
|
| 377 |
+
self._stack.callback(reset_log_level, prev_level)
|
| 378 |
+
|
| 379 |
+
self._stack.enter_context(V.set_debug_handler(self))
|
| 380 |
+
|
| 381 |
+
if not config.trace.enabled:
|
| 382 |
+
return
|
| 383 |
+
|
| 384 |
+
self._path = self.create_debug_dir(get_aot_graph_name())
|
| 385 |
+
|
| 386 |
+
if config.trace.debug_log:
|
| 387 |
+
self._setup_log_capture("debug.log", logging.DEBUG)
|
| 388 |
+
if config.trace.info_log:
|
| 389 |
+
self._setup_log_capture("info.log", logging.INFO)
|
| 390 |
+
if config.trace.compile_profile:
|
| 391 |
+
self._prof = cProfile.Profile()
|
| 392 |
+
self._prof.enable()
|
| 393 |
+
|
| 394 |
+
def _setup_log_capture(self, filename: str, level: int):
|
| 395 |
+
log = logging.getLogger("torch._inductor")
|
| 396 |
+
fd = self._stack.enter_context(self.fopen(filename))
|
| 397 |
+
ch = logging.StreamHandler(fd)
|
| 398 |
+
ch.setLevel(level)
|
| 399 |
+
ch.setFormatter(
|
| 400 |
+
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
|
| 401 |
+
)
|
| 402 |
+
log.addHandler(ch)
|
| 403 |
+
log.setLevel(min(log.level, level))
|
| 404 |
+
self._stack.callback(log.removeHandler, ch)
|
| 405 |
+
|
| 406 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 407 |
+
if self._prof:
|
| 408 |
+
self._prof.disable()
|
| 409 |
+
self._save_profile_data()
|
| 410 |
+
|
| 411 |
+
if self._path:
|
| 412 |
+
self.upload_tar()
|
| 413 |
+
log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
|
| 414 |
+
self._stack.close()
|
| 415 |
+
|
| 416 |
+
def _save_profile_data(self):
|
| 417 |
+
assert self._prof
|
| 418 |
+
self._prof.dump_stats(self.filename("compile.prof"))
|
| 419 |
+
with self.fopen("compile.stats") as fd:
|
| 420 |
+
stats = pstats.Stats(self._prof, stream=fd)
|
| 421 |
+
stats.strip_dirs()
|
| 422 |
+
stats.sort_stats("cumtime")
|
| 423 |
+
stats.print_stats(100)
|
| 424 |
+
stats.sort_stats("tottime")
|
| 425 |
+
stats.print_stats(100)
|
| 426 |
+
|
| 427 |
+
def __getattr__(self, name):
|
| 428 |
+
if config.trace.enabled and getattr(config.trace, name):
|
| 429 |
+
try:
|
| 430 |
+
return getattr(DebugFormatter(self), name)
|
| 431 |
+
except Exception:
|
| 432 |
+
log.warning("Ignoring exception in debug code", exc_info=True)
|
| 433 |
+
else:
|
| 434 |
+
|
| 435 |
+
def ignored(*args, **kwargs):
|
| 436 |
+
pass
|
| 437 |
+
|
| 438 |
+
return ignored
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
class DebugFormatter:
|
| 442 |
+
def __init__(self, handler):
|
| 443 |
+
self.fopen = handler.fopen
|
| 444 |
+
self.fopen_context = handler.fopen_context
|
| 445 |
+
self.filename = handler.filename
|
| 446 |
+
self.handler = handler
|
| 447 |
+
|
| 448 |
+
def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
|
| 449 |
+
with self.fopen("fx_graph_runnable.py") as fd:
|
| 450 |
+
save_graph_repro(fd, gm, inputs, "inductor")
|
| 451 |
+
|
| 452 |
+
with self.fopen("fx_graph_readable.py") as fd:
|
| 453 |
+
fd.write(gm.print_readable(print_output=False))
|
| 454 |
+
|
| 455 |
+
def fx_graph_transformed(
|
| 456 |
+
self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
|
| 457 |
+
):
|
| 458 |
+
with self.fopen("fx_graph_transformed.py") as fd:
|
| 459 |
+
fd.write(gm.print_readable(print_output=False))
|
| 460 |
+
|
| 461 |
+
def ir_pre_fusion(self, nodes: SchedulerNodeList):
|
| 462 |
+
self._write_ir("ir_pre_fusion.txt", nodes)
|
| 463 |
+
|
| 464 |
+
def ir_post_fusion(self, nodes: SchedulerNodeList):
|
| 465 |
+
self._write_ir("ir_post_fusion.txt", nodes)
|
| 466 |
+
|
| 467 |
+
def _write_ir(self, filename: str, nodes: SchedulerNodeList):
|
| 468 |
+
with self.fopen(filename) as fd:
|
| 469 |
+
log.info("Writing debug ir to %s", fd.name)
|
| 470 |
+
for node in nodes:
|
| 471 |
+
fd.write(node.debug_str())
|
| 472 |
+
fd.write("\n\n\n")
|
| 473 |
+
|
| 474 |
+
def graph_diagram(self, nodes: SchedulerNodeList):
|
| 475 |
+
draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
|
| 476 |
+
|
| 477 |
+
def draw_orig_fx_graph(self, gm: torch.fx.GraphModule, nodes: SchedulerNodeList):
|
| 478 |
+
annotate_orig_fx_with_snodes(gm, nodes)
|
| 479 |
+
draw_graph(
|
| 480 |
+
gm,
|
| 481 |
+
fname=self.filename("orig_fx_graph_diagram.svg"),
|
| 482 |
+
clear_meta=False,
|
| 483 |
+
prog=GRAPHVIZ_COMMAND_SCALABLE,
|
| 484 |
+
parse_stack_trace=True,
|
| 485 |
+
dot_graph_shape=config.trace.dot_graph_shape,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
def output_code(self, filename):
|
| 489 |
+
shutil.copy(filename, self.filename("output_code.py"))
|
| 490 |
+
|
| 491 |
+
def log_autotuning_results(
|
| 492 |
+
self,
|
| 493 |
+
name: str,
|
| 494 |
+
input_nodes: List[ir.IRNode],
|
| 495 |
+
timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
|
| 496 |
+
elapse: float,
|
| 497 |
+
):
|
| 498 |
+
import json
|
| 499 |
+
|
| 500 |
+
from .ir import FixedLayout
|
| 501 |
+
|
| 502 |
+
def build_node_info(node: ir.IRNode):
|
| 503 |
+
if hasattr(node, "name"):
|
| 504 |
+
node_name = node.name
|
| 505 |
+
else:
|
| 506 |
+
node_name = ""
|
| 507 |
+
node_info = {
|
| 508 |
+
"name": node_name,
|
| 509 |
+
"type": type(node).__name__,
|
| 510 |
+
}
|
| 511 |
+
try:
|
| 512 |
+
layout = node.get_layout()
|
| 513 |
+
if isinstance(layout, FixedLayout):
|
| 514 |
+
offset = 0
|
| 515 |
+
try:
|
| 516 |
+
offset = int(layout.offset)
|
| 517 |
+
except Exception:
|
| 518 |
+
try:
|
| 519 |
+
offset = V.graph.sizevars.size_hint(
|
| 520 |
+
layout.offset, fallback=0
|
| 521 |
+
)
|
| 522 |
+
except Exception:
|
| 523 |
+
pass
|
| 524 |
+
static_layout = FixedLayout(
|
| 525 |
+
layout.device,
|
| 526 |
+
dtype=layout.dtype,
|
| 527 |
+
size=list(V.graph.sizevars.size_hints(layout.size)),
|
| 528 |
+
stride=list(V.graph.sizevars.size_hints(layout.stride)),
|
| 529 |
+
offset=offset,
|
| 530 |
+
)
|
| 531 |
+
node_info["layout"] = str(static_layout)
|
| 532 |
+
else:
|
| 533 |
+
node_info["layout"] = str(node.get_layout())
|
| 534 |
+
except Exception as e:
|
| 535 |
+
pass
|
| 536 |
+
try:
|
| 537 |
+
node_info["dtype"] = str(node.get_dtype())
|
| 538 |
+
except Exception as e:
|
| 539 |
+
pass
|
| 540 |
+
try:
|
| 541 |
+
node_info["device"] = str(node.get_device())
|
| 542 |
+
except Exception as e:
|
| 543 |
+
pass
|
| 544 |
+
try:
|
| 545 |
+
node_info["stride"] = str(
|
| 546 |
+
V.graph.sizevars.size_hints(node.get_stride())
|
| 547 |
+
)
|
| 548 |
+
except Exception as e:
|
| 549 |
+
pass
|
| 550 |
+
try:
|
| 551 |
+
node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size()))
|
| 552 |
+
except Exception as e:
|
| 553 |
+
pass
|
| 554 |
+
try:
|
| 555 |
+
node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel()))
|
| 556 |
+
except Exception as e:
|
| 557 |
+
pass
|
| 558 |
+
if hasattr(node, "data") and isinstance(node.data, ir.IRNode):
|
| 559 |
+
node_info["data"] = build_node_info(node.data)
|
| 560 |
+
return node_info
|
| 561 |
+
|
| 562 |
+
general_properties = {
|
| 563 |
+
"op_name": name,
|
| 564 |
+
"cuda_device_name": torch.cuda.get_device_name(),
|
| 565 |
+
"cuda_device_count": torch.cuda.device_count(),
|
| 566 |
+
"input_nodes": [build_node_info(node) for node in input_nodes],
|
| 567 |
+
"autotuning_time": elapse,
|
| 568 |
+
}
|
| 569 |
+
with self.fopen_context(
|
| 570 |
+
"autotuning_result_json_list.txt", "at", encoding="utf-8"
|
| 571 |
+
) as fd:
|
| 572 |
+
for caller, time in timings.items():
|
| 573 |
+
info_dict = dict(caller.info_dict())
|
| 574 |
+
info_dict.update(general_properties)
|
| 575 |
+
info_dict["benchmark_result"] = time
|
| 576 |
+
json.dump(info_dict, fd)
|
| 577 |
+
fd.write("\n")
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
@dataclasses.dataclass
|
| 581 |
+
class TensorMetadataHolder:
|
| 582 |
+
tensor_metadata: TensorMetadata
|
| 583 |
+
device: torch.device
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
save_args_cnt = itertools.count()
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def save_args_for_compile_fx_inner(*args, **kwargs):
|
| 590 |
+
"""
|
| 591 |
+
This function is used to save arguments for a compile_fx_inner function call
|
| 592 |
+
to the file system. Later on one can replay the compile_fx_inner call
|
| 593 |
+
with the saved arguments using load_args_and_run_compile_fx_inner.
|
| 594 |
+
"""
|
| 595 |
+
|
| 596 |
+
folder = "/tmp/inductor_saved_args"
|
| 597 |
+
if not os.path.exists(folder):
|
| 598 |
+
os.mkdir(folder)
|
| 599 |
+
|
| 600 |
+
def handle_tensor(x):
|
| 601 |
+
"""
|
| 602 |
+
Pickle FakeTensor will result in error:
|
| 603 |
+
AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
|
| 604 |
+
|
| 605 |
+
Convert all Tensor to metadata. This may also makes pickle faster.
|
| 606 |
+
"""
|
| 607 |
+
if isinstance(x, torch.Tensor):
|
| 608 |
+
return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
|
| 609 |
+
else:
|
| 610 |
+
return x
|
| 611 |
+
|
| 612 |
+
args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))
|
| 613 |
+
|
| 614 |
+
fn_name = "compile_fx_inner"
|
| 615 |
+
path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl"
|
| 616 |
+
with open(path, "wb") as f:
|
| 617 |
+
pickle.dump((args_to_save, kwargs_to_save), f)
|
| 618 |
+
|
| 619 |
+
if log.isEnabledFor(logging.DEBUG):
|
| 620 |
+
message = f"""
|
| 621 |
+
Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
|
| 622 |
+
run the following:
|
| 623 |
+
|
| 624 |
+
from torch._inductor.debug import load_args_and_run_compile_fx_inner
|
| 625 |
+
load_args_and_run_compile_fx_inner({path!r})
|
| 626 |
+
"""
|
| 627 |
+
# call print rather than log.debug. log.debug will print message
|
| 628 |
+
# prefix for each line which makes the code snippet harder to be
|
| 629 |
+
# copied.
|
| 630 |
+
# Not a big deal since the code is already been guarded by checking
|
| 631 |
+
# the log level.
|
| 632 |
+
print(message)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def load_args_and_run_compile_fx_inner(path: str):
|
| 636 |
+
from torch._inductor.compile_fx import compile_fx_inner
|
| 637 |
+
|
| 638 |
+
with open(path, "rb") as f:
|
| 639 |
+
args, kwargs = pickle.load(f)
|
| 640 |
+
|
| 641 |
+
def handle_tensor(x):
|
| 642 |
+
if isinstance(x, TensorMetadataHolder):
|
| 643 |
+
return torch._dynamo.testing.rand_strided(
|
| 644 |
+
x.tensor_metadata.shape,
|
| 645 |
+
x.tensor_metadata.stride,
|
| 646 |
+
x.tensor_metadata.dtype,
|
| 647 |
+
x.device,
|
| 648 |
+
)
|
| 649 |
+
else:
|
| 650 |
+
return x
|
| 651 |
+
|
| 652 |
+
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
| 653 |
+
with fake_mode, config.patch("save_args", False):
|
| 654 |
+
args, kwargs = tree_map(handle_tensor, (args, kwargs))
|
| 655 |
+
return compile_fx_inner(*args, **kwargs)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py
ADDED
|
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import sys
|
| 5 |
+
import typing
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._decomp as decomp
|
| 10 |
+
import torch._prims_common as utils
|
| 11 |
+
import torch.ao.quantization.fx._decomposed
|
| 12 |
+
from torch._decomp import (
|
| 13 |
+
core_aten_decompositions,
|
| 14 |
+
get_decompositions,
|
| 15 |
+
remove_decompositions,
|
| 16 |
+
)
|
| 17 |
+
from torch._decomp.decompositions import (
|
| 18 |
+
_grid_sampler_2d as decomp_grid_sampler_2d,
|
| 19 |
+
pw_cast_for_opmath,
|
| 20 |
+
)
|
| 21 |
+
from torch._decomp.decompositions_for_rng import extra_random_decomps
|
| 22 |
+
from torch._higher_order_ops.out_dtype import out_dtype
|
| 23 |
+
from torch._prims_common import (
|
| 24 |
+
elementwise_dtypes,
|
| 25 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
| 26 |
+
type_to_dtype,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from . import config, inductor_prims
|
| 30 |
+
|
| 31 |
+
log = logging.getLogger(__name__)
|
| 32 |
+
aten = torch.ops.aten
|
| 33 |
+
prims = torch.ops.prims
|
| 34 |
+
quantized_decomposed = torch.ops.quantized_decomposed
|
| 35 |
+
|
| 36 |
+
inductor_decompositions = get_decompositions(
|
| 37 |
+
[
|
| 38 |
+
aten._adaptive_avg_pool2d_backward,
|
| 39 |
+
aten.arange,
|
| 40 |
+
aten.bitwise_and_,
|
| 41 |
+
aten.bitwise_or_,
|
| 42 |
+
aten.clamp_min_,
|
| 43 |
+
aten.dist,
|
| 44 |
+
aten.empty_like,
|
| 45 |
+
aten.flip,
|
| 46 |
+
aten.gelu,
|
| 47 |
+
aten.hardtanh,
|
| 48 |
+
aten.index_select,
|
| 49 |
+
aten.lcm,
|
| 50 |
+
aten.leaky_relu,
|
| 51 |
+
aten.linalg_vector_norm,
|
| 52 |
+
aten._log_softmax,
|
| 53 |
+
aten.max_pool2d_with_indices_backward,
|
| 54 |
+
aten._native_batch_norm_legit,
|
| 55 |
+
aten._native_batch_norm_legit_functional,
|
| 56 |
+
aten._native_batch_norm_legit_no_training,
|
| 57 |
+
aten.native_batch_norm,
|
| 58 |
+
aten.native_group_norm,
|
| 59 |
+
aten.native_layer_norm,
|
| 60 |
+
aten.nll_loss2d_backward,
|
| 61 |
+
aten._softmax,
|
| 62 |
+
aten.sin_,
|
| 63 |
+
aten.sqrt_,
|
| 64 |
+
out_dtype,
|
| 65 |
+
aten._to_copy,
|
| 66 |
+
aten.tril_indices,
|
| 67 |
+
aten.triu_indices,
|
| 68 |
+
aten.upsample_bilinear2d.vec,
|
| 69 |
+
]
|
| 70 |
+
)
|
| 71 |
+
decompositions = {**core_aten_decompositions(), **inductor_decompositions}
|
| 72 |
+
|
| 73 |
+
# Remove unwanted decompositions included via the core ATen decompositions from
|
| 74 |
+
# the Inductor decomp table.
|
| 75 |
+
decomps_to_exclude = [
|
| 76 |
+
aten._unsafe_index,
|
| 77 |
+
aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py
|
| 78 |
+
aten.clamp_max,
|
| 79 |
+
aten.clamp_min,
|
| 80 |
+
aten.glu, # inductor lowers this directly
|
| 81 |
+
aten.split.Tensor, # inductor lowers this directly
|
| 82 |
+
aten.squeeze, # inductor lowers this directly
|
| 83 |
+
aten.sum, # inductor lowers this directly
|
| 84 |
+
aten.unbind, # inductor lowers this directly
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
remove_decompositions(decompositions, decomps_to_exclude)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def register_decomposition(ops):
|
| 91 |
+
for op in [ops] if callable(ops) else ops:
|
| 92 |
+
if op in decompositions:
|
| 93 |
+
log.warning("duplicate decomp: %s", ops)
|
| 94 |
+
return decomp.register_decomposition(ops, decompositions)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# TODO: for now, inductor doesn't handle asserts
|
| 98 |
+
# because the condition is symbool -> tensor in the graph.
|
| 99 |
+
@register_decomposition([aten._assert_async.msg])
|
| 100 |
+
def assert_async_msg_decomp(tensor, msg):
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Following `assert_async_msg_decomp` and implement as non-op.
|
| 105 |
+
@register_decomposition([aten._functional_assert_async.msg])
|
| 106 |
+
def functional_assert_async_msg_decomp(tensor, msg):
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@register_decomposition([aten.sym_constrain_range_for_size.default])
|
| 111 |
+
def sym_constrain_range_for_size(symbol, *, min=None, max=None):
|
| 112 |
+
return
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@register_decomposition([aten.clamp])
|
| 116 |
+
@pw_cast_for_opmath
|
| 117 |
+
def clamp(x, min=None, max=None):
|
| 118 |
+
if min is not None:
|
| 119 |
+
x = x.clamp_min(min)
|
| 120 |
+
if max is not None:
|
| 121 |
+
x = x.clamp_max(max)
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@register_decomposition([aten.full])
|
| 126 |
+
def full(size, fill_value, **kwargs):
|
| 127 |
+
dtype = kwargs.get("dtype")
|
| 128 |
+
if dtype is None:
|
| 129 |
+
kwargs["dtype"] = type_to_dtype(type(fill_value))
|
| 130 |
+
return aten.full(size, fill_value, **kwargs)
|
| 131 |
+
return NotImplemented
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# Not really sure how to put this into the main library. PrimTorch wants
|
| 135 |
+
# empty_permuted to go to the prim, and typically users don't really want
|
| 136 |
+
# to decompose to empty_strided (but inductor is OK with it, because we are
|
| 137 |
+
# cool with strides and everything goes to empty_strided)
|
| 138 |
+
@register_decomposition([aten.empty_permuted.default])
|
| 139 |
+
def empty_permuted(size, physical_layout, **kwargs):
|
| 140 |
+
perm = [0] * len(size)
|
| 141 |
+
for p, l in enumerate(physical_layout):
|
| 142 |
+
perm[l] = p
|
| 143 |
+
return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@register_decomposition([aten.convolution_backward])
|
| 147 |
+
def convolution_backward(
|
| 148 |
+
grad_output,
|
| 149 |
+
input,
|
| 150 |
+
weight,
|
| 151 |
+
bias_sizes,
|
| 152 |
+
stride,
|
| 153 |
+
padding,
|
| 154 |
+
dilation,
|
| 155 |
+
transposed,
|
| 156 |
+
output_padding,
|
| 157 |
+
groups,
|
| 158 |
+
output_mask,
|
| 159 |
+
):
|
| 160 |
+
if not output_mask[2] or grad_output.device.type != "cuda":
|
| 161 |
+
return NotImplemented
|
| 162 |
+
grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
|
| 163 |
+
grad_inp, grad_weight, _ = aten.convolution_backward(
|
| 164 |
+
grad_output,
|
| 165 |
+
input,
|
| 166 |
+
weight,
|
| 167 |
+
bias_sizes,
|
| 168 |
+
stride,
|
| 169 |
+
padding,
|
| 170 |
+
dilation,
|
| 171 |
+
transposed,
|
| 172 |
+
output_padding,
|
| 173 |
+
groups,
|
| 174 |
+
[output_mask[0], output_mask[1], False],
|
| 175 |
+
)
|
| 176 |
+
return (grad_inp, grad_weight, grad_bias)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@register_decomposition([aten.log2])
|
| 180 |
+
def log2(x):
|
| 181 |
+
return torch.log(x) * (1.0 / math.log(2.0))
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@register_decomposition([aten.round.decimals])
|
| 185 |
+
def round_dec(x, decimals=0):
|
| 186 |
+
ten_pow_decimals = 10.0**decimals
|
| 187 |
+
return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@register_decomposition([aten.bmm])
|
| 191 |
+
@pw_cast_for_opmath
|
| 192 |
+
def bmm(self, batch2):
|
| 193 |
+
if config.coordinate_descent_tuning:
|
| 194 |
+
if self.shape[1] == 1 or batch2.shape[2] == 1:
|
| 195 |
+
out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2)
|
| 196 |
+
return out
|
| 197 |
+
if self.device.type == "cpu":
|
| 198 |
+
if self.size(1) == 1 and batch2.size(-1) == 1:
|
| 199 |
+
return torch.sum(
|
| 200 |
+
self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
|
| 201 |
+
).unsqueeze(1)
|
| 202 |
+
return NotImplemented
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@register_decomposition([aten.addmm])
|
| 206 |
+
@pw_cast_for_opmath
|
| 207 |
+
def addmm(self, mat1, mat2, beta=1, alpha=1):
|
| 208 |
+
if self.device.type == "cpu":
|
| 209 |
+
if mat1.size(0) == 1 and mat2.size(-1) == 1:
|
| 210 |
+
out = torch.sum(
|
| 211 |
+
mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True
|
| 212 |
+
).unsqueeze(0)
|
| 213 |
+
return alpha * out + beta * self
|
| 214 |
+
if mat1.size(0) == 1 and mat2.size(0) <= 16 and mat2.size(1) <= 16:
|
| 215 |
+
out = (mat1.T * mat2).sum(dim=0, keepdim=True)
|
| 216 |
+
return alpha * out + beta * self
|
| 217 |
+
return NotImplemented
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@register_decomposition([aten.mm])
|
| 221 |
+
@pw_cast_for_opmath
|
| 222 |
+
def mm(self, input2):
|
| 223 |
+
from torch.fx.experimental.symbolic_shapes import (
|
| 224 |
+
definitely_true,
|
| 225 |
+
guard_size_oblivious,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
|
| 229 |
+
# todo: Look into why and fix it (hopefully)
|
| 230 |
+
if config.coordinate_descent_tuning:
|
| 231 |
+
if self.shape[0] == 1 or input2.shape[1] == 1:
|
| 232 |
+
return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
|
| 233 |
+
if self.device.type == "cpu":
|
| 234 |
+
if (
|
| 235 |
+
guard_size_oblivious(self.size(-1) == 1)
|
| 236 |
+
and guard_size_oblivious(self.size(0) > 0)
|
| 237 |
+
and guard_size_oblivious(input2.size(0) == 1)
|
| 238 |
+
and (self.dtype == input2.dtype)
|
| 239 |
+
and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32)
|
| 240 |
+
):
|
| 241 |
+
return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
|
| 242 |
+
if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious(
|
| 243 |
+
input2.size(-1) == 1
|
| 244 |
+
):
|
| 245 |
+
return torch.sum(
|
| 246 |
+
self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True
|
| 247 |
+
).unsqueeze(0)
|
| 248 |
+
return NotImplemented
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# This pass does two things:
|
| 252 |
+
# - Eliminate cat when there is only one tensor input
|
| 253 |
+
# - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we
|
| 254 |
+
# don't remove ALL empty tensors, only the naughty ones)
|
| 255 |
+
@register_decomposition([aten.cat.default])
|
| 256 |
+
def cat(tensors, dim=0):
|
| 257 |
+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
| 258 |
+
|
| 259 |
+
def non_empty_tensor(x):
|
| 260 |
+
# For better or worse, this is a valid cat:
|
| 261 |
+
#
|
| 262 |
+
# torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)])
|
| 263 |
+
#
|
| 264 |
+
# We'd like to eliminate naughtiness like this for downstream passes
|
| 265 |
+
# like split_cat. The easiest way is to just drop such inputs
|
| 266 |
+
# (guarding that they are non-zero).
|
| 267 |
+
#
|
| 268 |
+
# Is it permissible for this filtering to be size-oblivious? A case
|
| 269 |
+
# where this could matter is cat([(2, 2), (u0,)], dim=0); if u0
|
| 270 |
+
# happened to be zero, we would have liked to have filtered it out.
|
| 271 |
+
# But actually, the ONLY way this could have passed is if u0 == 0,
|
| 272 |
+
# so by the time we get here we have already installed a deferred
|
| 273 |
+
# runtime assert forcing u0 to be zero. So if this hasn't happened,
|
| 274 |
+
# we know that the unbacked SymInt has appropriate size and there are
|
| 275 |
+
# no problems.
|
| 276 |
+
return len(x.shape) != 1 or guard_size_oblivious(x.shape[0] > 0)
|
| 277 |
+
|
| 278 |
+
filtered_tensors = list(filter(non_empty_tensor, tensors))
|
| 279 |
+
|
| 280 |
+
if len(filtered_tensors) == 1:
|
| 281 |
+
return filtered_tensors[0].clone()
|
| 282 |
+
elif 1 < len(filtered_tensors) < len(tensors):
|
| 283 |
+
# on the first call, when we remove empty tensors, we redispatch recursively
|
| 284 |
+
return aten.cat.default(filtered_tensors, dim)
|
| 285 |
+
# when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed)
|
| 286 |
+
return NotImplemented
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@register_decomposition([aten.angle])
|
| 290 |
+
def angle(x):
|
| 291 |
+
if x.is_complex():
|
| 292 |
+
return torch.where(
|
| 293 |
+
torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# when x is real number
|
| 297 |
+
# if x >= 0, return 0
|
| 298 |
+
# if x < 0, return pi
|
| 299 |
+
# if x is nan, return nan
|
| 300 |
+
_, dtype = elementwise_dtypes(
|
| 301 |
+
x,
|
| 302 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 303 |
+
)
|
| 304 |
+
pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
|
| 305 |
+
ret = torch.where(x < 0, pi, 0.0)
|
| 306 |
+
return torch.where(torch.isnan(x), float("nan"), ret)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
@register_decomposition([aten.add])
|
| 310 |
+
def add(x, y, *, alpha=None):
|
| 311 |
+
x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
|
| 312 |
+
y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
|
| 313 |
+
if not x_is_complex_tensor or not y_is_complex_tensor:
|
| 314 |
+
return NotImplemented
|
| 315 |
+
z = y
|
| 316 |
+
if alpha is not None:
|
| 317 |
+
z = alpha * y
|
| 318 |
+
complex_type = torch.promote_types(x.dtype, y.dtype)
|
| 319 |
+
return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@register_decomposition([aten.conj_physical])
|
| 323 |
+
def conj_physical(self):
|
| 324 |
+
assert not self.is_complex(), "TODO: implement this"
|
| 325 |
+
return self
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@register_decomposition([aten.lift, aten.detach_])
|
| 329 |
+
def lift(self):
|
| 330 |
+
return self
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@register_decomposition([aten.bernoulli.default])
|
| 334 |
+
def bernoulli(self, *, generator=None):
|
| 335 |
+
assert generator is None
|
| 336 |
+
return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@register_decomposition([aten.fmin, prims.fmin])
|
| 340 |
+
def fmin(self, other):
|
| 341 |
+
return torch.where(torch.isnan(other) | (other > self), self, other)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@register_decomposition([aten.fmax, prims.fmax])
|
| 345 |
+
def fmax(self, other):
|
| 346 |
+
return torch.where(torch.isnan(other) | (other < self), self, other)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@register_decomposition(aten.amax)
|
| 350 |
+
def amax(self, dim=None, keepdim=False):
|
| 351 |
+
if self.dtype == torch.bool:
|
| 352 |
+
return torch.any(self, dim=dim, keepdim=keepdim)
|
| 353 |
+
return NotImplemented
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
@register_decomposition(aten.amin)
|
| 357 |
+
def amin(self, dim=None, keepdim=False):
|
| 358 |
+
if self.dtype == torch.bool:
|
| 359 |
+
return torch.all(self, dim=dim, keepdim=keepdim)
|
| 360 |
+
return NotImplemented
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
@register_decomposition([aten.narrow_copy])
|
| 364 |
+
def narrow_copy(self, dim, start, length):
|
| 365 |
+
return torch.narrow(self, dim, start, length).clone()
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@register_decomposition([aten.expand_copy])
|
| 369 |
+
def expand_copy(self, size, *, implicit=False):
|
| 370 |
+
return aten.expand(self, size, implicit=implicit).clone()
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
@register_decomposition([aten.view_copy.default])
|
| 374 |
+
def view_copy_default(self, size):
|
| 375 |
+
return aten.view(self, size).clone()
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
@register_decomposition([aten.view_copy.dtype])
|
| 379 |
+
def view_copy_dtype(self, dtype):
|
| 380 |
+
return self.to(dtype).clone()
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def get_like_layout(
|
| 384 |
+
tensor: torch.Tensor, memory_format: Optional[torch.memory_format]
|
| 385 |
+
) -> torch.memory_format:
|
| 386 |
+
# TODO: _to_copy tensor to stride permutation
|
| 387 |
+
if memory_format is torch.preserve_format or memory_format is None:
|
| 388 |
+
return utils.suggest_memory_format(tensor)
|
| 389 |
+
else:
|
| 390 |
+
return memory_format
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
@register_decomposition(aten.rand_like)
|
| 394 |
+
def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
|
| 395 |
+
return torch.rand(
|
| 396 |
+
[*self.size()],
|
| 397 |
+
dtype=dtype or self.dtype,
|
| 398 |
+
device=device or self.device,
|
| 399 |
+
**kwargs,
|
| 400 |
+
).to(memory_format=get_like_layout(self, memory_format))
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@register_decomposition(aten.randn_like)
|
| 404 |
+
def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
|
| 405 |
+
return torch.randn(
|
| 406 |
+
[*self.size()],
|
| 407 |
+
dtype=dtype or self.dtype,
|
| 408 |
+
device=device or self.device,
|
| 409 |
+
**kwargs,
|
| 410 |
+
).to(memory_format=get_like_layout(self, memory_format))
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
@register_decomposition(aten.full_like)
|
| 414 |
+
def full_like(
|
| 415 |
+
self,
|
| 416 |
+
fill_value,
|
| 417 |
+
*,
|
| 418 |
+
dtype=None,
|
| 419 |
+
layout=None,
|
| 420 |
+
device=None,
|
| 421 |
+
pin_memory=False,
|
| 422 |
+
requires_grad=False,
|
| 423 |
+
memory_format=torch.preserve_format,
|
| 424 |
+
):
|
| 425 |
+
return torch.full(
|
| 426 |
+
[*self.size()],
|
| 427 |
+
fill_value,
|
| 428 |
+
dtype=dtype or self.dtype,
|
| 429 |
+
layout=layout or self.layout,
|
| 430 |
+
device=device or self.device,
|
| 431 |
+
requires_grad=requires_grad,
|
| 432 |
+
).to(memory_format=get_like_layout(self, memory_format))
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
@register_decomposition(aten.randint_like.default)
|
| 436 |
+
def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs):
|
| 437 |
+
return aten.randint.low(
|
| 438 |
+
0,
|
| 439 |
+
high,
|
| 440 |
+
[*self.size()],
|
| 441 |
+
dtype=dtype or self.dtype,
|
| 442 |
+
device=device or self.device,
|
| 443 |
+
**kwargs,
|
| 444 |
+
).to(memory_format=get_like_layout(self, memory_format))
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
@register_decomposition(aten.randint_like.low_dtype)
|
| 448 |
+
def randint_like_low(
|
| 449 |
+
self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs
|
| 450 |
+
):
|
| 451 |
+
return aten.randint.low(
|
| 452 |
+
low,
|
| 453 |
+
high,
|
| 454 |
+
[*self.size()],
|
| 455 |
+
dtype=dtype or self.dtype,
|
| 456 |
+
device=device or self.device,
|
| 457 |
+
**kwargs,
|
| 458 |
+
).to(memory_format=get_like_layout(self, memory_format))
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
@register_decomposition(aten.randint.default)
|
| 462 |
+
def randint(high, size, **kwargs):
|
| 463 |
+
return aten.randint.low(0, high, size, **kwargs)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
# The difference between quantize_per_tensor.default and quantize_per_tensor.tensor is
|
| 467 |
+
# scale and zero_point is scalar or scalar tensor
|
| 468 |
+
@register_decomposition(quantized_decomposed.quantize_per_tensor.default)
|
| 469 |
+
def quantize_per_tensor_default_decomp_impl(
|
| 470 |
+
input: torch.Tensor,
|
| 471 |
+
scale: float,
|
| 472 |
+
zero_point: int,
|
| 473 |
+
quant_min: int,
|
| 474 |
+
quant_max: int,
|
| 475 |
+
dtype: torch.dtype,
|
| 476 |
+
) -> torch.Tensor:
|
| 477 |
+
if input.dtype == torch.bfloat16:
|
| 478 |
+
input = input.to(torch.float32)
|
| 479 |
+
inv_scale = 1.0 / scale
|
| 480 |
+
return torch.clamp(
|
| 481 |
+
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
|
| 482 |
+
).to(dtype)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
# The difference between dequantize_per_tensor.default and dequantize_per_tensor.tensor is
|
| 486 |
+
# scale and zero_point is scalar or scalar tensor
|
| 487 |
+
@register_decomposition(quantized_decomposed.dequantize_per_tensor.default)
|
| 488 |
+
def dequantize_per_tensor_default_decomp_impl(
|
| 489 |
+
input: torch.Tensor,
|
| 490 |
+
scale: float,
|
| 491 |
+
zero_point: int,
|
| 492 |
+
quant_min: int,
|
| 493 |
+
quant_max: int,
|
| 494 |
+
dtype: torch.dtype,
|
| 495 |
+
) -> torch.Tensor:
|
| 496 |
+
return (input.to(torch.float32) - zero_point) * scale
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
@register_decomposition(quantized_decomposed.quantize_per_tensor.tensor)
|
| 500 |
+
def quantize_per_tensor_tensor_decomp_impl(
|
| 501 |
+
input: torch.Tensor,
|
| 502 |
+
scale: torch.Tensor,
|
| 503 |
+
zero_point: torch.Tensor,
|
| 504 |
+
quant_min: int,
|
| 505 |
+
quant_max: int,
|
| 506 |
+
dtype: torch.dtype,
|
| 507 |
+
) -> torch.Tensor:
|
| 508 |
+
if input.dtype == torch.bfloat16:
|
| 509 |
+
input = input.to(torch.float32)
|
| 510 |
+
inv_scale = 1.0 / scale
|
| 511 |
+
return torch.clamp(
|
| 512 |
+
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
|
| 513 |
+
).to(dtype)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
@register_decomposition(quantized_decomposed.dequantize_per_tensor.tensor)
|
| 517 |
+
def dequantize_per_tensor_tensor_decomp_impl(
|
| 518 |
+
input: torch.Tensor,
|
| 519 |
+
scale: torch.Tensor,
|
| 520 |
+
zero_point: torch.Tensor,
|
| 521 |
+
quant_min: int,
|
| 522 |
+
quant_max: int,
|
| 523 |
+
dtype: torch.dtype,
|
| 524 |
+
) -> torch.Tensor:
|
| 525 |
+
return (input.to(torch.float32) - zero_point.to(torch.int32)) * scale.to(
|
| 526 |
+
torch.float32
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
|
| 531 |
+
def q_embedding_bag_byte_unpack_decomp(packed):
|
| 532 |
+
def bitcast_u8_to_f32(u8):
|
| 533 |
+
x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
|
| 534 |
+
if sys.byteorder == "little":
|
| 535 |
+
return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None]
|
| 536 |
+
else:
|
| 537 |
+
return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None]
|
| 538 |
+
|
| 539 |
+
scales = bitcast_u8_to_f32(packed[..., -8:-4])
|
| 540 |
+
offsets = bitcast_u8_to_f32(packed[..., -4:])
|
| 541 |
+
return packed[..., :-8].to(torch.float32) * scales + offsets
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
@register_decomposition([aten.grid_sampler_2d])
|
| 545 |
+
@pw_cast_for_opmath
|
| 546 |
+
def grid_sampler_2d(
|
| 547 |
+
a: torch.Tensor,
|
| 548 |
+
grid: torch.Tensor,
|
| 549 |
+
interpolation_mode: int = 0,
|
| 550 |
+
padding_mode: int = 0,
|
| 551 |
+
align_corners: bool = False,
|
| 552 |
+
) -> torch.Tensor:
|
| 553 |
+
# We do not expand the grid (_expand_grid=False) on cpu for performance reasons
|
| 554 |
+
# Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
|
| 555 |
+
# and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
|
| 556 |
+
# However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
|
| 557 |
+
# Thus we apply this hack to not expand the grid for this case.
|
| 558 |
+
_expand_grid = not (
|
| 559 |
+
a.device == torch.device("cpu")
|
| 560 |
+
and interpolation_mode == 0
|
| 561 |
+
and a.is_contiguous(memory_format=torch.contiguous_format)
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
output = decomp_grid_sampler_2d(
|
| 565 |
+
a,
|
| 566 |
+
grid=grid,
|
| 567 |
+
interpolation_mode=interpolation_mode,
|
| 568 |
+
padding_mode=padding_mode,
|
| 569 |
+
align_corners=align_corners,
|
| 570 |
+
_expand_grid=_expand_grid,
|
| 571 |
+
)
|
| 572 |
+
return output
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
@register_decomposition(aten._foreach_addcmul.Scalar)
|
| 576 |
+
def _foreach_addcmul_scalar(self, left_tensors, right_tensors, scalar=1):
|
| 577 |
+
return aten._foreach_add.List(
|
| 578 |
+
self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
@register_decomposition(aten._foreach_addcdiv.Scalar)
|
| 583 |
+
def _foreach_addcdiv_scalar(self, left_tensors, right_tensors, scalar=1):
|
| 584 |
+
return aten._foreach_add.List(
|
| 585 |
+
self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
@register_decomposition(aten._foreach_lerp.Scalar)
|
| 590 |
+
def _foreach_lerp_scalar(start_tensors, end_tensors, weight):
|
| 591 |
+
return aten._foreach_add.List(
|
| 592 |
+
start_tensors,
|
| 593 |
+
aten._foreach_mul.Scalar(
|
| 594 |
+
aten._foreach_sub.List(end_tensors, start_tensors), weight
|
| 595 |
+
),
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
|
| 600 |
+
@register_decomposition(aten.miopen_batch_norm)
|
| 601 |
+
def miopen_batch_norm(
|
| 602 |
+
input: torch.Tensor,
|
| 603 |
+
weight: torch.Tensor,
|
| 604 |
+
bias: typing.Optional[torch.Tensor],
|
| 605 |
+
running_mean: typing.Optional[torch.Tensor],
|
| 606 |
+
running_var: typing.Optional[torch.Tensor],
|
| 607 |
+
training: bool,
|
| 608 |
+
exponential_average_factor: float,
|
| 609 |
+
epsilon: float,
|
| 610 |
+
):
|
| 611 |
+
a, b, c = aten.native_batch_norm(
|
| 612 |
+
input,
|
| 613 |
+
weight,
|
| 614 |
+
bias,
|
| 615 |
+
running_mean,
|
| 616 |
+
running_var,
|
| 617 |
+
training,
|
| 618 |
+
exponential_average_factor,
|
| 619 |
+
epsilon,
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
if training:
|
| 623 |
+
return (a, b, c)
|
| 624 |
+
return (
|
| 625 |
+
a,
|
| 626 |
+
weight.new_zeros((0,)),
|
| 627 |
+
weight.new_zeros((0,)),
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
@functools.lru_cache(None)
|
| 632 |
+
def fast_random_decomps():
|
| 633 |
+
return {**decompositions, **extra_random_decomps}
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def select_decomp_table():
|
| 637 |
+
"""decomps can change based on config"""
|
| 638 |
+
if config.fallback_random:
|
| 639 |
+
return decompositions
|
| 640 |
+
return fast_random_decomps()
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
@register_decomposition(aten.masked_scatter)
|
| 644 |
+
def masked_scatter(self, mask, source):
|
| 645 |
+
if self.device.type == "cuda":
|
| 646 |
+
# This two-step algorithm is the same as eager CUDA, for eager CPU we
|
| 647 |
+
# use a 1-shot serial iteration.
|
| 648 |
+
self, mask = aten.broadcast_tensors([self, mask])
|
| 649 |
+
source_idx = mask.reshape(-1).cumsum(0) - 1
|
| 650 |
+
return inductor_prims.masked_scatter_with_index(self, mask, source_idx, source)
|
| 651 |
+
return NotImplemented
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
@register_decomposition(quantized_decomposed.choose_qparams.tensor)
|
| 655 |
+
def choose_qparams_tensor(
|
| 656 |
+
input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
|
| 657 |
+
):
|
| 658 |
+
min_val, max_val = torch.aminmax(input)
|
| 659 |
+
scale = (max_val - min_val) / float(quant_max - quant_min)
|
| 660 |
+
scale = torch.max(scale, torch.Tensor([eps]))
|
| 661 |
+
zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
|
| 662 |
+
zero_point = torch.clamp(zero_point, quant_min, quant_max)
|
| 663 |
+
return scale.to(torch.float64), zero_point.to(torch.int64)
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
@register_decomposition(aten.put)
|
| 667 |
+
def put(self, index, source, accumulate=False):
|
| 668 |
+
flattened = self.flatten()
|
| 669 |
+
flattened = torch.index_put(
|
| 670 |
+
flattened, [index], source.reshape(index.shape), accumulate
|
| 671 |
+
)
|
| 672 |
+
return flattened.reshape(self.shape)
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
@register_decomposition(aten.put_)
|
| 676 |
+
def put_(self, index, source, accumulate=False):
|
| 677 |
+
out = aten.put(self, index, source, accumulate=accumulate)
|
| 678 |
+
return self.copy_(out)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_operators.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.library
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
from torch.autograd import Function
|
| 4 |
+
|
| 5 |
+
_test_lib_def = torch.library.Library("_inductor_test", "DEF")
|
| 6 |
+
_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag)
|
| 7 |
+
|
| 8 |
+
_test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
|
| 9 |
+
for dispatch_key in ("CPU", "CUDA", "Meta"):
|
| 10 |
+
_test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Realize(Function):
|
| 14 |
+
@staticmethod
|
| 15 |
+
def forward(ctx, x):
|
| 16 |
+
return torch.ops._inductor_test.realize(x)
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def backward(ctx, grad_output):
|
| 20 |
+
return grad_output
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def realize(x: Tensor) -> Tensor:
|
| 24 |
+
return Realize.apply(x)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/virtualized.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file provides a number of "global" variables/handlers that are actually
|
| 3 |
+
thread local and dynamically scoped, with Inductor patching them to various
|
| 4 |
+
implementations depending on the situation.
|
| 5 |
+
|
| 6 |
+
These handlers are interacted with in a fairly stylized way. Typically,
|
| 7 |
+
we will import V from this module::
|
| 8 |
+
|
| 9 |
+
from .virtualized import V
|
| 10 |
+
|
| 11 |
+
Various handlers are accessible as attributes on this module; for example,
|
| 12 |
+
you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with
|
| 13 |
+
a number.
|
| 14 |
+
|
| 15 |
+
There are a few distinct usage patterns for virtualized global variables:
|
| 16 |
+
|
| 17 |
+
1. Implicit argument passing. Examples: ``V.current_node``, ``V.aot_compilation``.
|
| 18 |
+
Use ``V.set_current_node`` to change what the current node is while we're
|
| 19 |
+
executing some region of code, so code inside that region can query ``V.current_node``
|
| 20 |
+
to find out what it is. This is often more convenient than manually threading
|
| 21 |
+
the current node as an argument through all call stacks.
|
| 22 |
+
|
| 23 |
+
2. Per-compilation global state. Examples: ``V.fake_mode``, ``V.graph``. For a
|
| 24 |
+
given ``compile_fx`` invocation, these typically don't change, but they are
|
| 25 |
+
associated with some internal state so they cannot just be global functions.
|
| 26 |
+
We install these objects at the beginning of compilation and then you can
|
| 27 |
+
conveniently access them without having to pass them around.
|
| 28 |
+
|
| 29 |
+
3. Alternate define-by-run interpretations. Examples: ``V.ops``, ``V.kernel``.
|
| 30 |
+
A commonly used IR in Inductor is define-by-run: instead of maintaining
|
| 31 |
+
explicit syntax data structures, we instead represent loop bodies as
|
| 32 |
+
callable functions, which internally invoke operations defined on
|
| 33 |
+
``V.ops``. To perform semantic analysis, print or code generate these
|
| 34 |
+
operations, we dynamically patch ``V.ops`` with an alternate handler with
|
| 35 |
+
the intended semantics and then run the callable function. For example, to
|
| 36 |
+
extract out a traditional (FX) graph representation of the define-by-run
|
| 37 |
+
IR, simply install a handler that records each ``ops`` call to a graph.
|
| 38 |
+
|
| 39 |
+
TODO: Define a parent class / protocol that defines all of the operations
|
| 40 |
+
V.ops is expected to support.
|
| 41 |
+
|
| 42 |
+
It is typically an error to access a virtualized global without having installed
|
| 43 |
+
an appropriate handler (you will get a NullHandler), although in some cases we
|
| 44 |
+
provide a default implementation.
|
| 45 |
+
|
| 46 |
+
One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is
|
| 47 |
+
ubiquitous enough to have its own top level variable, so you will typically see
|
| 48 |
+
``ops.constant(...)`` rather than ``V.ops.constant(...)``. In fact, these are not
|
| 49 |
+
equivalent; the former interface supports arithmetic overloads like ``x + y``
|
| 50 |
+
instead of forcing ``ops.add(x, y)``, so it should be preferred.
|
| 51 |
+
|
| 52 |
+
Some operators are seemingly unused, but they are implicitly used by ops_wrapper.
|
| 53 |
+
In particular, we typically have an operator for every basic pointwise PyTorch operation
|
| 54 |
+
supported.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
from __future__ import annotations
|
| 58 |
+
|
| 59 |
+
from contextlib import AbstractContextManager, contextmanager
|
| 60 |
+
from threading import local
|
| 61 |
+
from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union
|
| 62 |
+
|
| 63 |
+
from .ops_handler import ( # noqa: F401
|
| 64 |
+
KernelFormatterHandler,
|
| 65 |
+
MockHandler,
|
| 66 |
+
OpsHandler,
|
| 67 |
+
ReductionType,
|
| 68 |
+
StoreMode,
|
| 69 |
+
WrapperHandler,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if TYPE_CHECKING:
|
| 73 |
+
import torch
|
| 74 |
+
from torch._inductor.debug import DebugContext
|
| 75 |
+
from torch._inductor.graph import GraphLowering
|
| 76 |
+
from torch._inductor.ir import InterpreterShim
|
| 77 |
+
from torch._subclasses import FakeTensorMode
|
| 78 |
+
|
| 79 |
+
threadlocal = local()
|
| 80 |
+
|
| 81 |
+
T = TypeVar("T")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class NullHandler:
|
| 85 |
+
"""
|
| 86 |
+
Sentinel indicating that a global variable is unset ala None. Typically,
|
| 87 |
+
attempting to access the global variable before it's set is an error, but with
|
| 88 |
+
NullHandler it won't fail until you try to access an attribute on it.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Virtualized(Generic[T]):
|
| 95 |
+
"""
|
| 96 |
+
Implements a global variable that redirects via thread local variable
|
| 97 |
+
(NB: construct this class to create the global variable; this is not
|
| 98 |
+
a singleton class!)
|
| 99 |
+
|
| 100 |
+
This allows us to swap in different op implementations in codegen.
|
| 101 |
+
|
| 102 |
+
NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is
|
| 103 |
+
the default value of the variable), we sometimes use these variables to
|
| 104 |
+
store other things, like booleans.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]):
|
| 108 |
+
self._key: str = f"__torchinductor_{vname}"
|
| 109 |
+
self._default = default
|
| 110 |
+
|
| 111 |
+
def _set_handler(self, value: T) -> AbstractContextManager[None]:
|
| 112 |
+
prior = self._get_handler()
|
| 113 |
+
setattr(threadlocal, self._key, value)
|
| 114 |
+
|
| 115 |
+
@contextmanager
|
| 116 |
+
def ctx():
|
| 117 |
+
try:
|
| 118 |
+
yield
|
| 119 |
+
finally:
|
| 120 |
+
self._set_handler(prior)
|
| 121 |
+
|
| 122 |
+
return ctx()
|
| 123 |
+
|
| 124 |
+
def _get_handler(self) -> T:
|
| 125 |
+
try:
|
| 126 |
+
return getattr(threadlocal, self._key)
|
| 127 |
+
except AttributeError:
|
| 128 |
+
# TODO: To be honest, I feel we probably should just error in this
|
| 129 |
+
# case, instead of making a null handler that will probably error
|
| 130 |
+
# when you getattr on it
|
| 131 |
+
return self._default() # type: ignore[return-value]
|
| 132 |
+
|
| 133 |
+
def __getattr__(self, name: str) -> Any:
|
| 134 |
+
return getattr(self._get_handler(), name)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class NullKernelHandler(NullHandler):
|
| 138 |
+
"""
|
| 139 |
+
We need access `V.kernel.removed_buffers` in DeferredLine class when there
|
| 140 |
+
is no kernel in the context. This happens when codegening the wrapper.
|
| 141 |
+
Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't
|
| 142 |
+
need call 'getattr' with default value which is error prone to typo in
|
| 143 |
+
attribute name.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(self):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.removed_buffers = set()
|
| 149 |
+
self.inplaced_to_remove = set()
|
| 150 |
+
self.index_dtype = "tl.int64"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
|
| 154 |
+
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
|
| 155 |
+
_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
|
| 156 |
+
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
|
| 157 |
+
_kernel: Virtualized[NullKernelHandler] = Virtualized(
|
| 158 |
+
"kernel", NullKernelHandler
|
| 159 |
+
) # TODO: improve type
|
| 160 |
+
_debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler)
|
| 161 |
+
_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler)
|
| 162 |
+
_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler)
|
| 163 |
+
_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class OpsValue:
|
| 167 |
+
"""The return type of most ops calls.
|
| 168 |
+
|
| 169 |
+
This exists so we can overload magic methods, and write mathematical
|
| 170 |
+
expressions much more fluently. So instead of
|
| 171 |
+
|
| 172 |
+
ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1)
|
| 173 |
+
|
| 174 |
+
we can write
|
| 175 |
+
|
| 176 |
+
(_Ap2 * x - _Ap3) * x * x + _1
|
| 177 |
+
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
value: Any
|
| 181 |
+
|
| 182 |
+
def __init__(self, value):
|
| 183 |
+
self.value = value
|
| 184 |
+
|
| 185 |
+
def __str__(self):
|
| 186 |
+
return str(self.value)
|
| 187 |
+
|
| 188 |
+
def __repr__(self):
|
| 189 |
+
return f"OpsValue({self.value!r})"
|
| 190 |
+
|
| 191 |
+
def __add__(self, other):
|
| 192 |
+
return ops.add(self, other)
|
| 193 |
+
|
| 194 |
+
def __mul__(self, other):
|
| 195 |
+
return ops.mul(self, other)
|
| 196 |
+
|
| 197 |
+
def __sub__(self, other):
|
| 198 |
+
return ops.sub(self, other)
|
| 199 |
+
|
| 200 |
+
def __neg__(self):
|
| 201 |
+
return ops.neg(self)
|
| 202 |
+
|
| 203 |
+
def __truediv__(self, other):
|
| 204 |
+
return ops.truediv(self, other)
|
| 205 |
+
|
| 206 |
+
def __floordiv__(self, other):
|
| 207 |
+
return ops.floordiv(self, other)
|
| 208 |
+
|
| 209 |
+
def __mod__(self, other):
|
| 210 |
+
return ops.mod(self, other)
|
| 211 |
+
|
| 212 |
+
def __pow__(self, other):
|
| 213 |
+
return ops.pow(self, other)
|
| 214 |
+
|
| 215 |
+
def __lt__(self, other):
|
| 216 |
+
return ops.lt(self, other)
|
| 217 |
+
|
| 218 |
+
def __le__(self, other):
|
| 219 |
+
return ops.le(self, other)
|
| 220 |
+
|
| 221 |
+
def __eq__(self, other):
|
| 222 |
+
return ops.eq(self, other)
|
| 223 |
+
|
| 224 |
+
def __ne__(self, other):
|
| 225 |
+
return ops.ne(self, other)
|
| 226 |
+
|
| 227 |
+
def __gt__(self, other):
|
| 228 |
+
return ops.gt(self, other)
|
| 229 |
+
|
| 230 |
+
def __ge__(self, other):
|
| 231 |
+
return ops.ge(self, other)
|
| 232 |
+
|
| 233 |
+
def __and__(self, other):
|
| 234 |
+
return ops.bitwise_and(self, other)
|
| 235 |
+
|
| 236 |
+
def __or__(self, other):
|
| 237 |
+
return ops.bitwise_or(self, other)
|
| 238 |
+
|
| 239 |
+
def __xor__(self, other):
|
| 240 |
+
return ops.bitwise_xor(self, other)
|
| 241 |
+
|
| 242 |
+
def __invert__(self):
|
| 243 |
+
return ops.bitwise_not(self)
|
| 244 |
+
|
| 245 |
+
def __rshfit__(self, n):
|
| 246 |
+
return ops.bitwise_right_shift(self, n)
|
| 247 |
+
|
| 248 |
+
def __lshift__(self, n):
|
| 249 |
+
return ops.bitwise_left_shift(self, n)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class OpsWrapper:
|
| 253 |
+
"""This wraps any returned IR values into an `OpsValue` instance, so that we
|
| 254 |
+
can overload the magic methods for writing mathematical expressions fluently.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __getattr__(self, name):
|
| 258 |
+
def inner(*args, **kwargs):
|
| 259 |
+
new_args = [OpsWrapper._unwrap(a) for a in args]
|
| 260 |
+
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
|
| 261 |
+
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
|
| 262 |
+
|
| 263 |
+
return inner
|
| 264 |
+
|
| 265 |
+
@staticmethod
|
| 266 |
+
def _unwrap(x):
|
| 267 |
+
if isinstance(x, (list, tuple)):
|
| 268 |
+
return tuple(OpsWrapper._unwrap(v) for v in x)
|
| 269 |
+
if isinstance(x, OpsValue):
|
| 270 |
+
return x.value
|
| 271 |
+
return x
|
| 272 |
+
|
| 273 |
+
@staticmethod
|
| 274 |
+
def _wrap(x):
|
| 275 |
+
if isinstance(x, (list, tuple)):
|
| 276 |
+
return tuple(OpsValue(v) for v in x)
|
| 277 |
+
return OpsValue(x)
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def indirect_indexing(index, size, check=True):
|
| 281 |
+
# Returns a sympy value, not IR value
|
| 282 |
+
index = OpsWrapper._unwrap(index)
|
| 283 |
+
return _ops.indirect_indexing(index, size, check)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
ops = OpsWrapper()
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class _V:
|
| 290 |
+
MockHandler = MockHandler
|
| 291 |
+
KernelFormatterHandler = KernelFormatterHandler
|
| 292 |
+
WrapperHandler = WrapperHandler
|
| 293 |
+
|
| 294 |
+
set_ops_handler: Callable[[Any], Any] = _ops._set_handler
|
| 295 |
+
get_ops_handler: Callable[[], Any] = _ops._get_handler
|
| 296 |
+
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
|
| 297 |
+
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
|
| 298 |
+
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
|
| 299 |
+
set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler
|
| 300 |
+
get_fake_mode: Callable[[], Any] = _fake_mode._get_handler
|
| 301 |
+
set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler
|
| 302 |
+
set_debug_handler: Callable[[Any], Any] = _debug._set_handler
|
| 303 |
+
set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler
|
| 304 |
+
set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler
|
| 305 |
+
get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
|
| 306 |
+
set_current_node: Callable[[Any], Any] = _current_node._set_handler
|
| 307 |
+
get_current_node: Callable[[], Any] = _current_node._get_handler
|
| 308 |
+
|
| 309 |
+
@property
|
| 310 |
+
def ops(self) -> OpsHandler[Any]:
|
| 311 |
+
"""The operator handler specific to the current codegen task"""
|
| 312 |
+
return _ops._get_handler()
|
| 313 |
+
|
| 314 |
+
@property
|
| 315 |
+
def graph(self) -> GraphLowering:
|
| 316 |
+
"""The graph currently being generated"""
|
| 317 |
+
return _graph._get_handler()
|
| 318 |
+
|
| 319 |
+
@property
|
| 320 |
+
def real_inputs(self):
|
| 321 |
+
"""non-fake example inputs"""
|
| 322 |
+
return _real_inputs._get_handler()
|
| 323 |
+
|
| 324 |
+
@property
|
| 325 |
+
def fake_mode(self):
|
| 326 |
+
"""The graph currently being generated"""
|
| 327 |
+
return _fake_mode._get_handler()
|
| 328 |
+
|
| 329 |
+
@property
|
| 330 |
+
def kernel(self):
|
| 331 |
+
"""The kernel currently being generated"""
|
| 332 |
+
return _kernel._get_handler()
|
| 333 |
+
|
| 334 |
+
@property
|
| 335 |
+
def debug(self):
|
| 336 |
+
return _debug._get_handler()
|
| 337 |
+
|
| 338 |
+
@property
|
| 339 |
+
def interpreter(self):
|
| 340 |
+
return _interpreter._get_handler()
|
| 341 |
+
|
| 342 |
+
@property
|
| 343 |
+
def aot_compilation(self):
|
| 344 |
+
return _aot_compilation._get_handler()
|
| 345 |
+
|
| 346 |
+
@property
|
| 347 |
+
def current_node(self):
|
| 348 |
+
return _current_node._get_handler()
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
V = _V()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <c10/util/Exception.h>
|
| 5 |
+
#include <c10/util/string_view.h>
|
| 6 |
+
|
| 7 |
+
namespace c10 {
|
| 8 |
+
class Scalar;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
struct TensorIterator;
|
| 13 |
+
struct TensorIteratorBase;
|
| 14 |
+
class TensorBase;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
namespace at::native {
|
| 18 |
+
|
| 19 |
+
// These constants control the approximation behavior of gelu function.
|
| 20 |
+
enum class GeluType {
|
| 21 |
+
None, // Baseline Gelu
|
| 22 |
+
Tanh, // Tahn Gelu Approximation
|
| 23 |
+
END
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
static GeluType get_gelutype_enum(const c10::string_view approximate) {
|
| 27 |
+
if (approximate == "none") {
|
| 28 |
+
return GeluType::None;
|
| 29 |
+
} else if (approximate == "tanh") {
|
| 30 |
+
return GeluType::Tanh;
|
| 31 |
+
} else {
|
| 32 |
+
TORCH_CHECK(false, "approximate argument must be either none or tanh.");
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
static std::string gelutype_to_string(const GeluType type) {
|
| 37 |
+
switch(type) {
|
| 38 |
+
case GeluType::None: return "none";
|
| 39 |
+
case GeluType::Tanh: return "tanh";
|
| 40 |
+
default: TORCH_CHECK(false, "unknown GELU type: ", static_cast<int>(type));
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
using structured_activation_fn = void (*)(TensorIteratorBase&);
|
| 45 |
+
using structured_activation_backward_fn = void (*)(TensorIteratorBase&);
|
| 46 |
+
|
| 47 |
+
using activation_fn = void (*)(TensorIterator&);
|
| 48 |
+
using activation_backward_fn = void (*)(TensorIterator&);
|
| 49 |
+
using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
|
| 50 |
+
using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
|
| 51 |
+
using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
|
| 52 |
+
using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&);
|
| 53 |
+
using hardsigmoid_fn = void(*)(TensorIteratorBase&);
|
| 54 |
+
using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&);
|
| 55 |
+
using hardswish_fn = void(*)(TensorIterator&);
|
| 56 |
+
using hardswish_backward_fn = void(*)(TensorIterator&);
|
| 57 |
+
using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 58 |
+
using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 59 |
+
using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 60 |
+
using elu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&);
|
| 61 |
+
using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&, bool);
|
| 62 |
+
using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 63 |
+
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 64 |
+
using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
|
| 65 |
+
using gelu_fn = void (*)(TensorIteratorBase&, GeluType);
|
| 66 |
+
using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType);
|
| 67 |
+
using glu_jvp_fn = void (*)(TensorIteratorBase&);
|
| 68 |
+
|
| 69 |
+
DECLARE_DISPATCH(elu_fn, elu_stub);
|
| 70 |
+
DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
|
| 71 |
+
DECLARE_DISPATCH(softplus_fn, softplus_stub);
|
| 72 |
+
DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
|
| 73 |
+
DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
|
| 74 |
+
DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
|
| 75 |
+
DECLARE_DISPATCH(threshold_fn, threshold_stub);
|
| 76 |
+
DECLARE_DISPATCH(gelu_fn, GeluKernel);
|
| 77 |
+
DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
|
| 78 |
+
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
|
| 79 |
+
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
|
| 80 |
+
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
|
| 81 |
+
DECLARE_DISPATCH(hardswish_fn, hardswish_stub);
|
| 82 |
+
DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub);
|
| 83 |
+
DECLARE_DISPATCH(shrink_fn, hardshrink_stub);
|
| 84 |
+
DECLARE_DISPATCH(softshrink_fn, softshrink_stub);
|
| 85 |
+
DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);
|
| 86 |
+
DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub);
|
| 87 |
+
DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);
|
| 88 |
+
DECLARE_DISPATCH(structured_activation_fn, glu_stub);
|
| 89 |
+
DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub);
|
| 90 |
+
DECLARE_DISPATCH(glu_jvp_fn, glu_jvp_stub);
|
| 91 |
+
DECLARE_DISPATCH(structured_activation_fn, silu_stub);
|
| 92 |
+
DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub);
|
| 93 |
+
DECLARE_DISPATCH(structured_activation_fn, mish_stub);
|
| 94 |
+
DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub);
|
| 95 |
+
DECLARE_DISPATCH(activation_fn, prelu_stub);
|
| 96 |
+
DECLARE_DISPATCH(activation_backward_fn, prelu_backward_stub);
|
| 97 |
+
|
| 98 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <c10/util/ArrayRef.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
#include <cmath>
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
using adaptive_avg_pooling_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
|
| 12 |
+
using adaptive_avg_pooling_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
|
| 13 |
+
DECLARE_DISPATCH(adaptive_avg_pooling_fn, adaptive_avg_pool2d_kernel);
|
| 14 |
+
DECLARE_DISPATCH(adaptive_avg_pooling_backward_fn, adaptive_avg_pool2d_backward_kernel);
|
| 15 |
+
|
| 16 |
+
using adaptive_max_pooling_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
|
| 17 |
+
using adaptive_max_pooling_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
|
| 18 |
+
DECLARE_DISPATCH(adaptive_max_pooling_fn, adaptive_max_pool2d_kernel);
|
| 19 |
+
DECLARE_DISPATCH(adaptive_max_pooling_backward_fn, adaptive_max_pool2d_backward_kernel);
|
| 20 |
+
|
| 21 |
+
static inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
|
| 22 |
+
return (a / b) * c + ((a % b) * c) / b;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
static inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
|
| 26 |
+
return 1 + ((a + 1) * c - 1) / b;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
static inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) {
|
| 30 |
+
int64_t ndim = gradOutput_.ndimension();
|
| 31 |
+
for (const auto i : c10::irange(1, ndim)) {
|
| 32 |
+
TORCH_CHECK(gradOutput_.size(i) > 0,
|
| 33 |
+
arg_name, "(): Expected grad_output to have non-zero size for non-batch dimensions, "
|
| 34 |
+
"but grad_output has sizes ", gradOutput_.sizes(), " with dimension ", i,
|
| 35 |
+
" being empty");
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BucketizationUtils.h
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/TypeProperties.h>
|
| 5 |
+
#include <ATen/ScalarOps.h>
|
| 6 |
+
|
| 7 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 8 |
+
#include <ATen/NativeFunctions.h>
|
| 9 |
+
#else
|
| 10 |
+
#include <ATen/ops/result_type.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
|
| 15 |
+
// original values given by raw_*. If an original value is not contiguous, will make a contiguous copy to
|
| 16 |
+
// the corresponding trimmed_* value. Additionally, if the dtypes of the boundary and input tensor do not
|
| 17 |
+
// match, will change them to be a common super type so comparisons are done between the same types.
|
| 18 |
+
// For any trimmed_* tensor, if its outgoing value matches what it was incoming (typically null), then the
|
| 19 |
+
// corresponding raw_* version should be used since it was already contiguous of the right type.
|
| 20 |
+
inline void searchsorted_maybe_trim_input_tensors(
|
| 21 |
+
Tensor& trimmed_input,
|
| 22 |
+
Tensor& trimmed_boundaries,
|
| 23 |
+
Tensor& trimmed_sorter,
|
| 24 |
+
const Tensor& raw_input,
|
| 25 |
+
const Tensor& raw_boundaries,
|
| 26 |
+
const Tensor& raw_sorter) {
|
| 27 |
+
bool in_is_contiguous = raw_input.is_contiguous();
|
| 28 |
+
bool bd_is_contiguous = raw_boundaries.is_contiguous();
|
| 29 |
+
bool sort_is_contiguous = raw_sorter.is_contiguous();
|
| 30 |
+
|
| 31 |
+
if (!in_is_contiguous) {
|
| 32 |
+
TORCH_WARN_ONCE("torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due "
|
| 33 |
+
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value "
|
| 34 |
+
"tensor if possible. This message will only appear once per program.");
|
| 35 |
+
trimmed_input = raw_input.contiguous();
|
| 36 |
+
}
|
| 37 |
+
if (!bd_is_contiguous) {
|
| 38 |
+
TORCH_WARN_ONCE("torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
|
| 39 |
+
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
|
| 40 |
+
"tensor if possible. This message will only appear once per program.");
|
| 41 |
+
trimmed_boundaries = raw_boundaries.contiguous();
|
| 42 |
+
}
|
| 43 |
+
if (!sort_is_contiguous) {
|
| 44 |
+
TORCH_WARN_ONCE("torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
|
| 45 |
+
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
|
| 46 |
+
"tensor if possible. This message will only appear once per program.");
|
| 47 |
+
trimmed_sorter = raw_sorter.contiguous();
|
| 48 |
+
}
|
| 49 |
+
if (raw_input.dtype() != raw_boundaries.dtype()) {
|
| 50 |
+
at::native::ResultTypeState state = {};
|
| 51 |
+
state = at::native::update_result_type_state(raw_boundaries, state);
|
| 52 |
+
state = at::native::update_result_type_state(raw_input, state);
|
| 53 |
+
ScalarType common_stype = at::native::result_type(state);
|
| 54 |
+
|
| 55 |
+
TORCH_INTERNAL_ASSERT(common_stype != ScalarType::Undefined);
|
| 56 |
+
if (common_stype != raw_input.scalar_type()) {
|
| 57 |
+
trimmed_input = in_is_contiguous ? raw_input.to(common_stype) : trimmed_input.to(common_stype);
|
| 58 |
+
}
|
| 59 |
+
if (common_stype != raw_boundaries.scalar_type()) {
|
| 60 |
+
trimmed_boundaries = bd_is_contiguous ? raw_boundaries.to(common_stype) : trimmed_boundaries.to(common_stype);
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
/* unused but needed for internal jagged tensor class */
|
| 66 |
+
inline void searchsorted_maybe_trim_input_tensors(
|
| 67 |
+
Tensor& trimmed_input,
|
| 68 |
+
Tensor& trimmed_boundaries,
|
| 69 |
+
const Tensor& raw_input,
|
| 70 |
+
const Tensor& raw_boundaries) {
|
| 71 |
+
Tensor trimmed_sorter;
|
| 72 |
+
Tensor raw_sorter;
|
| 73 |
+
return searchsorted_maybe_trim_input_tensors(
|
| 74 |
+
trimmed_input,
|
| 75 |
+
trimmed_boundaries,
|
| 76 |
+
trimmed_sorter,
|
| 77 |
+
raw_input,
|
| 78 |
+
raw_boundaries,
|
| 79 |
+
raw_sorter);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
inline bool searchsorted_dims_matched_before_last_dim(const Tensor& boundaries, const Tensor& input) {
|
| 83 |
+
if (boundaries.dim() != input.dim()) {
|
| 84 |
+
return false;
|
| 85 |
+
}
|
| 86 |
+
const auto& dims_bd = boundaries.sizes();
|
| 87 |
+
const auto& dims_in = input.sizes();
|
| 88 |
+
for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) {
|
| 89 |
+
if (dims_bd[dim] != dims_in[dim]) {
|
| 90 |
+
return false;
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
return true;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
inline Tensor searchsorted_scalar_tensor(const Scalar& scalar, const c10::Device& device) {
|
| 97 |
+
auto tensor = c10::scalar_to_tensor(scalar, device);
|
| 98 |
+
// This is to adopt the scalar promotion rules defined in native/TypeProperties.h
|
| 99 |
+
// So we have the same type promotion rules as binary operations.
|
| 100 |
+
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
|
| 101 |
+
return tensor;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
inline void searchsorted_pre_check(
|
| 105 |
+
const Tensor& boundaries,
|
| 106 |
+
const Tensor& input,
|
| 107 |
+
const Tensor& output,
|
| 108 |
+
const bool out_int32,
|
| 109 |
+
const bool right,
|
| 110 |
+
const c10::optional<c10::string_view> side_opt,
|
| 111 |
+
const Tensor& sorter) {
|
| 112 |
+
if (side_opt) {
|
| 113 |
+
const c10::string_view side = *side_opt;
|
| 114 |
+
TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ",
|
| 115 |
+
"got ", side);
|
| 116 |
+
|
| 117 |
+
// assume the user has not explicitly set (right=False, side="right")
|
| 118 |
+
TORCH_CHECK(!right || side == "right", "torch.searchsorted(): side and right can't be set to opposites, got side "
|
| 119 |
+
"of ", side, " while right was True");
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
TORCH_CHECK(boundaries.device() == input.device(), "torch.searchsorted(): boundaries and input value tensors ",
|
| 123 |
+
"should have same device type, but got boundaries tensor device type ", boundaries.device(), " and input value ",
|
| 124 |
+
"tensor device type ", input.device());
|
| 125 |
+
|
| 126 |
+
if (sorter.defined()) {
|
| 127 |
+
TORCH_CHECK(sorter.device() == boundaries.device(), "torch.searchsorted(): sorter and boundary tensors should ",
|
| 128 |
+
"have same device type, but got sorter tensor device type ", sorter.device(), " and input value tensor ",
|
| 129 |
+
"device type ", boundaries.device());
|
| 130 |
+
|
| 131 |
+
TORCH_CHECK(sorter.sizes() == boundaries.sizes(), "torch.searchsorted(): boundary and sorter must have the same "
|
| 132 |
+
"size, but got boundary tensor ", boundaries.sizes(), "and got sorter tensor ", sorter.sizes());
|
| 133 |
+
|
| 134 |
+
TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
|
| 135 |
+
"dtype but got dtype ", sorter.scalar_type());
|
| 136 |
+
|
| 137 |
+
if (sorter.numel() > 0) {
|
| 138 |
+
auto minmax = sorter.aminmax();
|
| 139 |
+
int64_t vmin = std::get<0>(minmax).item().toLong();
|
| 140 |
+
int64_t vmax = std::get<1>(minmax).item().toLong();
|
| 141 |
+
TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
|
| 146 |
+
"torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we got ",
|
| 147 |
+
"boundaries tensor dim(", boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(",
|
| 148 |
+
input.numel(), ")");
|
| 149 |
+
|
| 150 |
+
TORCH_CHECK(boundaries.dim() != 0, "torch.searchsorted(): boundaries tensor should have positive dimension, but ",
|
| 151 |
+
"got 0 dimension");
|
| 152 |
+
|
| 153 |
+
TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input),
|
| 154 |
+
"torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor ",
|
| 155 |
+
"and input value tensor must match, but we got boundaries tensor ", boundaries.sizes(), " and input value tensor ",
|
| 156 |
+
input.sizes());
|
| 157 |
+
|
| 158 |
+
ScalarType output_dtype = output.scalar_type();
|
| 159 |
+
TORCH_CHECK(
|
| 160 |
+
(output_dtype == ScalarType::Long && !out_int32) ||
|
| 161 |
+
(output_dtype == ScalarType::Int && out_int32),
|
| 162 |
+
"torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) depending on ",
|
| 163 |
+
"whether out_int32 flag is True, but we got output tensor's dtype ", output_dtype,
|
| 164 |
+
" and out_int32 flag is ", (out_int32 ? "True" : "False"));
|
| 165 |
+
|
| 166 |
+
if (out_int32) {
|
| 167 |
+
TORCH_CHECK(boundaries.sizes().back() < INT_MAX,
|
| 168 |
+
"torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX, ", but we got ",
|
| 169 |
+
boundaries.sizes().back());
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvUtils.h
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/TensorUtils.h>
|
| 4 |
+
#include <ATen/detail/CUDAHooksInterface.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
#include <c10/util/env.h>
|
| 7 |
+
#include <c10/util/irange.h>
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
using conv_depthwise2d_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
|
| 12 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 13 |
+
at::IntArrayRef, at::IntArrayRef, std::array<bool, 2>);
|
| 14 |
+
DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub);
|
| 15 |
+
using conv_depthwise3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 16 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 17 |
+
at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
|
| 18 |
+
DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub);
|
| 19 |
+
using cudnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
|
| 20 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 21 |
+
at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
|
| 22 |
+
DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub);
|
| 23 |
+
using mps_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 24 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 25 |
+
at::IntArrayRef, int64_t, std::array<bool,3>);
|
| 26 |
+
DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub);
|
| 27 |
+
using cudnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
|
| 28 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 29 |
+
at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
|
| 30 |
+
DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub);
|
| 31 |
+
using miopen_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 32 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 33 |
+
at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
|
| 34 |
+
DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub);
|
| 35 |
+
using miopen_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 36 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 37 |
+
at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
|
| 38 |
+
DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub);
|
| 39 |
+
using miopen_depthwise_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 40 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 41 |
+
at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
|
| 42 |
+
DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub);
|
| 43 |
+
using mkldnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 44 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 45 |
+
at::IntArrayRef, int64_t, std::array<bool,3>);
|
| 46 |
+
DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub);
|
| 47 |
+
using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const c10::optional<Tensor>&,
|
| 48 |
+
IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
|
| 49 |
+
DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub);
|
| 50 |
+
using mkldnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 51 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 52 |
+
at::IntArrayRef, at::IntArrayRef, int64_t, std::array<bool,3>);
|
| 53 |
+
DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub);
|
| 54 |
+
using slow_conv_dilated2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 55 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 56 |
+
at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
|
| 57 |
+
DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub);
|
| 58 |
+
using slow_conv_dilated3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 59 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 60 |
+
at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
|
| 61 |
+
DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub);
|
| 62 |
+
using slow_conv_transpose2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 63 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 64 |
+
at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
|
| 65 |
+
DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub);
|
| 66 |
+
using slow_conv_transpose3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 67 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 68 |
+
at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
|
| 69 |
+
DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub);
|
| 70 |
+
|
| 71 |
+
namespace {
|
| 72 |
+
static bool cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
static inline bool cudnnv8_enabled_check_debug() {
|
| 76 |
+
static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
|
| 77 |
+
static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
|
| 78 |
+
static uint8_t cudnnv8_debugcount = 0;
|
| 79 |
+
if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
|
| 80 |
+
TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", cudnnv8_heuristic_mode_b);
|
| 81 |
+
cudnnv8_debugcount++;
|
| 82 |
+
}
|
| 83 |
+
return cudnnv8_flag == 1;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
static inline bool cudnnv8_use_heur_mode_b() {
|
| 87 |
+
return cudnnv8_heuristic_mode_b;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// Keep in sync with py::enum_ in Module.cpp
|
| 91 |
+
enum class ConvBackend {
|
| 92 |
+
CudaDepthwise2d,
|
| 93 |
+
CudaDepthwise3d,
|
| 94 |
+
Cudnn,
|
| 95 |
+
CudnnTranspose,
|
| 96 |
+
Empty,
|
| 97 |
+
Miopen,
|
| 98 |
+
MiopenDepthwise,
|
| 99 |
+
MiopenTranspose,
|
| 100 |
+
Mkldnn,
|
| 101 |
+
MkldnnTranspose,
|
| 102 |
+
MkldnnEmpty,
|
| 103 |
+
NnpackSpatial,
|
| 104 |
+
Overrideable,
|
| 105 |
+
Slow2d,
|
| 106 |
+
Slow3d,
|
| 107 |
+
SlowDilated2d,
|
| 108 |
+
SlowDilated3d,
|
| 109 |
+
SlowTranspose2d,
|
| 110 |
+
SlowTranspose3d,
|
| 111 |
+
Winograd3x3Depthwise,
|
| 112 |
+
Xnnpack2d,
|
| 113 |
+
Mps,
|
| 114 |
+
MpsTranspose,
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
// Overload for selecting the convolution backend from the full set of convolution inputs.
|
| 118 |
+
// This overload is exposed to python for testing, etc.
|
| 119 |
+
TORCH_API ConvBackend select_conv_backend(
|
| 120 |
+
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
|
| 121 |
+
SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation,
|
| 122 |
+
bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
|
| 123 |
+
|
| 124 |
+
TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
|
| 125 |
+
const Tensor& weight,
|
| 126 |
+
const ConvBackend backend);
|
| 127 |
+
|
| 128 |
+
// ---------------------------------------------------------------------
|
| 129 |
+
//
|
| 130 |
+
// Math
|
| 131 |
+
//
|
| 132 |
+
// ---------------------------------------------------------------------
|
| 133 |
+
|
| 134 |
+
constexpr int input_batch_size_dim = 0; // also grad_input
|
| 135 |
+
constexpr int input_channels_dim = 1;
|
| 136 |
+
constexpr int output_batch_size_dim = 0; // also grad_output
|
| 137 |
+
constexpr int output_channels_dim = 1;
|
| 138 |
+
constexpr int weight_output_channels_dim = 0;
|
| 139 |
+
constexpr int weight_input_channels_dim = 1;
|
| 140 |
+
|
| 141 |
+
// Often written as 2 + max_dim (extra dims for batch size and channels)
|
| 142 |
+
constexpr int max_dim = 3;
|
| 143 |
+
|
| 144 |
+
// ---------------------------------------------------------------------
|
| 145 |
+
//
|
| 146 |
+
// Checking
|
| 147 |
+
//
|
| 148 |
+
// ---------------------------------------------------------------------
|
| 149 |
+
|
| 150 |
+
// Used on pad, stride and dilation
|
| 151 |
+
static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
|
| 152 |
+
{
|
| 153 |
+
TORCH_CHECK(args.size() <= expected_size,
|
| 154 |
+
"Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
|
| 155 |
+
expected_size, " (while checking arguments for ", c, ")");
|
| 156 |
+
TORCH_CHECK(args.size() >= expected_size,
|
| 157 |
+
"Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
|
| 158 |
+
expected_size, " (while checking arguments for ", c, ")");
|
| 159 |
+
|
| 160 |
+
auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
|
| 161 |
+
if (num_negative_values > 0){
|
| 162 |
+
std::stringstream ss;
|
| 163 |
+
ss << arg_name << " should be greater than zero but got (";
|
| 164 |
+
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
|
| 165 |
+
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
|
| 166 |
+
AT_ERROR(ss.str());
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
// NOTE [ Convolution checks ]
|
| 172 |
+
//
|
| 173 |
+
// NB: For many call sites, it is not strictly necessary to check all of
|
| 174 |
+
// these relationships (for example, for forward convolution, we compute
|
| 175 |
+
// the size of output ourselves, so we don't actually need to check
|
| 176 |
+
// output. However, writing a single function that does everything
|
| 177 |
+
// means we get to reuse it for both forwards and all backwards
|
| 178 |
+
// variants, even when the set of "real" inputs varies. The magic of
|
| 179 |
+
// relational computing!
|
| 180 |
+
//
|
| 181 |
+
// (There is one downside, which is that it is slightly harder to write
|
| 182 |
+
// error messages which are able to distinguish between real inputs
|
| 183 |
+
// (which the user can change) and computed inputs (which the user can
|
| 184 |
+
// only indirectly affect). It would be an interesting exercise to
|
| 185 |
+
// come up with a general framework to handle such situations.)
|
| 186 |
+
static void convolution_shape_check(
|
| 187 |
+
CheckedFrom c,
|
| 188 |
+
const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
|
| 189 |
+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
|
| 190 |
+
{
|
| 191 |
+
check_args(c, padding, input->dim() - 2, "padding");
|
| 192 |
+
check_args(c, stride, padding.size(), "stride");
|
| 193 |
+
check_args(c, dilation, padding.size(), "dilation");
|
| 194 |
+
|
| 195 |
+
// Input
|
| 196 |
+
checkDimRange(c, input, 3, 6 /* exclusive */);
|
| 197 |
+
checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
|
| 198 |
+
|
| 199 |
+
// Weight
|
| 200 |
+
checkSameDim(c, input, weight);
|
| 201 |
+
|
| 202 |
+
// TODO: check that output->size() matches output_sizes
|
| 203 |
+
// TODO: check that weight matches output->sizes()
|
| 204 |
+
checkSameDim(c, input, output);
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// NB: conv_output_size and conv_input_size are not bijections,
|
| 208 |
+
// as conv_output_size loses information; this is why conv_input_size
|
| 209 |
+
// takes an extra output_padding argument to resolve the ambiguity.
|
| 210 |
+
|
| 211 |
+
template <typename T>
|
| 212 |
+
static inline std::vector<T> _conv_output_size(
|
| 213 |
+
ArrayRef<T> input_size, ArrayRef<T> weight_size,
|
| 214 |
+
ArrayRef<T> padding, ArrayRef<T> stride, ArrayRef<T> dilation = ArrayRef<T>()
|
| 215 |
+
) {
|
| 216 |
+
// ASSERT(input_size.size() > 2)
|
| 217 |
+
// ASSERT(input_size.size() == weight_size.size())
|
| 218 |
+
bool has_dilation = !dilation.empty();
|
| 219 |
+
auto dim = input_size.size();
|
| 220 |
+
std::vector<T> output_size(dim);
|
| 221 |
+
output_size[0] = input_size[input_batch_size_dim];
|
| 222 |
+
output_size[1] = weight_size[weight_output_channels_dim];
|
| 223 |
+
for (const auto d : c10::irange(2, dim)) {
|
| 224 |
+
auto dilation_ = has_dilation ? dilation[d - 2] : 1;
|
| 225 |
+
auto kernel = dilation_ * (weight_size[d] - 1) + 1;
|
| 226 |
+
output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
|
| 227 |
+
}
|
| 228 |
+
return output_size;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
static inline std::vector<int64_t> conv_output_size(
|
| 232 |
+
IntArrayRef input_size, IntArrayRef weight_size,
|
| 233 |
+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
|
| 234 |
+
) {
|
| 235 |
+
return _conv_output_size(input_size, weight_size, padding, stride, dilation);
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
static inline std::vector<c10::SymInt> conv_output_size(
|
| 239 |
+
SymIntArrayRef input_size, SymIntArrayRef weight_size,
|
| 240 |
+
SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
|
| 241 |
+
) {
|
| 242 |
+
return _conv_output_size(input_size, weight_size, padding, stride, dilation);
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
template <typename T>
|
| 246 |
+
std::vector<T> _conv_input_size(
|
| 247 |
+
ArrayRef<T> output_size, ArrayRef<T> weight_size,
|
| 248 |
+
ArrayRef<T> padding, ArrayRef<T> output_padding, ArrayRef<T> stride, ArrayRef<T> dilation, T groups
|
| 249 |
+
) {
|
| 250 |
+
// ASSERT(output_size.size() > 2)
|
| 251 |
+
// ASSERT(output_size.size() == weight_size.size())
|
| 252 |
+
auto dim = output_size.size();
|
| 253 |
+
std::vector<T> input_size(dim);
|
| 254 |
+
input_size[0] = output_size[output_batch_size_dim];
|
| 255 |
+
input_size[1] = weight_size[weight_input_channels_dim] * groups;
|
| 256 |
+
for (const auto d : c10::irange(2, dim)) {
|
| 257 |
+
auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
|
| 258 |
+
input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
|
| 259 |
+
kernel + output_padding[d - 2];
|
| 260 |
+
}
|
| 261 |
+
return input_size;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
static inline std::vector<c10::SymInt> conv_input_size(
|
| 265 |
+
SymIntArrayRef output_size, SymIntArrayRef weight_size,
|
| 266 |
+
SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
|
| 267 |
+
) {
|
| 268 |
+
return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
static inline std::vector<int64_t> conv_input_size(
|
| 272 |
+
IntArrayRef output_size, IntArrayRef weight_size,
|
| 273 |
+
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
| 274 |
+
) {
|
| 275 |
+
return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
template <typename T>
|
| 279 |
+
std::vector<T> _conv_weight_size(
|
| 280 |
+
ArrayRef<T> input_size, ArrayRef<T> output_size,
|
| 281 |
+
ArrayRef<T> padding, ArrayRef<T> output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
| 282 |
+
) {
|
| 283 |
+
auto dim = input_size.size();
|
| 284 |
+
std::vector<T> weight_size(dim);
|
| 285 |
+
weight_size[0] = output_size[1];
|
| 286 |
+
weight_size[1] = input_size[1] / groups;
|
| 287 |
+
for (const auto d : c10::irange(2, dim)) {
|
| 288 |
+
auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
|
| 289 |
+
+ padding[d - 2] * 2 - output_padding[d - 2];
|
| 290 |
+
weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
|
| 291 |
+
}
|
| 292 |
+
return weight_size;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
static inline std::vector<c10::SymInt> conv_weight_size(
|
| 296 |
+
SymIntArrayRef input_size, SymIntArrayRef output_size,
|
| 297 |
+
SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
| 298 |
+
) {
|
| 299 |
+
return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
static inline std::vector<int64_t> conv_weight_size(
|
| 303 |
+
IntArrayRef input_size, IntArrayRef output_size,
|
| 304 |
+
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
| 305 |
+
) {
|
| 306 |
+
return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
|
| 310 |
+
std::vector<int64_t> shape(dim, 1);
|
| 311 |
+
shape[1] = -1;
|
| 312 |
+
return bias.reshape(shape);
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
static inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
|
| 316 |
+
// disable NHWC for float64 input.
|
| 317 |
+
if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
|
| 318 |
+
input.scalar_type() == at::kDouble ||
|
| 319 |
+
weight.scalar_type() == at::kDouble) {
|
| 320 |
+
return at::MemoryFormat::Contiguous;
|
| 321 |
+
}
|
| 322 |
+
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
|
| 323 |
+
auto input_memory_format = input.suggest_memory_format();
|
| 324 |
+
auto weight_memory_format = weight.suggest_memory_format();
|
| 325 |
+
auto weight_ndim = weight.ndimension();
|
| 326 |
+
|
| 327 |
+
bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
|
| 328 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
|
| 329 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast)
|
| 330 |
+
);
|
| 331 |
+
if (can_use_cudnn_channels_last_2d) {
|
| 332 |
+
return at::MemoryFormat::ChannelsLast;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
|
| 336 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
|
| 337 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast3d)
|
| 338 |
+
);
|
| 339 |
+
if (can_use_cudnn_channels_last_3d) {
|
| 340 |
+
return at::MemoryFormat::ChannelsLast3d;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
return at::MemoryFormat::Contiguous;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
// controls whether emptyCache will be called following cudnn conv benchmarking
|
| 347 |
+
TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
|
| 348 |
+
TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
| 352 |
+
|
| 353 |
+
// disable NHWC for float64 input.
|
| 354 |
+
if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
|
| 355 |
+
input.scalar_type() == at::kDouble ||
|
| 356 |
+
weight.scalar_type() == at::kDouble) {
|
| 357 |
+
return false;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
bool can_use_miopen_channels_last_2d = false;
|
| 361 |
+
#if defined(USE_ROCM) && (ROCM_VERSION >= 40300)
|
| 362 |
+
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
|
| 363 |
+
// See #64427
|
| 364 |
+
static c10::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
|
| 365 |
+
|
| 366 |
+
auto input_memory_format = input.suggest_memory_format();
|
| 367 |
+
auto weight_memory_format = weight.suggest_memory_format();
|
| 368 |
+
|
| 369 |
+
can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
|
| 370 |
+
( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
|
| 371 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast) )
|
| 372 |
+
);
|
| 373 |
+
#endif
|
| 374 |
+
|
| 375 |
+
bool can_use_miopen_channels_last_3d = false;
|
| 376 |
+
|
| 377 |
+
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
| 381 |
+
|
| 382 |
+
// disable NHWC for float64 input.
|
| 383 |
+
if (input.scalar_type() == at::kDouble ||
|
| 384 |
+
weight.scalar_type() == at::kDouble) {
|
| 385 |
+
return false;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
// disable NHWC for MkldnnCPU tensor.
|
| 389 |
+
if (input.is_mkldnn() || weight.is_mkldnn()) {
|
| 390 |
+
return false;
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
auto input_memory_format = input.suggest_memory_format();
|
| 394 |
+
auto weight_memory_format = weight.suggest_memory_format();
|
| 395 |
+
|
| 396 |
+
bool can_use_mkldnn_channels_last_2d =
|
| 397 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
|
| 398 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast);
|
| 399 |
+
|
| 400 |
+
bool can_use_mkldnn_channels_last_3d =
|
| 401 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
|
| 402 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast3d);
|
| 403 |
+
|
| 404 |
+
return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
static inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
| 408 |
+
|
| 409 |
+
auto input_memory_format = input.suggest_memory_format();
|
| 410 |
+
auto weight_memory_format = weight.suggest_memory_format();
|
| 411 |
+
|
| 412 |
+
bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
|
| 413 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast) || (
|
| 414 |
+
weight_memory_format == at::MemoryFormat::ChannelsLast));
|
| 415 |
+
|
| 416 |
+
return can_use_thnn_channels_last_2d;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
static inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
| 420 |
+
|
| 421 |
+
// check layout only for xpu tensor.
|
| 422 |
+
if (!input.is_xpu() || !weight.is_xpu()) {
|
| 423 |
+
return false;
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
// disable NHWC for float64 input.
|
| 427 |
+
if (input.scalar_type() == at::kDouble ||
|
| 428 |
+
weight.scalar_type() == at::kDouble) {
|
| 429 |
+
return false;
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
auto input_memory_format = input.suggest_memory_format();
|
| 433 |
+
auto weight_memory_format = weight.suggest_memory_format();
|
| 434 |
+
|
| 435 |
+
bool can_use_xpu_channels_last_2d =
|
| 436 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
|
| 437 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast);
|
| 438 |
+
|
| 439 |
+
bool can_use_xpu_channels_last_3d =
|
| 440 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
|
| 441 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast3d);
|
| 442 |
+
|
| 443 |
+
return can_use_xpu_channels_last_2d || can_use_xpu_channels_last_3d;
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Cross.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class Tensor;
|
| 7 |
+
|
| 8 |
+
namespace native {
|
| 9 |
+
|
| 10 |
+
using cross_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const int64_t d);
|
| 11 |
+
|
| 12 |
+
DECLARE_DISPATCH(cross_fn, cross_stub);
|
| 13 |
+
|
| 14 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DistributionTemplates.h
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/Dispatch_v2.h>
|
| 6 |
+
#include <ATen/Generator.h>
|
| 7 |
+
#include <ATen/ExpandUtils.h>
|
| 8 |
+
#include <ATen/Tensor.h>
|
| 9 |
+
#include <ATen/MemoryOverlap.h>
|
| 10 |
+
#include <ATen/NamedTensorUtils.h>
|
| 11 |
+
#include <ATen/native/Resize.h>
|
| 12 |
+
#include <ATen/native/TensorIterator.h>
|
| 13 |
+
#include <c10/util/Optional.h>
|
| 14 |
+
#include <limits>
|
| 15 |
+
#include <cmath>
|
| 16 |
+
|
| 17 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 18 |
+
#include <ATen/Functions.h>
|
| 19 |
+
#else
|
| 20 |
+
#include <ATen/ops/empty_like.h>
|
| 21 |
+
#include <ATen/ops/empty.h>
|
| 22 |
+
#include <ATen/ops/full.h>
|
| 23 |
+
#include <ATen/ops/view_as_real.h>
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
namespace at::native::templates {
|
| 27 |
+
|
| 28 |
+
// ==================================================== Random ========================================================
|
| 29 |
+
|
| 30 |
+
// The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`.
|
| 31 |
+
// The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t).
|
| 32 |
+
// This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance:
|
| 33 |
+
//
|
| 34 |
+
// auto actual = torch::empty({3, 3}, torch::half);
|
| 35 |
+
// actual.random_(0, 65504);
|
| 36 |
+
//
|
| 37 |
+
// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504
|
| 38 |
+
// and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to`
|
| 39 |
+
// moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to
|
| 40 |
+
// the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous
|
| 41 |
+
// available number for torch::half dtype.
|
| 42 |
+
template<typename scalar_t>
|
| 43 |
+
int64_t update_from(int64_t from) {
|
| 44 |
+
static_assert(
|
| 45 |
+
std::is_floating_point<scalar_t>::value ||
|
| 46 |
+
std::is_same<scalar_t, at::Half>::value ||
|
| 47 |
+
std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
|
| 48 |
+
const auto from_plus_1 = static_cast<int64_t>(static_cast<scalar_t>(from + 1));
|
| 49 |
+
if (from_plus_1 < from) {
|
| 50 |
+
int64_t from_ = std::abs(from + 1);
|
| 51 |
+
int n = 0;
|
| 52 |
+
while (from_ >>= 1) ++n;
|
| 53 |
+
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
|
| 54 |
+
from = from_plus_1 + (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
|
| 55 |
+
}
|
| 56 |
+
return from;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
template<typename scalar_t>
|
| 60 |
+
int64_t update_to(int64_t to) {
|
| 61 |
+
static_assert(
|
| 62 |
+
std::is_floating_point<scalar_t>::value ||
|
| 63 |
+
std::is_same<scalar_t, at::Half>::value ||
|
| 64 |
+
std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
|
| 65 |
+
const auto to_minus_1 = static_cast<int64_t>(static_cast<scalar_t>(to - 1));
|
| 66 |
+
if (to_minus_1 >= to) {
|
| 67 |
+
int64_t to_ = std::abs(to - 1);
|
| 68 |
+
int n = 0;
|
| 69 |
+
while (to_ >>= 1) ++n;
|
| 70 |
+
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
|
| 71 |
+
to = to_minus_1 - (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
|
| 72 |
+
}
|
| 73 |
+
return to;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// Return earlier for not invoking kernel.
|
| 77 |
+
// See https://github.com/pytorch/pytorch/issues/103418 for more details
|
| 78 |
+
#define CHECK_EMPTY_AND_RETURN(tensor) \
|
| 79 |
+
if (tensor.numel() == 0) { \
|
| 80 |
+
return tensor; \
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template<template<typename> class random_kernel, typename RNG>
|
| 84 |
+
at::Tensor& random_impl(at::Tensor& self, c10::optional<Generator> generator) {
|
| 85 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 86 |
+
auto iter = at::TensorIterator::borrowing_nullary_op(self);
|
| 87 |
+
random_kernel<RNG>()(iter, generator);
|
| 88 |
+
return self;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
#define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \
|
| 92 |
+
TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \
|
| 93 |
+
|
| 94 |
+
#define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \
|
| 95 |
+
if (var < -(1LL << digits) || var > (1LL << digits)) { \
|
| 96 |
+
TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \
|
| 97 |
+
"Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \
|
| 98 |
+
"This warning will become an error in version 1.7 release, please fix the code in advance"); \
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
static void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) {
|
| 102 |
+
const auto scalar_type = typeMetaToScalarType(dtype);
|
| 103 |
+
if (isFloatingType(scalar_type)) {
|
| 104 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] {
|
| 105 |
+
const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
|
| 106 |
+
const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
|
| 107 |
+
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
|
| 108 |
+
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
|
| 109 |
+
|
| 110 |
+
constexpr auto digits = std::numeric_limits<scalar_t>::digits;
|
| 111 |
+
WARN_OUT_OF_BOUNDS(from, "from", digits, dtype);
|
| 112 |
+
WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype);
|
| 113 |
+
});
|
| 114 |
+
} else if (scalar_type == kUInt64) {
|
| 115 |
+
// When you do a comparison between int64_t and uint64_t, the usual
|
| 116 |
+
// arithmetic conversions say that the int64_t value is promoted to
|
| 117 |
+
// unsigned. But this conversion wraps around: if I had -1 as my int64_t,
|
| 118 |
+
// then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never
|
| 119 |
+
// the right thing to do.
|
| 120 |
+
CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype);
|
| 121 |
+
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype);
|
| 122 |
+
} else if (isIntegralType(scalar_type, /*includeBool=*/true)) {
|
| 123 |
+
AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() {
|
| 124 |
+
const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());
|
| 125 |
+
const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
|
| 126 |
+
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
|
| 127 |
+
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
|
| 128 |
+
}), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool);
|
| 129 |
+
} else {
|
| 130 |
+
TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types");
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
template<template<typename> class random_from_to_kernel, typename RNG>
|
| 135 |
+
at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c10::optional<Generator> generator) {
|
| 136 |
+
uint64_t range = 0;
|
| 137 |
+
auto iter = at::TensorIterator::borrowing_nullary_op(self);
|
| 138 |
+
if (to_opt.has_value()) {
|
| 139 |
+
// [from, to)
|
| 140 |
+
int64_t to = *to_opt;
|
| 141 |
+
TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
|
| 142 |
+
if (isFloatingType(iter.dtype())) {
|
| 143 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] {
|
| 144 |
+
from = update_from<scalar_t>(from);
|
| 145 |
+
to = update_to<scalar_t>(to);
|
| 146 |
+
TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
|
| 147 |
+
});
|
| 148 |
+
}
|
| 149 |
+
check_from_to_in_range(from, to - 1, self.dtype());
|
| 150 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 151 |
+
range = static_cast<uint64_t>(to) - static_cast<uint64_t>(from);
|
| 152 |
+
random_from_to_kernel<RNG>()(iter, range, from, generator);
|
| 153 |
+
} else if (from != std::numeric_limits<int64_t>::lowest()) {
|
| 154 |
+
// [from, std::numeric_limits<int64_t>::max()]
|
| 155 |
+
int64_t to_inc = 0;
|
| 156 |
+
if (isFloatingType(iter.dtype())) {
|
| 157 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] {
|
| 158 |
+
constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
|
| 159 |
+
to_inc = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max() : static_cast<int64_t>(scalar_t_max);
|
| 160 |
+
from = update_from<scalar_t>(from);
|
| 161 |
+
TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc);
|
| 162 |
+
});
|
| 163 |
+
} else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
|
| 164 |
+
AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] {
|
| 165 |
+
if constexpr (std::is_same_v<scalar_t, bool>) {
|
| 166 |
+
to_inc = static_cast<int64_t>(true);
|
| 167 |
+
} else {
|
| 168 |
+
to_inc = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
|
| 169 |
+
}
|
| 170 |
+
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool);
|
| 171 |
+
} else {
|
| 172 |
+
TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types");
|
| 173 |
+
}
|
| 174 |
+
check_from_to_in_range(from, to_inc, self.dtype());
|
| 175 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 176 |
+
range = static_cast<uint64_t>(to_inc) - static_cast<uint64_t>(from) + 1;
|
| 177 |
+
random_from_to_kernel<RNG>()(iter, range, from, generator);
|
| 178 |
+
} else {
|
| 179 |
+
// [std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max()]
|
| 180 |
+
// range = 2^64
|
| 181 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 182 |
+
random_from_to_kernel<RNG>()(iter, generator);
|
| 183 |
+
}
|
| 184 |
+
return self;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
// ==================================================== Normal ========================================================
|
| 188 |
+
|
| 189 |
+
#define CHECK_NORMAL_TENSOR_STD(std) \
|
| 190 |
+
do { \
|
| 191 |
+
TORCH_CHECK( \
|
| 192 |
+
!std.is_complex(), \
|
| 193 |
+
"normal expects standard deviation to be non-complex"); \
|
| 194 |
+
TORCH_CHECK( \
|
| 195 |
+
std.numel() == 0 || std.is_meta() || std.min().ge(0).item<bool>(), \
|
| 196 |
+
"normal expects all elements of std >= 0.0"); \
|
| 197 |
+
} while (0)
|
| 198 |
+
|
| 199 |
+
#define CHECK_NORMAL_STD(std) \
|
| 200 |
+
TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std);
|
| 201 |
+
|
| 202 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 203 |
+
Tensor& normal_impl_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
|
| 204 |
+
CHECK_NORMAL_STD(std);
|
| 205 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 206 |
+
|
| 207 |
+
if (self.is_complex()) {
|
| 208 |
+
auto float_tensor = at::view_as_real(self);
|
| 209 |
+
// variance for normal distribution of the real and imaginary values
|
| 210 |
+
// is half of the input variance
|
| 211 |
+
normal_kernel<RNG>()(float_tensor, mean, std/(std::sqrt(2)), gen);
|
| 212 |
+
} else {
|
| 213 |
+
normal_kernel<RNG>()(self, mean, std, gen);
|
| 214 |
+
}
|
| 215 |
+
return self;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 219 |
+
Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, c10::optional<Generator> gen) {
|
| 220 |
+
CHECK_NORMAL_STD(std);
|
| 221 |
+
auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous);
|
| 222 |
+
auto shape = at::infer_size(mean.sizes(), std_tensor.sizes());
|
| 223 |
+
at::native::resize_output(output, shape);
|
| 224 |
+
normal_impl_<normal_kernel, RNG>(output, 0, std, gen);
|
| 225 |
+
output.add_(mean);
|
| 226 |
+
return output;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 230 |
+
Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, c10::optional<Generator> gen) {
|
| 231 |
+
CHECK_NORMAL_TENSOR_STD(std);
|
| 232 |
+
auto mean_tensor = at::full({}, mean, output.options());
|
| 233 |
+
auto shape = at::infer_size(mean_tensor.sizes(), std.sizes());
|
| 234 |
+
at::native::resize_output(output, shape);
|
| 235 |
+
normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
|
| 236 |
+
// CUDA NB: addcmul_out copies the tensor to be added into the output.
|
| 237 |
+
// The previous function here was addcmul_out(output, mean_tensor, output, std, 1);
|
| 238 |
+
// The third argument is not a constant reference and hence the samples in output are overwritten.
|
| 239 |
+
// Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std
|
| 240 |
+
output.mul_(std).add_(mean_tensor);
|
| 241 |
+
return output;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 245 |
+
Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
|
| 246 |
+
CHECK_NORMAL_TENSOR_STD(std);
|
| 247 |
+
auto shape = at::infer_size(mean.sizes(), std.sizes());
|
| 248 |
+
at::native::resize_output(output, shape);
|
| 249 |
+
normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
|
| 250 |
+
// CUDA NB: addcmul_out copies the tensor to be added into the output.
|
| 251 |
+
// The previous function here was addcmul_out(output, mean, output, std, 1);
|
| 252 |
+
// The third argument is not a constant reference and hence the samples in output are overwritten.
|
| 253 |
+
// Consequently, the computation performed is mean + mean * std instead of mean + output * std
|
| 254 |
+
output.mul_(std).add_(mean);
|
| 255 |
+
return output;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 259 |
+
Tensor normal_impl(const Tensor& mean, double std, c10::optional<Generator> gen) {
|
| 260 |
+
CHECK_NORMAL_STD(std);
|
| 261 |
+
Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous);
|
| 262 |
+
normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
|
| 263 |
+
return ret;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 267 |
+
Tensor normal_impl(double mean, const Tensor& std, c10::optional<Generator> gen) {
|
| 268 |
+
CHECK_NORMAL_TENSOR_STD(std);
|
| 269 |
+
Tensor ret = at::empty_like(std, MemoryFormat::Contiguous);
|
| 270 |
+
normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
|
| 271 |
+
return ret;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 275 |
+
Tensor normal_impl(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
|
| 276 |
+
CHECK_NORMAL_TENSOR_STD(std);
|
| 277 |
+
auto shape = at::infer_size(mean.sizes(), std.sizes());
|
| 278 |
+
Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous);
|
| 279 |
+
normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
|
| 280 |
+
return ret;
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
// ==================================================== Uniform =======================================================
|
| 284 |
+
|
| 285 |
+
template<template<typename> class uniform_kernel, typename RNG>
|
| 286 |
+
at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, c10::optional<Generator> generator) {
|
| 287 |
+
if (self.is_complex()) {
|
| 288 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 289 |
+
auto float_tensor = at::view_as_real(self);
|
| 290 |
+
uniform_impl_<uniform_kernel, RNG>(float_tensor, from, to, generator);
|
| 291 |
+
} else {
|
| 292 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] {
|
| 293 |
+
const auto dtype = self.dtype();
|
| 294 |
+
const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
|
| 295 |
+
const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
|
| 296 |
+
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
|
| 297 |
+
CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype);
|
| 298 |
+
TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
|
| 299 |
+
TORCH_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
|
| 300 |
+
"uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
|
| 301 |
+
">::max(), but found to=", to, " and from=", from,
|
| 302 |
+
" which result in to-from to exceed the limit");
|
| 303 |
+
from = std::min(std::max(from, min), max);
|
| 304 |
+
to = std::max(std::min(to, max), min);
|
| 305 |
+
});
|
| 306 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 307 |
+
auto iter = at::TensorIterator::borrowing_nullary_op(self);
|
| 308 |
+
uniform_kernel<RNG>()(iter, from, to, generator);
|
| 309 |
+
}
|
| 310 |
+
return self;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
// ================================================== LogNormal =======================================================
|
| 314 |
+
|
| 315 |
+
template<template<typename> class log_normal_kernel, typename RNG>
|
| 316 |
+
at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, c10::optional<Generator> gen) {
|
| 317 |
+
TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std);
|
| 318 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 319 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 320 |
+
log_normal_kernel<RNG>()(iter, mean, std, gen);
|
| 321 |
+
return self;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
// =================================================== Geometric ======================================================
|
| 325 |
+
|
| 326 |
+
template<template<typename> class geometric_kernel, typename RNG>
|
| 327 |
+
Tensor& geometric_impl_(Tensor& self, double p, c10::optional<Generator> gen) {
|
| 328 |
+
TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
|
| 329 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 330 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 331 |
+
geometric_kernel<RNG>()(iter, p, gen);
|
| 332 |
+
return self;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
// ================================================== Exponential =====================================================
|
| 336 |
+
|
| 337 |
+
template<template<typename> class exponential_kernel, typename RNG>
|
| 338 |
+
Tensor& exponential_impl_(Tensor& self, double lambda, c10::optional<Generator> gen) {
|
| 339 |
+
TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda);
|
| 340 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 341 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 342 |
+
exponential_kernel<RNG>()(iter, lambda, gen);
|
| 343 |
+
return self;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
// ==================================================== Cauchy ========================================================
|
| 347 |
+
|
| 348 |
+
template<template<typename> class cauchy_kernel, typename RNG>
|
| 349 |
+
Tensor& cauchy_impl_(Tensor& self, double median, double sigma, c10::optional<Generator> gen) {
|
| 350 |
+
// TODO: instead of variable name 'sigma', use 'gamma' or 'scale'
|
| 351 |
+
// the variance, squared sigma, is undefined for cauchy distribution
|
| 352 |
+
TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma);
|
| 353 |
+
TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Cauchy distribution is a continuous probability distribution. dtype must be a floating point but you specified ", self.dtype());
|
| 354 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 355 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 356 |
+
cauchy_kernel<RNG>()(iter, median, sigma, gen);
|
| 357 |
+
return self;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
// ==================================================== Bernoulli =====================================================
|
| 361 |
+
|
| 362 |
+
template<template<typename> class bernoulli_tensor_kernel, typename RNG>
|
| 363 |
+
Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, c10::optional<Generator> gen) {
|
| 364 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 365 |
+
NoNamesGuard guard;
|
| 366 |
+
at::assert_no_internal_overlap(self);
|
| 367 |
+
bernoulli_tensor_kernel<RNG>()(self, p_, gen);
|
| 368 |
+
return self;
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
template<template<typename> class bernoulli_scalar_kernel, typename RNG>
|
| 372 |
+
Tensor& bernoulli_impl_(Tensor& self, double p, c10::optional<Generator> gen) {
|
| 373 |
+
TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
|
| 374 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 375 |
+
at::assert_no_internal_overlap(self);
|
| 376 |
+
bernoulli_scalar_kernel<RNG>()(self, p, gen);
|
| 377 |
+
return self;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
template<template<typename> class bernoulli_tensor_kernel, typename RNG>
|
| 381 |
+
Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, c10::optional<Generator> gen) {
|
| 382 |
+
// result.resize_as_(self) requires self to have same dtype as result, so we
|
| 383 |
+
// use resize_ instead.
|
| 384 |
+
// TODO: Fix resize_as_. See pytorch/pytorch#11665.
|
| 385 |
+
result.resize_(self.sizes());
|
| 386 |
+
bernoulli_impl_<bernoulli_tensor_kernel, RNG>(result, self, gen);
|
| 387 |
+
namedinference::propagate_names(result, self);
|
| 388 |
+
return result;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
#undef CHECK_OUT_OF_BOUNDS
|
| 392 |
+
#undef WARN_OUT_OF_BOUNDS
|
| 393 |
+
|
| 394 |
+
} // namespace at::native::templates
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Histogram.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
using histogramdd_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&);
|
| 9 |
+
using histogramdd_linear_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&, bool);
|
| 10 |
+
using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges);
|
| 11 |
+
|
| 12 |
+
DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub);
|
| 13 |
+
DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub);
|
| 14 |
+
DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub);
|
| 15 |
+
|
| 16 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexKernel.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
#include <c10/util/ArrayRef.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class Tensor;
|
| 7 |
+
class TensorBase;
|
| 8 |
+
struct TensorIterator;
|
| 9 |
+
struct TensorIteratorBase;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
namespace c10 {
|
| 13 |
+
class Scalar;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
namespace at::native {
|
| 17 |
+
|
| 18 |
+
using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
|
| 19 |
+
using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
|
| 20 |
+
using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
|
| 21 |
+
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
|
| 22 |
+
using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate);
|
| 23 |
+
using take_fn = void(*)(TensorIterator & iter, const TensorBase& input);
|
| 24 |
+
using flip_fn = void(*)(TensorIterator &, const bool);
|
| 25 |
+
using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
|
| 26 |
+
using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
|
| 27 |
+
using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &);
|
| 28 |
+
|
| 29 |
+
DECLARE_DISPATCH(index_fn, index_stub);
|
| 30 |
+
DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
|
| 31 |
+
DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
|
| 32 |
+
DECLARE_DISPATCH(index_put_fn, index_put_stub);
|
| 33 |
+
DECLARE_DISPATCH(put_fn, put_stub);
|
| 34 |
+
DECLARE_DISPATCH(take_fn, take_stub);
|
| 35 |
+
DECLARE_DISPATCH(flip_fn, flip_stub);
|
| 36 |
+
DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
|
| 37 |
+
DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub);
|
| 38 |
+
DECLARE_DISPATCH(masked_select_fn, masked_select_stub);
|
| 39 |
+
DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub);
|
| 40 |
+
|
| 41 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexingUtils.h
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/ExpandUtils.h>
|
| 3 |
+
#include <ATen/native/CanUse32BitIndexMath.h>
|
| 4 |
+
#include <ATen/native/TensorIterator.h>
|
| 5 |
+
#include <ATen/core/IListRef.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
[[noreturn]]
|
| 11 |
+
static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
|
| 12 |
+
TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
|
| 13 |
+
" does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTensorListRef indices) {
|
| 18 |
+
// If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
|
| 19 |
+
std::vector<Tensor> result;
|
| 20 |
+
for (const auto& index_opt : indices) {
|
| 21 |
+
if (!index_opt.has_value()) {
|
| 22 |
+
result.emplace_back();
|
| 23 |
+
} else {
|
| 24 |
+
const auto& index = *index_opt;
|
| 25 |
+
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
|
| 26 |
+
if (index.scalar_type() == kByte) {
|
| 27 |
+
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
|
| 28 |
+
" please use a dtype torch.bool instead.");
|
| 29 |
+
}
|
| 30 |
+
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
|
| 31 |
+
// corresponding dimensions in self
|
| 32 |
+
for (const auto j : c10::irange(index.dim())) {
|
| 33 |
+
int64_t srcIdx = static_cast<int64_t>(result.size() + j);
|
| 34 |
+
if (index.size(j) != self.size(srcIdx)) {
|
| 35 |
+
invalid_mask(self, srcIdx, index, j);
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
// Replace with nonzeros
|
| 39 |
+
auto nonzero = index.nonzero();
|
| 40 |
+
for (const auto j : c10::irange(index.dim())) {
|
| 41 |
+
result.emplace_back(nonzero.select(1, j));
|
| 42 |
+
}
|
| 43 |
+
} else {
|
| 44 |
+
result.emplace_back(index);
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
return result;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
|
| 52 |
+
for (const auto& tensor : indices) {
|
| 53 |
+
if (tensor.has_value() && tensor->defined()) {
|
| 54 |
+
auto scalarType = tensor->scalar_type();
|
| 55 |
+
if (allow_int) {
|
| 56 |
+
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
|
| 57 |
+
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
|
| 58 |
+
}
|
| 59 |
+
} else {
|
| 60 |
+
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
|
| 61 |
+
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
inline torch::List<c10::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
|
| 69 |
+
torch::List<c10::optional<Tensor>> result;
|
| 70 |
+
result.reserve(list.size());
|
| 71 |
+
for (const Tensor& a : list) {
|
| 72 |
+
result.push_back(a);
|
| 73 |
+
}
|
| 74 |
+
return result;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
inline torch::List<c10::optional<Tensor>> toListOfOptionalTensors(ArrayRef<IValue> list) {
|
| 78 |
+
torch::List<c10::optional<Tensor>> result;
|
| 79 |
+
result.reserve(list.size());
|
| 80 |
+
for (const IValue& a : list) {
|
| 81 |
+
result.push_back(a.isTensor() ? c10::optional<Tensor>(a.toTensor()) : c10::optional<Tensor>());
|
| 82 |
+
}
|
| 83 |
+
return result;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
static C10_UNUSED bool hasContiguousSubspace(TensorList tl) {
|
| 87 |
+
// true if all the non-null tensors are adjacent
|
| 88 |
+
auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
|
| 89 |
+
auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
|
| 90 |
+
auto start = std::find_if(tl.begin(), tl.end(), isDefined);
|
| 91 |
+
auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
|
| 92 |
+
auto it = std::find_if(start, stop.base(), isNull);
|
| 93 |
+
return it == stop.base();
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
// Transposes the tensor and indices together so that all the non-null indices
|
| 98 |
+
// index the first k dimensions of the tensor. Returns the transposed tensor
|
| 99 |
+
// and the reordered indices. For example:
|
| 100 |
+
// transposeToFront(tensor, {nullptr, a, nullptr, b})
|
| 101 |
+
// returns
|
| 102 |
+
// tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
|
| 103 |
+
static C10_UNUSED std::tuple<Tensor, std::vector<Tensor>>
|
| 104 |
+
transposeToFront(const Tensor& self, TensorList indices) {
|
| 105 |
+
std::vector<int64_t> dims;
|
| 106 |
+
std::vector<Tensor> transposedIndices;
|
| 107 |
+
dims.reserve(self.dim());
|
| 108 |
+
for (const auto i : c10::irange(self.dim())) {
|
| 109 |
+
if (indices[i].defined()) {
|
| 110 |
+
dims.push_back(i);
|
| 111 |
+
transposedIndices.emplace_back(indices[i]);
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
for (const auto i : c10::irange(self.dim())) {
|
| 115 |
+
if (!indices[i].defined()) {
|
| 116 |
+
dims.push_back(i);
|
| 117 |
+
transposedIndices.emplace_back();
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
return std::make_tuple(self.permute(dims), std::move(transposedIndices));
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
inline std::tuple<Tensor, std::vector<Tensor>, std::vector<int64_t>>
|
| 124 |
+
transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) {
|
| 125 |
+
std::vector<int64_t> dims;
|
| 126 |
+
std::vector<int64_t> invPerm;
|
| 127 |
+
std::vector<Tensor> transposedIndices;
|
| 128 |
+
dims.reserve(self.dim());
|
| 129 |
+
invPerm.resize(self.dim());
|
| 130 |
+
for (const auto i : c10::irange(self.dim())) {
|
| 131 |
+
if (indices[i].defined()) {
|
| 132 |
+
dims.push_back(i);
|
| 133 |
+
transposedIndices.emplace_back(indices[i]);
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
for (const auto i : c10::irange(self.dim())) {
|
| 137 |
+
if (!indices[i].defined()) {
|
| 138 |
+
dims.push_back(i);
|
| 139 |
+
transposedIndices.emplace_back();
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
for (const auto i : c10::irange(self.dim())) {
|
| 143 |
+
invPerm[dims[i]] = i;
|
| 144 |
+
}
|
| 145 |
+
return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
struct AdvancedIndex {
|
| 149 |
+
AdvancedIndex(const Tensor& src, TensorList indices);
|
| 150 |
+
|
| 151 |
+
Tensor src;
|
| 152 |
+
std::vector<Tensor> indices;
|
| 153 |
+
DimVector indexed_sizes;
|
| 154 |
+
DimVector indexed_strides;
|
| 155 |
+
int64_t dims_before;
|
| 156 |
+
int64_t dims_after;
|
| 157 |
+
};
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
} //namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitsFallback.h
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 3 |
+
#include <ATen/core/op_registration/op_registration.h>
|
| 4 |
+
#include <ATen/native/UnaryOps.h>
|
| 5 |
+
#include <ATen/native/Resize.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
#include <torch/library.h>
|
| 8 |
+
|
| 9 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 10 |
+
#include <ATen/Functions.h>
|
| 11 |
+
#else
|
| 12 |
+
#include <ATen/ops/clone.h>
|
| 13 |
+
|
| 14 |
+
#include <utility>
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
namespace at::native {
|
| 18 |
+
// This fallback should only be used for operations that are self inverse and have a corresponding tensor
|
| 19 |
+
// bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
|
| 20 |
+
// Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
|
| 21 |
+
// Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.
|
| 22 |
+
|
| 23 |
+
// NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
|
| 24 |
+
struct MathOpFallback {
|
| 25 |
+
MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(std::move(op_name_)) {}
|
| 26 |
+
virtual bool is_bit_set(const Tensor&) = 0;
|
| 27 |
+
void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
| 28 |
+
/*
|
| 29 |
+
Situations to handle:
|
| 30 |
+
1. Out-of-place operation. Easy: materialize all inputs and
|
| 31 |
+
call it a day.
|
| 32 |
+
2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
|
| 33 |
+
Materialize other inputs as in (1).
|
| 34 |
+
3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
|
| 35 |
+
Materialize other inputs as in (1).
|
| 36 |
+
|
| 37 |
+
It is important to be able to tell if we READ from an argument and if we
|
| 38 |
+
WRITE to an argument. Conservative approach is to assume that we always
|
| 39 |
+
READ from an argument, but in out= operations you can skip
|
| 40 |
+
conjugating inputs on entry that never get used. In the current schema we
|
| 41 |
+
can't easily tell if the operation is in in-place or out= operation.
|
| 42 |
+
|
| 43 |
+
Note:
|
| 44 |
+
1. Mutable tensorlists containing tensors whose math bit set to true are disallowed.
|
| 45 |
+
2. Mutable tensors with math bit set to true are unconditionally cloned to ensure
|
| 46 |
+
correct behavior in the case when the mutable tensor shares memory with non mutable arguments.
|
| 47 |
+
|
| 48 |
+
If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory
|
| 49 |
+
with these mutable inputs would read into wrong values in the following cases:
|
| 50 |
+
1. Non mutable inputs have their math bit set to false.
|
| 51 |
+
2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory
|
| 52 |
+
with one or more mutable arg(s)) are cloned.
|
| 53 |
+
At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs.
|
| 54 |
+
*/
|
| 55 |
+
const auto& arguments = op.schema().arguments();
|
| 56 |
+
const auto num_arguments = arguments.size();
|
| 57 |
+
const auto stack_start = stack->size() - num_arguments;
|
| 58 |
+
|
| 59 |
+
c10::optional<bool> is_write;
|
| 60 |
+
for (const auto i : c10::irange(num_arguments)) {
|
| 61 |
+
// Three possible states:
|
| 62 |
+
// 1. alias_info has no value --> out-of-place operation
|
| 63 |
+
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
|
| 64 |
+
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
|
| 65 |
+
const AliasInfo* alias_info = arguments[i].alias_info();
|
| 66 |
+
if (alias_info != nullptr) {
|
| 67 |
+
if (is_write.has_value()) {
|
| 68 |
+
TORCH_CHECK(*is_write == alias_info->isWrite(),
|
| 69 |
+
"Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
|
| 70 |
+
op_name, " fallback doesn't work for operators with a mix "
|
| 71 |
+
"mutable and non-mutable inputs that alias with outputs, "
|
| 72 |
+
"this must be implemented manually. "
|
| 73 |
+
"If you got this error on a core op, please report a bug to PyTorch.");
|
| 74 |
+
} else {
|
| 75 |
+
is_write = alias_info->isWrite();
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if (is_write.has_value() && !*is_write) {
|
| 81 |
+
// We assume that view operators automatically handle the math bit
|
| 82 |
+
// correctly by propagating the dispatch key in key_set.
|
| 83 |
+
// This is not necessarily always right, so you should test these cases.
|
| 84 |
+
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
|
| 85 |
+
return;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
// Mutable inputs with math bit set to True and their clones
|
| 89 |
+
std::vector<std::pair<Tensor, Tensor>> mutable_inputs_with_their_clones;
|
| 90 |
+
for (const auto i : c10::irange(num_arguments)) {
|
| 91 |
+
auto& ivalue = (*stack)[stack_start + i];
|
| 92 |
+
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
|
| 93 |
+
continue;
|
| 94 |
+
}
|
| 95 |
+
const auto& argument = arguments[i];
|
| 96 |
+
bool mut_arg = false;
|
| 97 |
+
if (argument.alias_info()) {
|
| 98 |
+
// Was already tested by is_write loop above
|
| 99 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
|
| 100 |
+
mut_arg = true;
|
| 101 |
+
}
|
| 102 |
+
if (ivalue.isTensor()) {
|
| 103 |
+
if (!is_bit_set(ivalue.toTensor())) {
|
| 104 |
+
continue;
|
| 105 |
+
}
|
| 106 |
+
auto tensor = std::move(ivalue).toTensor();
|
| 107 |
+
auto resolved_tensor = at::clone(tensor);
|
| 108 |
+
if (mut_arg) {
|
| 109 |
+
TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ",
|
| 110 |
+
op_name, "bit set to true.");
|
| 111 |
+
mutable_inputs_with_their_clones.emplace_back(std::move(tensor), resolved_tensor);
|
| 112 |
+
}
|
| 113 |
+
(*stack)[stack_start + i] = std::move(resolved_tensor);
|
| 114 |
+
} else if (ivalue.isTensorList()) {
|
| 115 |
+
auto tensors = std::move(ivalue).toTensorList();
|
| 116 |
+
for(const auto j : c10::irange(tensors.size())) {
|
| 117 |
+
const auto& tensor = tensors[j];
|
| 118 |
+
if (!is_bit_set(tensor)) {
|
| 119 |
+
continue;
|
| 120 |
+
}
|
| 121 |
+
TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ",
|
| 122 |
+
op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ",
|
| 123 |
+
op.schema().name());
|
| 124 |
+
tensors[j] = at::clone(tensor);
|
| 125 |
+
}
|
| 126 |
+
(*stack)[stack_start + i] = std::move(tensors);
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
|
| 131 |
+
|
| 132 |
+
TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1);
|
| 133 |
+
|
| 134 |
+
for (std::pair<Tensor, Tensor> mut_tensors: mutable_inputs_with_their_clones) {
|
| 135 |
+
auto& mutable_input = mut_tensors.first;
|
| 136 |
+
auto& cloned_mutable_input = mut_tensors.second;
|
| 137 |
+
auto& ivalue = (*stack)[stack_start];
|
| 138 |
+
auto returned_output = std::move(ivalue).toTensor();
|
| 139 |
+
|
| 140 |
+
// sanity check to ensure that the tensor in stack aliases the cloned_mutable_input
|
| 141 |
+
TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output));
|
| 142 |
+
|
| 143 |
+
// necessary for out= arg
|
| 144 |
+
at::native::resize_output(mutable_input, returned_output.sizes());
|
| 145 |
+
|
| 146 |
+
mutable_input.copy_(returned_output);
|
| 147 |
+
(*stack)[stack_start] = std::move(mutable_input);
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
virtual ~MathOpFallback() = default;
|
| 152 |
+
|
| 153 |
+
DispatchKey key;
|
| 154 |
+
string op_name;
|
| 155 |
+
};
|
| 156 |
+
|
| 157 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MaxPooling.h
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/Parallel.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
#include <ATen/native/Pool.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
static void check_max_pool1d(
|
| 11 |
+
const Tensor& self,
|
| 12 |
+
IntArrayRef kernel_size,
|
| 13 |
+
IntArrayRef stride,
|
| 14 |
+
IntArrayRef padding,
|
| 15 |
+
IntArrayRef dilation,
|
| 16 |
+
bool ceil_mode) {
|
| 17 |
+
|
| 18 |
+
TORCH_CHECK(
|
| 19 |
+
self.dim() == 2 || self.dim() == 3,
|
| 20 |
+
"max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
|
| 21 |
+
TORCH_CHECK(
|
| 22 |
+
kernel_size.size() == 1,
|
| 23 |
+
"max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
|
| 24 |
+
kernel_size.size());
|
| 25 |
+
TORCH_CHECK(
|
| 26 |
+
stride.empty() || stride.size() == 1,
|
| 27 |
+
"max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
|
| 28 |
+
stride.size());
|
| 29 |
+
TORCH_CHECK(
|
| 30 |
+
padding.size() == 1,
|
| 31 |
+
"max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
|
| 32 |
+
padding.size());
|
| 33 |
+
TORCH_CHECK(
|
| 34 |
+
dilation.size() == 1,
|
| 35 |
+
"max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
|
| 36 |
+
dilation.size());
|
| 37 |
+
|
| 38 |
+
// If stride=None then set it to kernel_size
|
| 39 |
+
if (stride.empty()) {
|
| 40 |
+
stride = kernel_size;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
TORCH_CHECK(
|
| 44 |
+
kernel_size[0] > 0,
|
| 45 |
+
"max_pool1d() kernel_size must be greater than zero, but got ",
|
| 46 |
+
kernel_size[0]);
|
| 47 |
+
TORCH_CHECK(
|
| 48 |
+
stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
|
| 49 |
+
TORCH_CHECK(
|
| 50 |
+
padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
|
| 51 |
+
TORCH_CHECK(
|
| 52 |
+
padding[0] <= kernel_size[0] / 2,
|
| 53 |
+
"max_pool1d() padding should be at most half of kernel size, but got padding=",
|
| 54 |
+
padding[0],
|
| 55 |
+
" and kernel_size=",
|
| 56 |
+
kernel_size[0]);
|
| 57 |
+
TORCH_CHECK(
|
| 58 |
+
dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
|
| 59 |
+
|
| 60 |
+
const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
|
| 61 |
+
TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// TODO(Heitor) Template by dimension
|
| 65 |
+
struct PoolingParams1D {
|
| 66 |
+
int64_t NB; // Number of batches
|
| 67 |
+
int64_t NC; // Number of channels
|
| 68 |
+
int64_t IW; // Input width
|
| 69 |
+
int64_t OW; // Output width
|
| 70 |
+
int64_t KW; // Kernel width
|
| 71 |
+
int64_t SJ; // Column stride
|
| 72 |
+
int64_t PJ; // Column padding
|
| 73 |
+
int64_t DJ; // Column dilation
|
| 74 |
+
|
| 75 |
+
// Return index of input element for the given kernel and output index
|
| 76 |
+
inline int64_t index(int64_t kj, int64_t oj) const {
|
| 77 |
+
return oj * SJ + kj * DJ - PJ;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// Return index of first output within bounds for this kernel index
|
| 81 |
+
inline int64_t valid_output_start(int64_t kj) const {
|
| 82 |
+
int64_t ij = index(kj, 0);;
|
| 83 |
+
return ij < 0 ? at::divup(-ij, SJ) : 0;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Return index one past last output within bounds for this kernel index
|
| 87 |
+
inline int64_t valid_output_end(int64_t kj) const {
|
| 88 |
+
int64_t ij = index(kj, OW - 1);
|
| 89 |
+
return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
|
| 94 |
+
|
| 95 |
+
DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
|
| 96 |
+
|
| 97 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBase.h>
|
| 2 |
+
#include <algorithm>
|
| 3 |
+
#include <vector>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
inline int64_t ensure_nonempty_dim(int64_t dim) {
|
| 8 |
+
return std::max<int64_t>(dim, 1);
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) {
|
| 12 |
+
return t.dim() == 0 ? 1 : t.size(dim);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) {
|
| 16 |
+
return t.dim() == 0 ? 1 : t.stride(dim);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
using IdxVec = std::vector<int64_t>;
|
| 20 |
+
inline IdxVec ensure_nonempty_vec(IdxVec vec) {
|
| 21 |
+
if (vec.empty()) {
|
| 22 |
+
vec.push_back(1);
|
| 23 |
+
}
|
| 24 |
+
return vec;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Padding.h
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef);
|
| 9 |
+
|
| 10 |
+
// reflection padding
|
| 11 |
+
DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel);
|
| 12 |
+
DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel);
|
| 13 |
+
DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel);
|
| 14 |
+
DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel);
|
| 15 |
+
DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel);
|
| 16 |
+
DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel);
|
| 17 |
+
|
| 18 |
+
// replication padding
|
| 19 |
+
DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel);
|
| 20 |
+
DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel);
|
| 21 |
+
DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel);
|
| 22 |
+
DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel);
|
| 23 |
+
DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel);
|
| 24 |
+
DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel);
|
| 25 |
+
|
| 26 |
+
namespace padding {
|
| 27 |
+
|
| 28 |
+
template <int dim>
|
| 29 |
+
static inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
|
| 30 |
+
|
| 31 |
+
TORCH_CHECK(padding.size() == 2 * dim,
|
| 32 |
+
"padding size is expected to be ", 2 * dim,
|
| 33 |
+
", but got: ", padding.size());
|
| 34 |
+
|
| 35 |
+
int input_dim = input.dim();
|
| 36 |
+
|
| 37 |
+
bool is_batch_mode = input_dim == (dim + 2);
|
| 38 |
+
|
| 39 |
+
bool valid_batch_mode = is_batch_mode;
|
| 40 |
+
bool valid_non_batch_mode = !is_batch_mode;
|
| 41 |
+
|
| 42 |
+
if (is_batch_mode) {
|
| 43 |
+
// allow batch size of 0-dim.
|
| 44 |
+
for (const auto d : c10::irange(1, input_dim)) {
|
| 45 |
+
valid_batch_mode = valid_batch_mode && input.size(d) != 0;
|
| 46 |
+
}
|
| 47 |
+
} else {
|
| 48 |
+
for (const auto d : c10::irange(0, input_dim)) {
|
| 49 |
+
valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0;
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// allow empty batch size but not other dimensions.
|
| 54 |
+
TORCH_CHECK(valid_batch_mode || valid_non_batch_mode,
|
| 55 |
+
"Expected ", dim + 1, "D or ", dim + 2,
|
| 56 |
+
"D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
|
| 57 |
+
input.sizes());
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
} // namespace padding
|
| 61 |
+
|
| 62 |
+
} // at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PointwiseOps.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Ternary and higher-order pointwise operations
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace c10 {
|
| 7 |
+
class Scalar;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
|
| 12 |
+
struct TensorIterator;
|
| 13 |
+
struct TensorIteratorBase;
|
| 14 |
+
|
| 15 |
+
namespace native {
|
| 16 |
+
|
| 17 |
+
using pointwise_fn = void (*)(TensorIterator&, const Scalar& scalar);
|
| 18 |
+
using structured_pointwise_fn = void (*)(TensorIteratorBase&, const Scalar& scalar);
|
| 19 |
+
using pointwise_fn_double = void (*)(TensorIterator&, const Scalar&, double);
|
| 20 |
+
|
| 21 |
+
DECLARE_DISPATCH(structured_pointwise_fn, addcmul_stub);
|
| 22 |
+
DECLARE_DISPATCH(structured_pointwise_fn, addcdiv_stub);
|
| 23 |
+
DECLARE_DISPATCH(pointwise_fn_double, smooth_l1_backward_stub);
|
| 24 |
+
DECLARE_DISPATCH(pointwise_fn_double, huber_backward_stub);
|
| 25 |
+
DECLARE_DISPATCH(pointwise_fn, mse_backward_stub);
|
| 26 |
+
|
| 27 |
+
} // namespace native
|
| 28 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pool.h
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <ATen/div_rtn.h>
|
| 3 |
+
#include <ATen/TensorUtils.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <c10/util/irange.h>
|
| 6 |
+
|
| 7 |
+
#include <utility>
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
namespace at::native {
|
| 12 |
+
|
| 13 |
+
using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input,
|
| 14 |
+
int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
|
| 15 |
+
using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
|
| 16 |
+
|
| 17 |
+
DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
|
| 18 |
+
DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel);
|
| 19 |
+
|
| 20 |
+
// averge pooling has same signature for forward and backward
|
| 21 |
+
using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH,
|
| 22 |
+
int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, c10::optional<int64_t> divisor_override);
|
| 23 |
+
using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
|
| 24 |
+
int dW, int dH, int padW, int padH, bool count_include_pad, c10::optional<int64_t> divisor_override);
|
| 25 |
+
|
| 26 |
+
DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel);
|
| 27 |
+
DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel);
|
| 28 |
+
|
| 29 |
+
using max_pool3d_fn = void(*)(Tensor& output, Tensor& indices, const Tensor& input,
|
| 30 |
+
int kW, int kH, int kD, int dW, int dH, int dD, int pW, int pH, int pD, int dilationW, int dilationH, int dilationD);
|
| 31 |
+
using max_pool3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
|
| 32 |
+
|
| 33 |
+
DECLARE_DISPATCH(max_pool3d_fn, max_pool3d_kernel);
|
| 34 |
+
DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel);
|
| 35 |
+
namespace {
|
| 36 |
+
|
| 37 |
+
template <typename dest_t, typename src_t>
|
| 38 |
+
static inline dest_t
|
| 39 |
+
safe_downcast(src_t v)
|
| 40 |
+
{
|
| 41 |
+
TORCH_CHECK(std::numeric_limits<dest_t>::min() <= v && v <= std::numeric_limits<dest_t>::max(),
|
| 42 |
+
"integer out of range");
|
| 43 |
+
|
| 44 |
+
return static_cast<dest_t>(v);
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
template<typename T>
|
| 48 |
+
static inline T pooling_output_shape_pad_lr(
|
| 49 |
+
T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
|
| 50 |
+
bool ceil_mode) {
|
| 51 |
+
T outputSize = div_rtn<T>(
|
| 52 |
+
inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 +
|
| 53 |
+
(ceil_mode ? stride - 1 : 0), stride) + 1;
|
| 54 |
+
if (ceil_mode) {
|
| 55 |
+
// ensure that the last pooling starts inside the image
|
| 56 |
+
// needed to avoid problems in ceil mode
|
| 57 |
+
if ((outputSize - 1) * stride >= inputSize + pad_l) {
|
| 58 |
+
--outputSize;
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
return outputSize;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
template<typename T>
|
| 65 |
+
static inline T pooling_output_shape(
|
| 66 |
+
T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
|
| 67 |
+
TORCH_CHECK(stride != 0, "stride should not be zero");
|
| 68 |
+
TORCH_CHECK(pad >= 0,
|
| 69 |
+
"pad must be non-negative, but got pad: ", pad);
|
| 70 |
+
TORCH_CHECK(pad <= ((kernelSize - 1) * dilation + 1) / 2,
|
| 71 |
+
"pad should be at most half of effective kernel size, but got pad=",
|
| 72 |
+
pad, ", kernel_size=", kernelSize, " and dilation=", dilation)
|
| 73 |
+
return pooling_output_shape_pad_lr(
|
| 74 |
+
inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
template <typename T>
|
| 78 |
+
std::pair<T, T> _pooling_same_mode_padding_lr(
|
| 79 |
+
T inputSize, T kernelSize, T stride, T dilation) {
|
| 80 |
+
// NOTE: with strides, the output shape is ceil(inputSize/stride)
|
| 81 |
+
auto total_padding = T(dilation) * (kernelSize - 1);
|
| 82 |
+
|
| 83 |
+
// Prefer symmetric padding if possible
|
| 84 |
+
if (stride > 2 && (total_padding % 2 == 1)) {
|
| 85 |
+
// The floor in the output size calculation gives us a little wiggle room
|
| 86 |
+
auto wiggle_room = inputSize % stride - 1;
|
| 87 |
+
if (wiggle_room > 0) {
|
| 88 |
+
total_padding = total_padding - 1;
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
auto left = total_padding / 2;
|
| 93 |
+
return {left, total_padding - left};
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
inline std::pair<int64_t, int64_t> pooling_same_mode_padding_lr(
|
| 97 |
+
int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) {
|
| 98 |
+
return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
|
| 102 |
+
c10::SymInt inputSize, c10::SymInt kernelSize, c10::SymInt stride, c10::SymInt dilation) {
|
| 103 |
+
return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), std::move(stride), std::move(dilation));
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// AveragePool2d/DilatedMaxPool2d (forward)
|
| 107 |
+
static inline void
|
| 108 |
+
pool2d_shape_check(
|
| 109 |
+
const Tensor& input,
|
| 110 |
+
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
|
| 111 |
+
int64_t nInputPlane,
|
| 112 |
+
int64_t inputHeight, int64_t inputWidth,
|
| 113 |
+
int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
|
| 114 |
+
{
|
| 115 |
+
const int64_t ndim = input.ndimension();
|
| 116 |
+
const int64_t nOutputPlane = nInputPlane;
|
| 117 |
+
|
| 118 |
+
TORCH_CHECK(kW > 0 && kH > 0,
|
| 119 |
+
"kernel size should be greater than zero, but got ",
|
| 120 |
+
"kH: ", kH, " kW: ", kW);
|
| 121 |
+
TORCH_CHECK(dW > 0 && dH > 0,
|
| 122 |
+
"stride should be greater than zero, but got "
|
| 123 |
+
"dH: ", dH, " dW: ", dW);
|
| 124 |
+
TORCH_CHECK(dilationH > 0 && dilationW > 0,
|
| 125 |
+
"dilation should be greater than zero, but got ",
|
| 126 |
+
"dilationH: ", dilationH, " dilationW: ", dilationW);
|
| 127 |
+
|
| 128 |
+
bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
|
| 129 |
+
if (memory_format == at::MemoryFormat::ChannelsLast){
|
| 130 |
+
// Expect tensor in NHWC format and allow 0-dim only for N.
|
| 131 |
+
TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0),
|
| 132 |
+
"Expected 4D (batch mode) tensor expected for input with channels_last layout"
|
| 133 |
+
" with optional 0 dim batch size for input, but got: ", input.sizes());
|
| 134 |
+
} else {
|
| 135 |
+
TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) ||
|
| 136 |
+
(ndim == 4 && valid_dims && input.size(3) != 0),
|
| 137 |
+
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
|
| 138 |
+
input.sizes());
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
|
| 142 |
+
"pad should be smaller than or equal to half of kernel size, but got ",
|
| 143 |
+
"padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
|
| 144 |
+
|
| 145 |
+
TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
|
| 146 |
+
"Given input size: (",
|
| 147 |
+
nInputPlane, "x", inputHeight, "x", inputWidth, "). ",
|
| 148 |
+
"Calculated output size: (",
|
| 149 |
+
nOutputPlane, "x", outputHeight, "x", outputWidth, "). ",
|
| 150 |
+
"Output size is too small");
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
// DilatedMaxPool2d (backward)
|
| 154 |
+
static inline void
|
| 155 |
+
max_pool2d_backward_shape_check(
|
| 156 |
+
const Tensor& input,
|
| 157 |
+
const Tensor& gradOutput,
|
| 158 |
+
const Tensor& indices,
|
| 159 |
+
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
|
| 160 |
+
int64_t nInputPlane,
|
| 161 |
+
int64_t inputHeight, int64_t inputWidth,
|
| 162 |
+
int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
|
| 163 |
+
{
|
| 164 |
+
pool2d_shape_check(
|
| 165 |
+
input,
|
| 166 |
+
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
| 167 |
+
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
|
| 168 |
+
|
| 169 |
+
const int64_t ndim = input.ndimension();
|
| 170 |
+
const int64_t nOutputPlane = nInputPlane;
|
| 171 |
+
|
| 172 |
+
check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
|
| 173 |
+
check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
|
| 174 |
+
check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
|
| 175 |
+
|
| 176 |
+
check_dim_size(indices, ndim, ndim-3, nOutputPlane);
|
| 177 |
+
check_dim_size(indices, ndim, ndim-2, outputHeight);
|
| 178 |
+
check_dim_size(indices, ndim, ndim-1, outputWidth);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
// AveragePool2d (backward)
|
| 182 |
+
static inline void
|
| 183 |
+
avg_pool2d_backward_shape_check(
|
| 184 |
+
const Tensor& input,
|
| 185 |
+
const Tensor& gradOutput,
|
| 186 |
+
int64_t /*nbatch*/,
|
| 187 |
+
int kH, int kW, int dH, int dW, int padH, int padW,
|
| 188 |
+
int64_t nInputPlane,
|
| 189 |
+
int64_t inputHeight, int64_t inputWidth,
|
| 190 |
+
int64_t outputHeight, int64_t outputWidth,
|
| 191 |
+
MemoryFormat memory_format)
|
| 192 |
+
{
|
| 193 |
+
pool2d_shape_check(
|
| 194 |
+
input,
|
| 195 |
+
kH, kW, dH, dW, padH, padW, 1, 1,
|
| 196 |
+
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
| 197 |
+
memory_format);
|
| 198 |
+
|
| 199 |
+
const int64_t ndim = input.ndimension();
|
| 200 |
+
const int64_t nOutputPlane = nInputPlane;
|
| 201 |
+
|
| 202 |
+
check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
|
| 203 |
+
check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
|
| 204 |
+
check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// AveragePool3d/DilatedMaxPool3d (forward)
|
| 208 |
+
static inline void
|
| 209 |
+
pool3d_shape_check(
|
| 210 |
+
const Tensor& input,
|
| 211 |
+
int64_t nslices,
|
| 212 |
+
int kT, int kH, int kW,
|
| 213 |
+
int dT, int dH, int dW,
|
| 214 |
+
int pT, int pH, int pW,
|
| 215 |
+
int dilationT, int dilationH, int dilationW,
|
| 216 |
+
int64_t itime, int64_t iheight, int64_t iwidth,
|
| 217 |
+
int64_t otime, int64_t oheight, int64_t owidth,
|
| 218 |
+
const char *fn_name,
|
| 219 |
+
bool check_input_size=false)
|
| 220 |
+
{
|
| 221 |
+
const int64_t ndim = input.ndimension();
|
| 222 |
+
|
| 223 |
+
TORCH_CHECK(kT > 0 && kW > 0 && kH > 0,
|
| 224 |
+
"kernel size should be greater than zero, but got ",
|
| 225 |
+
"kT: ", kT, " kH: ", kH, " kW: ", kW);
|
| 226 |
+
TORCH_CHECK(dT > 0 && dW > 0 && dH > 0,
|
| 227 |
+
"stride should be greater than zero, but got ",
|
| 228 |
+
"dT: ", dT, " dH: ", dH, " dW: ", dW);
|
| 229 |
+
TORCH_CHECK(dilationT > 0 && dilationW > 0 && dilationH > 0,
|
| 230 |
+
"dilation should be greater than zero, but got ",
|
| 231 |
+
"dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW);
|
| 232 |
+
|
| 233 |
+
TORCH_CHECK(ndim == 4 || ndim == 5,
|
| 234 |
+
fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes());
|
| 235 |
+
|
| 236 |
+
for (const auto i : c10::irange(ndim)) {
|
| 237 |
+
if (ndim == 5 && i == 0) {
|
| 238 |
+
// size of batch-dim can be 0.
|
| 239 |
+
continue;
|
| 240 |
+
}
|
| 241 |
+
TORCH_CHECK(
|
| 242 |
+
input.size(i) > 0,
|
| 243 |
+
fn_name,
|
| 244 |
+
": Expected input's non-batch dimensions to have positive length,"
|
| 245 |
+
" but input has a shape of ",
|
| 246 |
+
input.sizes(),
|
| 247 |
+
" and non-batch dimension ",
|
| 248 |
+
input.size(i),
|
| 249 |
+
" has length zero!")
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
if (check_input_size) { // AveragePool3d
|
| 253 |
+
TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
|
| 254 |
+
"input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ",
|
| 255 |
+
"kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")");
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
|
| 259 |
+
"pad should be smaller than or equal to half of kernel size, but got "
|
| 260 |
+
"kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
|
| 261 |
+
|
| 262 |
+
TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1,
|
| 263 |
+
"Given input size: (",
|
| 264 |
+
nslices,"x", itime, "x", iheight, "x", iwidth, "). ",
|
| 265 |
+
"Calculated output size: (",
|
| 266 |
+
nslices, "x", otime, "x", oheight, "x", owidth, "). ",
|
| 267 |
+
"Output size is too small");
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
static inline void
|
| 271 |
+
max_pool3d_backward_shape_check(
|
| 272 |
+
const Tensor& input,
|
| 273 |
+
const Tensor& gradOutput,
|
| 274 |
+
const Tensor& indices,
|
| 275 |
+
int64_t nslices,
|
| 276 |
+
int kT, int kH, int kW,
|
| 277 |
+
int dT, int dH, int dW,
|
| 278 |
+
int pT, int pH, int pW,
|
| 279 |
+
int dilationT, int dilationH, int dilationW,
|
| 280 |
+
int64_t itime, int64_t iheight, int64_t iwidth,
|
| 281 |
+
int64_t otime, int64_t oheight, int64_t owidth,
|
| 282 |
+
const char* fn_name)
|
| 283 |
+
{
|
| 284 |
+
const int64_t ndim = input.ndimension();
|
| 285 |
+
|
| 286 |
+
pool3d_shape_check(
|
| 287 |
+
input,
|
| 288 |
+
nslices,
|
| 289 |
+
kT, kH, kW,
|
| 290 |
+
dT, dH, dW,
|
| 291 |
+
pT, pH, pW,
|
| 292 |
+
dilationT, dilationH, dilationW,
|
| 293 |
+
itime, iheight, iwidth,
|
| 294 |
+
otime, oheight, owidth, fn_name);
|
| 295 |
+
|
| 296 |
+
check_dim_size(gradOutput, ndim, ndim-4, nslices);
|
| 297 |
+
check_dim_size(gradOutput, ndim, ndim-3, otime);
|
| 298 |
+
check_dim_size(gradOutput, ndim, ndim-2, oheight);
|
| 299 |
+
check_dim_size(gradOutput, ndim, ndim-1, owidth);
|
| 300 |
+
|
| 301 |
+
check_dim_size(indices, ndim, ndim-4, nslices);
|
| 302 |
+
check_dim_size(indices, ndim, ndim-3, otime);
|
| 303 |
+
check_dim_size(indices, ndim, ndim-2, oheight);
|
| 304 |
+
check_dim_size(indices, ndim, ndim-1, owidth);
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
static inline void
|
| 308 |
+
avg_pool3d_backward_shape_check(
|
| 309 |
+
const Tensor& input,
|
| 310 |
+
const Tensor& gradOutput,
|
| 311 |
+
int64_t nslices,
|
| 312 |
+
int kT, int kH, int kW,
|
| 313 |
+
int dT, int dH, int dW,
|
| 314 |
+
int pT, int pH, int pW,
|
| 315 |
+
int64_t itime, int64_t iheight, int64_t iwidth,
|
| 316 |
+
int64_t otime, int64_t oheight, int64_t owidth,
|
| 317 |
+
const char *fn_name)
|
| 318 |
+
{
|
| 319 |
+
const int64_t ndim = input.ndimension();
|
| 320 |
+
|
| 321 |
+
pool3d_shape_check(
|
| 322 |
+
input,
|
| 323 |
+
nslices,
|
| 324 |
+
kT, kH, kW,
|
| 325 |
+
dT, dH, dW,
|
| 326 |
+
pT, pH, pW,
|
| 327 |
+
1, 1, 1,
|
| 328 |
+
itime, iheight, iwidth,
|
| 329 |
+
otime, oheight, owidth,
|
| 330 |
+
fn_name, true);
|
| 331 |
+
|
| 332 |
+
check_dim_size(gradOutput, ndim, ndim-4, nslices);
|
| 333 |
+
check_dim_size(gradOutput, ndim, ndim-3, otime);
|
| 334 |
+
check_dim_size(gradOutput, ndim, ndim-2, oheight);
|
| 335 |
+
check_dim_size(gradOutput, ndim, ndim-1, owidth);
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
} // anonymous namespace
|
| 339 |
+
|
| 340 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RNN.h
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
using lstm_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool, bool);
|
| 9 |
+
using rnn_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool, bool);
|
| 10 |
+
using lstm_packed_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool);
|
| 11 |
+
using rnn_packed_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool);
|
| 12 |
+
|
| 13 |
+
DECLARE_DISPATCH(lstm_fn, lstm_cudnn_stub);
|
| 14 |
+
DECLARE_DISPATCH(lstm_fn, lstm_miopen_stub);
|
| 15 |
+
DECLARE_DISPATCH(lstm_fn, lstm_mkldnn_stub);
|
| 16 |
+
DECLARE_DISPATCH(rnn_fn, gru_cudnn_stub);
|
| 17 |
+
DECLARE_DISPATCH(rnn_fn, gru_miopen_stub);
|
| 18 |
+
DECLARE_DISPATCH(rnn_fn, rnn_tanh_cudnn_stub);
|
| 19 |
+
DECLARE_DISPATCH(rnn_fn, rnn_tanh_miopen_stub);
|
| 20 |
+
DECLARE_DISPATCH(rnn_fn, rnn_relu_cudnn_stub);
|
| 21 |
+
DECLARE_DISPATCH(rnn_fn, rnn_relu_miopen_stub);
|
| 22 |
+
DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_cudnn_stub);
|
| 23 |
+
DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_miopen_stub);
|
| 24 |
+
DECLARE_DISPATCH(rnn_packed_fn, gru_packed_cudnn_stub);
|
| 25 |
+
DECLARE_DISPATCH(rnn_packed_fn, gru_packed_miopen_stub);
|
| 26 |
+
DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_cudnn_stub);
|
| 27 |
+
DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_miopen_stub);
|
| 28 |
+
DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_cudnn_stub);
|
| 29 |
+
DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_miopen_stub);
|
| 30 |
+
|
| 31 |
+
inline void check_attributes(const Tensor& input, const TensorList& params, const TensorList& hiddens, bool check_dtype=false) {
|
| 32 |
+
auto input_device = input.device();
|
| 33 |
+
auto input_dtype = input.scalar_type();
|
| 34 |
+
|
| 35 |
+
auto check_tensors = [&](const std::string& name, const Tensor& t) {
|
| 36 |
+
if (!t.defined()) return;
|
| 37 |
+
auto t_device = t.device();
|
| 38 |
+
TORCH_CHECK(input_device == t_device,
|
| 39 |
+
"Input and ", name, " tensors are not at the same device, found input tensor at ",
|
| 40 |
+
input_device, " and ", name, " tensor at ", t_device);
|
| 41 |
+
if (check_dtype) {
|
| 42 |
+
auto t_dtype = t.scalar_type();
|
| 43 |
+
TORCH_CHECK(input_dtype == t_dtype,
|
| 44 |
+
"Input and ", name, " tensors are not the same dtype, found input tensor with ",
|
| 45 |
+
input_dtype, " and ", name, " tensor with ", t_dtype);
|
| 46 |
+
}
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
for (const auto& h : hiddens) check_tensors("hidden", h);
|
| 50 |
+
for (const auto& p : params) check_tensors("parameter", p);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Repeat.h
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/TensorOperators.h>
|
| 5 |
+
|
| 6 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 7 |
+
#include <ATen/Functions.h>
|
| 8 |
+
#else
|
| 9 |
+
#include <ATen/ops/empty.h>
|
| 10 |
+
#include <ATen/ops/empty_like.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
|
| 15 |
+
template <
|
| 16 |
+
typename index_t,
|
| 17 |
+
void compute(index_t*, int64_t*, index_t*, int64_t, int64_t)>
|
| 18 |
+
static inline Tensor repeat_interleave_common(
|
| 19 |
+
const Tensor& repeats,
|
| 20 |
+
c10::optional<int64_t> output_size) {
|
| 21 |
+
TORCH_CHECK(
|
| 22 |
+
repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
|
| 23 |
+
TORCH_CHECK(
|
| 24 |
+
repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
|
| 25 |
+
"repeats has to be Long or Int tensor");
|
| 26 |
+
if (repeats.size(0) == 0) {
|
| 27 |
+
return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
| 28 |
+
}
|
| 29 |
+
Tensor repeats_ = repeats.contiguous();
|
| 30 |
+
Tensor cumsum = repeats.cumsum(0);
|
| 31 |
+
int64_t total;
|
| 32 |
+
if (output_size.has_value()) {
|
| 33 |
+
total = output_size.value();
|
| 34 |
+
} else {
|
| 35 |
+
total = cumsum[-1].item<int64_t>();
|
| 36 |
+
TORCH_CHECK(
|
| 37 |
+
(repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
Tensor result = at::empty({total}, repeats.options());
|
| 41 |
+
index_t* repeat_ptr = repeats_.data_ptr<index_t>();
|
| 42 |
+
int64_t* cumsum_ptr = cumsum.data_ptr<int64_t>();
|
| 43 |
+
index_t* result_ptr = result.data_ptr<index_t>();
|
| 44 |
+
compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
|
| 45 |
+
return result;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Resize.h
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/ResizeCommon.h>
|
| 5 |
+
#include <ATen/EmptyTensor.h>
|
| 6 |
+
#include <ATen/TensorUtils.h>
|
| 7 |
+
|
| 8 |
+
#include <c10/core/CPUAllocator.h>
|
| 9 |
+
|
| 10 |
+
#include <utility>
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
|
| 15 |
+
// TODO: make all operations that resize given outputs use this function
|
| 16 |
+
// for consistency and maintainability.
|
| 17 |
+
// Some operations like `cat` might not be able to make the use of
|
| 18 |
+
// resize_output directly. For more details to understand how it works in `cat`,
|
| 19 |
+
// see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
|
| 20 |
+
// Resizes outputs
|
| 21 |
+
// Functions accepting output tensors, like with the "out" kwarg, should
|
| 22 |
+
// call this function to handle resizing their output tensor.
|
| 23 |
+
// Issues a warning if the output tensor has one or more elements and
|
| 24 |
+
// needs resizing
|
| 25 |
+
// NOTE: In the future the warning will become an error
|
| 26 |
+
// Returns a bool saying whether or not the resize actually happened or not
|
| 27 |
+
TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
|
| 28 |
+
// WARNING: Do NOT call this directly. If you are resizing an output and want
|
| 29 |
+
// to support dynamic shapes call at::resize__symint and resize_output_check_symint.
|
| 30 |
+
// For more details, see: https://github.com/pytorch/pytorch/pull/111530/files#r1365845272
|
| 31 |
+
TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
|
| 32 |
+
|
| 33 |
+
// Utility for resize_output
|
| 34 |
+
// Returns a bool saying resize should happen or not and
|
| 35 |
+
// raises a warning if resizing for one or more elements
|
| 36 |
+
TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
|
| 37 |
+
TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
|
| 38 |
+
|
| 39 |
+
TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
|
| 40 |
+
TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);
|
| 41 |
+
TORCH_API void resize_bytes_nocuda(const Storage& storage, c10::SymInt size_bytes);
|
| 42 |
+
|
| 43 |
+
static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
|
| 44 |
+
// It does not make sense to try to resize a storage
|
| 45 |
+
// to hold 0 elements, and this can break
|
| 46 |
+
// if storage_offset is positive but
|
| 47 |
+
// new_size is 0, so just bail in that case
|
| 48 |
+
// (same comment is in cuda/Resize.h)
|
| 49 |
+
if (self->numel() == 0) {
|
| 50 |
+
return;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
const Storage& storage = self->unsafe_storage();
|
| 54 |
+
if (!storage) {
|
| 55 |
+
auto new_storage = c10::make_intrusive<StorageImpl>(
|
| 56 |
+
StorageImpl::use_byte_size_t(),
|
| 57 |
+
new_size_bytes,
|
| 58 |
+
c10::GetCPUAllocator(),
|
| 59 |
+
true);
|
| 60 |
+
self->set_storage_keep_dtype(std::move(new_storage));
|
| 61 |
+
} else if (new_size_bytes > storage.nbytes()) {
|
| 62 |
+
resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
TORCH_API TensorImpl* resize_impl_cpu_(
|
| 67 |
+
TensorImpl* self,
|
| 68 |
+
IntArrayRef size,
|
| 69 |
+
at::OptionalIntArrayRef stride,
|
| 70 |
+
bool resize_storage = true);
|
| 71 |
+
|
| 72 |
+
template <typename T>
|
| 73 |
+
T maybe_convert_symint(c10::SymInt) = delete;
|
| 74 |
+
|
| 75 |
+
template <>
|
| 76 |
+
inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
|
| 77 |
+
|
| 78 |
+
template <>
|
| 79 |
+
inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
|
| 80 |
+
|
| 81 |
+
template <typename T>
|
| 82 |
+
static inline void checkInBoundsForStorage(
|
| 83 |
+
ArrayRef<T> size,
|
| 84 |
+
ArrayRef<T> stride,
|
| 85 |
+
T storage_offset,
|
| 86 |
+
const caffe2::TypeMeta& data_type,
|
| 87 |
+
const Storage& new_storage) {
|
| 88 |
+
T storage_size_bytes =
|
| 89 |
+
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
|
| 90 |
+
T storage_offset_bytes = storage_offset * data_type.itemsize();
|
| 91 |
+
if (storage_size_bytes == 0) {
|
| 92 |
+
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
|
| 93 |
+
return;
|
| 94 |
+
}
|
| 95 |
+
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
|
| 96 |
+
TORCH_CHECK(
|
| 97 |
+
storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
|
| 98 |
+
"setStorage: sizes ",
|
| 99 |
+
size,
|
| 100 |
+
", strides ",
|
| 101 |
+
stride,
|
| 102 |
+
","
|
| 103 |
+
" storage offset ",
|
| 104 |
+
storage_offset,
|
| 105 |
+
", and itemsize ",
|
| 106 |
+
data_type.itemsize(),
|
| 107 |
+
" requiring a storage size of ",
|
| 108 |
+
storage_size_bytes + storage_offset_bytes,
|
| 109 |
+
" are out of bounds for storage of size ",
|
| 110 |
+
new_storage_size_bytes);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
template <typename T>
|
| 114 |
+
static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
| 115 |
+
ArrayRef<T> size, ArrayRef<T> stride) {
|
| 116 |
+
// FIXME: stride should be optional
|
| 117 |
+
if (stride.data()) {
|
| 118 |
+
TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
|
| 119 |
+
") and stride length (", stride.size(), ")");
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
#ifdef DEBUG
|
| 123 |
+
TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
|
| 124 |
+
#endif
|
| 125 |
+
|
| 126 |
+
// storage: note this can't be replaced with result.set_(storage) as the semantics of that
|
| 127 |
+
// function is to set the tensor size to be equal to the size of the storage.
|
| 128 |
+
if (!result.storage().is_alias_of(storage)) {
|
| 129 |
+
// Caffe2 might have tensors whose storages are null, but we
|
| 130 |
+
// don't allow it in PyTorch.
|
| 131 |
+
TORCH_INTERNAL_ASSERT(storage);
|
| 132 |
+
TORCH_INTERNAL_ASSERT(result.storage());
|
| 133 |
+
|
| 134 |
+
// We used to allow this, but this breaks device caching.
|
| 135 |
+
// Let's put an actual error message for this one.
|
| 136 |
+
TORCH_CHECK(result.storage().device() == storage.device(),
|
| 137 |
+
"Attempted to set the storage of a tensor on device \"", result.storage().device(),
|
| 138 |
+
"\" to a storage on different device \"", storage.device(),
|
| 139 |
+
"\". This is no longer allowed; the devices must match.");
|
| 140 |
+
result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
// storageOffset
|
| 144 |
+
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
/**
|
| 148 |
+
* Set self's sizes, strides, and storage_offset.
|
| 149 |
+
* (size, stride, storage_offset) must be in bounds for self's storage.
|
| 150 |
+
*/
|
| 151 |
+
template <typename T>
|
| 152 |
+
inline void setStrided(
|
| 153 |
+
const Tensor& self,
|
| 154 |
+
ArrayRef<T> size,
|
| 155 |
+
ArrayRef<T> stride,
|
| 156 |
+
T storage_offset) {
|
| 157 |
+
TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
|
| 158 |
+
for (const auto& val : stride) {
|
| 159 |
+
TORCH_CHECK(val >= 0,
|
| 160 |
+
"as_strided: Negative strides are not supported at the moment, "
|
| 161 |
+
"got strides: ", stride);
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
auto* self_ = self.unsafeGetTensorImpl();
|
| 165 |
+
checkInBoundsForStorage(
|
| 166 |
+
size, stride, storage_offset, self_->dtype(), self_->storage());
|
| 167 |
+
|
| 168 |
+
/* storage offset */
|
| 169 |
+
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
| 170 |
+
self_->set_sizes_and_strides(size, stride, c10::make_optional(storage_offset));
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ResizeCommon.h
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/TensorFactories.h>
|
| 5 |
+
#include <ATen/NamedTensorUtils.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
|
| 8 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 9 |
+
#include <ATen/NativeFunctions.h>
|
| 10 |
+
#else
|
| 11 |
+
#include <ATen/ops/empty.h>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
|
| 16 |
+
template <typename T>
|
| 17 |
+
inline T storage_size_for(ArrayRef<T> size, ArrayRef<T> stride) {
|
| 18 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(),
|
| 19 |
+
"storage_size_for(size, stride) requires that size and stride ",
|
| 20 |
+
"have the same size as a precondition.");
|
| 21 |
+
T storage_size = 1;
|
| 22 |
+
for (const auto dim : c10::irange(size.size())) {
|
| 23 |
+
if (size[dim] == 0) {
|
| 24 |
+
storage_size = 0;
|
| 25 |
+
break;
|
| 26 |
+
}
|
| 27 |
+
storage_size += (size[dim] - 1) * stride[dim];
|
| 28 |
+
}
|
| 29 |
+
return storage_size;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
inline const Tensor& resize_named_tensor_(
|
| 33 |
+
const Tensor& self,
|
| 34 |
+
IntArrayRef size,
|
| 35 |
+
c10::optional<MemoryFormat> optional_memory_format) {
|
| 36 |
+
TORCH_INTERNAL_ASSERT(self.has_names());
|
| 37 |
+
TORCH_CHECK(
|
| 38 |
+
self.sizes() == size,
|
| 39 |
+
"Cannot resize named tensor with resize_ or resize_as_ (tried to resize "
|
| 40 |
+
"Tensor",
|
| 41 |
+
self.names(),
|
| 42 |
+
" with size ",
|
| 43 |
+
self.sizes(),
|
| 44 |
+
" to ",
|
| 45 |
+
size,
|
| 46 |
+
"). This may be caused by passing a named tensor ",
|
| 47 |
+
"as an `out=` argument; please ensure that the sizes are the same. ");
|
| 48 |
+
TORCH_CHECK(
|
| 49 |
+
!optional_memory_format.has_value(),
|
| 50 |
+
"Unsupported memory format for named tensor resize ",
|
| 51 |
+
optional_memory_format.value());
|
| 52 |
+
return self;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
// For deterministic output, fill new elements that were added after a storage
|
| 56 |
+
// resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage
|
| 57 |
+
// before the resize happened.
|
| 58 |
+
inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) {
|
| 59 |
+
const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage();
|
| 60 |
+
int64_t new_storage_nbytes = storage.nbytes();
|
| 61 |
+
int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize();
|
| 62 |
+
int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize();
|
| 63 |
+
if (new_storage_numel > old_storage_numel) {
|
| 64 |
+
at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device()));
|
| 65 |
+
tensor_view.set_(
|
| 66 |
+
storage,
|
| 67 |
+
/*storage_offset=*/old_storage_numel,
|
| 68 |
+
/*size=*/{new_storage_numel - old_storage_numel},
|
| 69 |
+
/*stride=*/{1});
|
| 70 |
+
at::native::fill_empty_deterministic_(tensor_view);
|
| 71 |
+
}
|
| 72 |
+
return tensor;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SharedReduceOps.h
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// Please note that this file is
|
| 3 |
+
// used across both CPU and GPU.
|
| 4 |
+
|
| 5 |
+
#include <type_traits>
|
| 6 |
+
#include <complex>
|
| 7 |
+
#include <c10/macros/Macros.h>
|
| 8 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 9 |
+
#include <ATen/NumericUtils.h>
|
| 10 |
+
#if defined(__CUDACC__)
|
| 11 |
+
#include <ATen/cuda/DeviceUtils.cuh>
|
| 12 |
+
#include <ATen/native/cuda/DeviceSqrt.cuh>
|
| 13 |
+
#elif defined(__HIPCC__)
|
| 14 |
+
#include <ATen/hip/DeviceUtils.cuh>
|
| 15 |
+
#include <ATen/native/hip/DeviceSqrt.cuh>
|
| 16 |
+
#endif
|
| 17 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 18 |
+
#include <thrust/pair.h>
|
| 19 |
+
#else
|
| 20 |
+
#include <cmath>
|
| 21 |
+
#define device_sqrt std::sqrt
|
| 22 |
+
#endif
|
| 23 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 24 |
+
template <typename scalar_t>
|
| 25 |
+
inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
|
| 26 |
+
#if defined(__HIPCC__)
|
| 27 |
+
// TODO: remove this special case for HIP when issue is fixed:
|
| 28 |
+
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
|
| 29 |
+
scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
|
| 30 |
+
#else
|
| 31 |
+
scalar_t max = at::_isnan(b) ? b : std::max(a, b);
|
| 32 |
+
#endif
|
| 33 |
+
return max;
|
| 34 |
+
}
|
| 35 |
+
template <typename scalar_t>
|
| 36 |
+
inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
|
| 37 |
+
#if defined(__HIPCC__)
|
| 38 |
+
// TODO: remove this special case for HIP when issue is fixed:
|
| 39 |
+
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
|
| 40 |
+
scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
|
| 41 |
+
#else
|
| 42 |
+
scalar_t min = at::_isnan(b) ? b : std::min(a, b);
|
| 43 |
+
#endif
|
| 44 |
+
return min;
|
| 45 |
+
}
|
| 46 |
+
#define MAX(X, Y) max_propagate_nan(X,Y)
|
| 47 |
+
#define MIN(X, Y) min_propagate_nan(X,Y)
|
| 48 |
+
#else
|
| 49 |
+
#include <ATen/native/cpu/zmath.h>
|
| 50 |
+
#define MAX(X, Y) max_impl(X,Y)
|
| 51 |
+
#define MIN(X, Y) min_impl(X,Y)
|
| 52 |
+
#endif
|
| 53 |
+
|
| 54 |
+
// ROCM hcc doesn't work well with using std:: in kernel functions
|
| 55 |
+
#if defined(__CUDA_ARCH__)
|
| 56 |
+
#include <c10/cuda/CUDAMathCompat.h>
|
| 57 |
+
#define compat_pow c10::cuda::compat::pow
|
| 58 |
+
#elif defined(__HIPCC__)
|
| 59 |
+
#include <c10/hip/HIPMathCompat.h>
|
| 60 |
+
#define compat_pow c10::hip::compat::pow
|
| 61 |
+
#else
|
| 62 |
+
#define compat_pow std::pow
|
| 63 |
+
#endif
|
| 64 |
+
|
| 65 |
+
namespace at { namespace native {
|
| 66 |
+
|
| 67 |
+
namespace detail {
|
| 68 |
+
|
| 69 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 70 |
+
template <typename T1, typename T2> using pair = thrust::pair<T1, T2>;
|
| 71 |
+
#else
|
| 72 |
+
template <typename T1, typename T2> using pair = std::pair<T1, T2>;
|
| 73 |
+
#endif
|
| 74 |
+
|
| 75 |
+
} // namespace detail
|
| 76 |
+
|
| 77 |
+
template <typename scalar_t, typename index_t>
|
| 78 |
+
struct WelfordData {
|
| 79 |
+
scalar_t mean;
|
| 80 |
+
scalar_t m2;
|
| 81 |
+
index_t n;
|
| 82 |
+
scalar_t nf;
|
| 83 |
+
|
| 84 |
+
C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
|
| 85 |
+
|
| 86 |
+
C10_HOST_DEVICE WelfordData(
|
| 87 |
+
scalar_t mean,
|
| 88 |
+
scalar_t m2,
|
| 89 |
+
index_t n,
|
| 90 |
+
scalar_t nf)
|
| 91 |
+
: mean(mean), m2(m2), n(n), nf(nf) {}
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
template <typename scalar_t, typename acc_scalar_t, typename index_t, typename res_t>
|
| 96 |
+
struct WelfordOps {
|
| 97 |
+
acc_scalar_t correction;
|
| 98 |
+
bool take_sqrt;
|
| 99 |
+
public:
|
| 100 |
+
using acc_t = WelfordData<acc_scalar_t, index_t>;
|
| 101 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
|
| 102 |
+
// We accumulate n in index_t to avoid cumulative rounding error, but still
|
| 103 |
+
// need nf for use in combine where int32 may overflow.
|
| 104 |
+
index_t new_n = acc.n + 1;
|
| 105 |
+
acc_scalar_t new_nf = static_cast<acc_scalar_t>(new_n);
|
| 106 |
+
acc_scalar_t delta = data - acc.mean;
|
| 107 |
+
acc_scalar_t new_mean = acc.mean + delta / new_nf;
|
| 108 |
+
acc_scalar_t new_delta = data - new_mean;
|
| 109 |
+
return {
|
| 110 |
+
new_mean,
|
| 111 |
+
acc.m2 + delta * new_delta,
|
| 112 |
+
new_n,
|
| 113 |
+
new_nf,
|
| 114 |
+
};
|
| 115 |
+
}
|
| 116 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 117 |
+
if (a.nf == 0) {
|
| 118 |
+
return b;
|
| 119 |
+
}
|
| 120 |
+
if (b.nf == 0) {
|
| 121 |
+
return a;
|
| 122 |
+
}
|
| 123 |
+
acc_scalar_t delta = b.mean - a.mean;
|
| 124 |
+
acc_scalar_t new_count = a.nf + b.nf;
|
| 125 |
+
acc_scalar_t nb_over_n = b.nf / new_count;
|
| 126 |
+
return {
|
| 127 |
+
a.mean + delta * nb_over_n,
|
| 128 |
+
a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
|
| 129 |
+
// setting acc.n as -1 since acc.n might not be able to represent the count
|
| 130 |
+
// correctly within its range, setting it to -1 to avoid confusion
|
| 131 |
+
-1,
|
| 132 |
+
new_count
|
| 133 |
+
};
|
| 134 |
+
}
|
| 135 |
+
inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ {
|
| 136 |
+
const auto mean = static_cast<scalar_t>(acc.mean);
|
| 137 |
+
const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
|
| 138 |
+
const auto var = acc.m2 / divisor;
|
| 139 |
+
res_t results(take_sqrt ? device_sqrt(var) : var, mean);
|
| 140 |
+
return results;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 144 |
+
return acc;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 148 |
+
inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 149 |
+
return {
|
| 150 |
+
WARP_SHFL_DOWN(acc.mean, offset)
|
| 151 |
+
, WARP_SHFL_DOWN(acc.m2, offset)
|
| 152 |
+
, WARP_SHFL_DOWN(acc.n, offset)
|
| 153 |
+
, WARP_SHFL_DOWN(acc.nf, offset)
|
| 154 |
+
};
|
| 155 |
+
}
|
| 156 |
+
#endif
|
| 157 |
+
C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
|
| 158 |
+
: correction(correction), take_sqrt(take_sqrt) {}
|
| 159 |
+
};
|
| 160 |
+
|
| 161 |
+
template <typename scalar_t, typename acc_t=scalar_t, typename factor_t=acc_t, typename out_t = acc_t>
|
| 162 |
+
struct MeanOps {
|
| 163 |
+
factor_t factor;
|
| 164 |
+
|
| 165 |
+
inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const {
|
| 166 |
+
return combine(a, static_cast<acc_t>(b));
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 170 |
+
return a + b;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 174 |
+
return a * factor;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 178 |
+
return acc;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 182 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
|
| 183 |
+
return WARP_SHFL_DOWN(data, offset);
|
| 184 |
+
}
|
| 185 |
+
#endif
|
| 186 |
+
|
| 187 |
+
MeanOps(factor_t factor): factor(factor) {
|
| 188 |
+
}
|
| 189 |
+
};
|
| 190 |
+
|
| 191 |
+
// This accumulator template is used to calculate the minimum absolute value of
|
| 192 |
+
// a set of numbers.
|
| 193 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 194 |
+
// value. These types differ for complex number input support.
|
| 195 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 196 |
+
struct AbsMinOps {
|
| 197 |
+
|
| 198 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 199 |
+
return MIN(acc, static_cast<acc_t>(std::abs(data)));
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 203 |
+
return MIN(a, b);
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 207 |
+
return a;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 211 |
+
return acc;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 215 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 216 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 217 |
+
}
|
| 218 |
+
#endif
|
| 219 |
+
};
|
| 220 |
+
|
| 221 |
+
// This accumulator template is used to calculate the maximum absolute value of
|
| 222 |
+
// a set of numbers.
|
| 223 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 224 |
+
// value. These types differ for complex number input support.
|
| 225 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 226 |
+
struct AbsMaxOps {
|
| 227 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 228 |
+
return MAX(acc, static_cast<acc_t>(std::abs(data)));
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 232 |
+
return MAX(a, b);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 236 |
+
return a;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 240 |
+
return acc;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 244 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 245 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 246 |
+
}
|
| 247 |
+
#endif
|
| 248 |
+
};
|
| 249 |
+
|
| 250 |
+
// This accumulator template is used to calculate the norm of the absolute value
|
| 251 |
+
// of a set of numbers.
|
| 252 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 253 |
+
// value. These types differ for complex number input support.
|
| 254 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 255 |
+
struct NormOps {
|
| 256 |
+
acc_t norm_;
|
| 257 |
+
|
| 258 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 259 |
+
return acc + compat_pow(static_cast<acc_t>(std::abs(data)), norm_);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 263 |
+
return a + b;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 267 |
+
return compat_pow(a, static_cast<acc_t>(1.0) / norm_);
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 271 |
+
return acc;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 275 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 276 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 277 |
+
}
|
| 278 |
+
#endif
|
| 279 |
+
|
| 280 |
+
NormOps(acc_t norm_): norm_(norm_) {
|
| 281 |
+
}
|
| 282 |
+
};
|
| 283 |
+
|
| 284 |
+
// This accumulator template is used to calculate the order zero norm of the
|
| 285 |
+
// absolute value of a set of numbers.
|
| 286 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 287 |
+
// value. These types differ for complex number input support.
|
| 288 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 289 |
+
struct NormZeroOps {
|
| 290 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 291 |
+
return acc + (data == static_cast<scalar_t>(0) ? static_cast<acc_t>(0) : static_cast<acc_t>(1));
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 295 |
+
return a + b;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 299 |
+
return a;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 303 |
+
return acc;
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 308 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 309 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 310 |
+
}
|
| 311 |
+
#endif
|
| 312 |
+
};
|
| 313 |
+
|
| 314 |
+
// This accumulator template is used to calculate the order one norm of the
|
| 315 |
+
// absolute value of a set of numbers.
|
| 316 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 317 |
+
// value. These types differ for complex number input support.
|
| 318 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 319 |
+
struct NormOneOps {
|
| 320 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 321 |
+
return acc + static_cast<acc_t>(std::abs(data));
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 325 |
+
return a + b;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 329 |
+
return a;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 333 |
+
return acc;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 337 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 338 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 339 |
+
}
|
| 340 |
+
#endif
|
| 341 |
+
};
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
template<typename acc_t>
|
| 345 |
+
struct AbsSwitch {};
|
| 346 |
+
|
| 347 |
+
template<typename scalar_t, typename acc_t>
|
| 348 |
+
inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t>) {
|
| 349 |
+
return static_cast<acc_t>(data);
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
template<typename scalar_t, typename acc_t>
|
| 353 |
+
inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t>) {
|
| 354 |
+
return static_cast<acc_t>(std::abs(data));
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
template<typename scalar_t, typename acc_t>
|
| 358 |
+
inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t>) {
|
| 359 |
+
return static_cast<acc_t>(std::abs(data));
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
// This accumulator template is used to calculate the order two norm of the
|
| 363 |
+
// absolute value of a set of numbers.
|
| 364 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 365 |
+
// value. These types differ for complex number input support.
|
| 366 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 367 |
+
struct NormTwoOps {
|
| 368 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 369 |
+
acc_t data_ = abs_if_complex(data, AbsSwitch<acc_t>());
|
| 370 |
+
return acc + data_ * data_;
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 374 |
+
return a + b;
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 378 |
+
return device_sqrt(a);
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 382 |
+
return acc;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 386 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 387 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 388 |
+
}
|
| 389 |
+
#endif
|
| 390 |
+
};
|
| 391 |
+
|
| 392 |
+
template <typename acc_t, typename data_t>
|
| 393 |
+
struct NanSumOps {
|
| 394 |
+
inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const {
|
| 395 |
+
return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b});
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 399 |
+
return a + b;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
inline C10_DEVICE data_t project(acc_t a) const {
|
| 403 |
+
return data_t{a};
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 407 |
+
return acc;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 411 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
|
| 412 |
+
return WARP_SHFL_DOWN(data, offset);
|
| 413 |
+
}
|
| 414 |
+
#endif
|
| 415 |
+
};
|
| 416 |
+
|
| 417 |
+
namespace detail {
|
| 418 |
+
|
| 419 |
+
template <typename scalar_t>
|
| 420 |
+
struct LessOrNan {
|
| 421 |
+
C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
|
| 422 |
+
// If (a == b), then choose the one with lower idx, else min(a, b)
|
| 423 |
+
if (at::_isnan(a)) {
|
| 424 |
+
if (at::_isnan(b)) {
|
| 425 |
+
return idx_a < idx_b;
|
| 426 |
+
}
|
| 427 |
+
return true;
|
| 428 |
+
}
|
| 429 |
+
return (a == b) ? idx_a < idx_b : (a < b);
|
| 430 |
+
}
|
| 431 |
+
};
|
| 432 |
+
|
| 433 |
+
template <typename scalar_t>
|
| 434 |
+
struct GreaterOrNan {
|
| 435 |
+
C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
|
| 436 |
+
// If (a == b), then choose the one with lower idx, else max(a, b)
|
| 437 |
+
if (at::_isnan(a)) {
|
| 438 |
+
if (at::_isnan(b)) {
|
| 439 |
+
return idx_a < idx_b;
|
| 440 |
+
}
|
| 441 |
+
return true;
|
| 442 |
+
}
|
| 443 |
+
return (a == b) ? idx_a < idx_b : (a > b);
|
| 444 |
+
}
|
| 445 |
+
};
|
| 446 |
+
|
| 447 |
+
template <typename comp_t>
|
| 448 |
+
struct MinMaxReductionOps {
|
| 449 |
+
using scalar_t = typename binary_function_traits<comp_t>::arg1_t;
|
| 450 |
+
using index_t = int64_t;
|
| 451 |
+
using arg_t = detail::pair<scalar_t, index_t>;
|
| 452 |
+
|
| 453 |
+
static C10_DEVICE arg_t project(arg_t arg) {
|
| 454 |
+
return arg;
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
|
| 458 |
+
return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
|
| 462 |
+
return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
|
| 466 |
+
return {a.first, a.second + base_idx};
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 470 |
+
static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
|
| 471 |
+
return arg_t(WARP_SHFL_DOWN(arg.first, offset),
|
| 472 |
+
WARP_SHFL_DOWN(arg.second, offset));
|
| 473 |
+
}
|
| 474 |
+
#endif
|
| 475 |
+
};
|
| 476 |
+
|
| 477 |
+
template <typename comp_t>
|
| 478 |
+
struct ArgReductionOps : public MinMaxReductionOps<comp_t> {
|
| 479 |
+
using typename MinMaxReductionOps<comp_t>::scalar_t;
|
| 480 |
+
using typename MinMaxReductionOps<comp_t>::index_t;
|
| 481 |
+
using typename MinMaxReductionOps<comp_t>::arg_t;
|
| 482 |
+
|
| 483 |
+
static C10_DEVICE index_t project(arg_t arg) {
|
| 484 |
+
return arg.second;
|
| 485 |
+
}
|
| 486 |
+
};
|
| 487 |
+
|
| 488 |
+
} // namespace detail
|
| 489 |
+
|
| 490 |
+
template <typename scalar_t>
|
| 491 |
+
struct ArgMaxOps :
|
| 492 |
+
public detail::ArgReductionOps<detail::GreaterOrNan<scalar_t>> {
|
| 493 |
+
};
|
| 494 |
+
|
| 495 |
+
template <typename scalar_t>
|
| 496 |
+
struct ArgMinOps :
|
| 497 |
+
public detail::ArgReductionOps<detail::LessOrNan<scalar_t>> {
|
| 498 |
+
};
|
| 499 |
+
|
| 500 |
+
template <typename scalar_t>
|
| 501 |
+
struct MinOps :
|
| 502 |
+
public detail::MinMaxReductionOps<detail::LessOrNan<scalar_t>> {
|
| 503 |
+
};
|
| 504 |
+
|
| 505 |
+
template <typename scalar_t>
|
| 506 |
+
struct MaxOps :
|
| 507 |
+
public detail::MinMaxReductionOps<detail::GreaterOrNan<scalar_t>> {
|
| 508 |
+
};
|
| 509 |
+
|
| 510 |
+
template <typename scalar_t, typename acc_scalar_t, typename index_t>
|
| 511 |
+
struct MinMaxOps {
|
| 512 |
+
using acc_t = detail::pair<acc_scalar_t, acc_scalar_t>;
|
| 513 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
|
| 514 |
+
return combine(acc, {data, data});
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 518 |
+
auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first;
|
| 519 |
+
auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second;
|
| 520 |
+
|
| 521 |
+
return {min_val, max_val};
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
inline C10_DEVICE acc_t project(acc_t acc) const {
|
| 525 |
+
return acc;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 529 |
+
return acc;
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 533 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 534 |
+
return {
|
| 535 |
+
WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset)
|
| 536 |
+
};
|
| 537 |
+
}
|
| 538 |
+
#endif
|
| 539 |
+
};
|
| 540 |
+
|
| 541 |
+
}} // namespace at::native
|
| 542 |
+
|
| 543 |
+
#undef MAX
|
| 544 |
+
#undef MIN
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SparseTensorUtils.h
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Parallel.h>
|
| 4 |
+
#include <ATen/SparseTensorImpl.h>
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
|
| 7 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 8 |
+
#include <ATen/Functions.h>
|
| 9 |
+
#else
|
| 10 |
+
#include <ATen/ops/empty.h>
|
| 11 |
+
#include <ATen/ops/tensor.h>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
namespace at::sparse {
|
| 15 |
+
|
| 16 |
+
// Just for documentary purposes
|
| 17 |
+
using SparseTensor = Tensor;
|
| 18 |
+
using SparseType = Type;
|
| 19 |
+
|
| 20 |
+
// This is an internal utility function for getting at the SparseTensorImpl,
|
| 21 |
+
// so that we can write sparse tensor specific accessors for special fields
|
| 22 |
+
// in SparseTensor. You should only use this for writing low level
|
| 23 |
+
// setters/getters for SparseTensorImpl fields; otherwise, you should use
|
| 24 |
+
// the low level setters/getters that were implemented using this.
|
| 25 |
+
//
|
| 26 |
+
// This may be called repeatedly, so make sure it's pretty cheap.
|
| 27 |
+
inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
|
| 28 |
+
TORCH_INTERNAL_ASSERT(
|
| 29 |
+
self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
|
| 30 |
+
return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
// Takes indices and values and directly puts them into the sparse tensor, no
|
| 34 |
+
// copy. This used to be called THSTensor_(_move)
|
| 35 |
+
inline void alias_into_sparse(
|
| 36 |
+
const SparseTensor& self,
|
| 37 |
+
const Tensor& indices,
|
| 38 |
+
const Tensor& values) {
|
| 39 |
+
get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
// Take indices and values and makes a (data) copy of them to put into the
|
| 43 |
+
// sparse indices/values. This used to be called THSTensor_(_set)
|
| 44 |
+
inline void copy_into_sparse(
|
| 45 |
+
const SparseTensor& self,
|
| 46 |
+
const Tensor& indices,
|
| 47 |
+
const Tensor& values,
|
| 48 |
+
bool non_blocking) {
|
| 49 |
+
alias_into_sparse(
|
| 50 |
+
self,
|
| 51 |
+
indices.to(self._indices().options(), non_blocking, /*copy=*/true),
|
| 52 |
+
values.to(self._values().options(), non_blocking, /*copy=*/true));
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
// TODO: put this into the public API
|
| 56 |
+
inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
|
| 57 |
+
return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
|
| 61 |
+
return self.sparse_dim() == src.sparse_dim() &&
|
| 62 |
+
self.dense_dim() == src.dense_dim();
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
// Give us a new values tensor, with the same dimensionality
|
| 66 |
+
// as 'values' but with a new number of non-zero elements.
|
| 67 |
+
// TODO: Expose this for real in ATen, some day?
|
| 68 |
+
// NB: Doesn't preserve data.
|
| 69 |
+
inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
|
| 70 |
+
std::vector<int64_t> size = values.sizes().vec();
|
| 71 |
+
size[0] = nnz;
|
| 72 |
+
return at::empty(size, values.options());
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// NOTE [ Flatten Sparse Indices ]
|
| 76 |
+
// This helper function flattens a sparse indices tensor (a Tensor) into a 1D
|
| 77 |
+
// indices tensor. E.g.,
|
| 78 |
+
// input = [[2, 4, 0],
|
| 79 |
+
// [3, 1, 10]]
|
| 80 |
+
// full_size = [2, 12]
|
| 81 |
+
// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
|
| 82 |
+
//
|
| 83 |
+
// In other words, assuming that each `indices[i, :]` is a valid index to a
|
| 84 |
+
// tensor `t` of shape `full_size`. This returns the corresponding indices to
|
| 85 |
+
// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
|
| 86 |
+
// if forceClone is true, the result will forced to be a clone of self.
|
| 87 |
+
// if force_clone is true, the result will forced to be a clone of self.
|
| 88 |
+
TORCH_API Tensor flatten_indices(
|
| 89 |
+
const Tensor& indices,
|
| 90 |
+
IntArrayRef full_size,
|
| 91 |
+
bool force_clone = false);
|
| 92 |
+
|
| 93 |
+
// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten
|
| 94 |
+
// Sparse Indices ], except this one allows partial flatten: only flatten on
|
| 95 |
+
// specified dims. Note that the flatten indices might be uncoalesced if
|
| 96 |
+
// dims_to_flatten.size() < sparse_dim. Also if input indices is already
|
| 97 |
+
// coalesced, the flattened indices will also be sorted.
|
| 98 |
+
//
|
| 99 |
+
// args:
|
| 100 |
+
// indices: sparse tensor indices
|
| 101 |
+
// sizes: sparse tensor sizes
|
| 102 |
+
// dims_to_flatten: a list of dim index to flatten
|
| 103 |
+
//
|
| 104 |
+
// Ex1:
|
| 105 |
+
// indices = [[2, 4, 0],
|
| 106 |
+
// [3, 1, 3]]
|
| 107 |
+
// sizes = [2, 12]
|
| 108 |
+
// dims_to_flatten = [0, 1]
|
| 109 |
+
// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
|
| 110 |
+
//
|
| 111 |
+
// Ex2:
|
| 112 |
+
// dims_to_flatten = [1]
|
| 113 |
+
// new_indices = [ 3, 1, 3 ] # uncoalesced
|
| 114 |
+
TORCH_API Tensor flatten_indices_by_dims(
|
| 115 |
+
const Tensor& indices,
|
| 116 |
+
const IntArrayRef& sizes,
|
| 117 |
+
const IntArrayRef& dims_to_flatten);
|
| 118 |
+
|
| 119 |
+
// Find the CSR representation for a row `indices` from the COO format
|
| 120 |
+
TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
|
| 121 |
+
|
| 122 |
+
TORCH_API Tensor zeros_like_with_indices(const Tensor& t);
|
| 123 |
+
|
| 124 |
+
template <size_t static_shape_max_len>
|
| 125 |
+
class TensorGeometryHolder {
|
| 126 |
+
using geometry_holder_t = std::array<int64_t, static_shape_max_len>;
|
| 127 |
+
|
| 128 |
+
public:
|
| 129 |
+
explicit TensorGeometryHolder(
|
| 130 |
+
IntArrayRef sizes,
|
| 131 |
+
IntArrayRef strides,
|
| 132 |
+
TensorOptions options = {}) {
|
| 133 |
+
std::copy(sizes.begin(), sizes.end(), t_sizes.begin());
|
| 134 |
+
std::copy(strides.begin(), strides.end(), t_strides.begin());
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
explicit TensorGeometryHolder(const Tensor& t)
|
| 138 |
+
: TensorGeometryHolder(t.sizes(), t.strides()) {}
|
| 139 |
+
|
| 140 |
+
auto operator*() const {
|
| 141 |
+
return std::make_tuple(t_sizes, t_strides);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
private:
|
| 145 |
+
geometry_holder_t t_sizes;
|
| 146 |
+
geometry_holder_t t_strides;
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
template <>
|
| 150 |
+
class TensorGeometryHolder<0> {
|
| 151 |
+
using geometry_holder_t = Tensor;
|
| 152 |
+
|
| 153 |
+
public:
|
| 154 |
+
explicit TensorGeometryHolder(
|
| 155 |
+
IntArrayRef sizes,
|
| 156 |
+
IntArrayRef strides,
|
| 157 |
+
TensorOptions options) {
|
| 158 |
+
const int64_t t_ndims = sizes.size();
|
| 159 |
+
const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU);
|
| 160 |
+
Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options);
|
| 161 |
+
t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options));
|
| 162 |
+
t_sizes_and_strides_cpu.select(0, 1).copy_(
|
| 163 |
+
at::tensor(strides, cpu_options));
|
| 164 |
+
const Tensor t_sizes_and_strides =
|
| 165 |
+
t_sizes_and_strides_cpu.to(options.device());
|
| 166 |
+
t_sizes = t_sizes_and_strides.select(0, 0);
|
| 167 |
+
t_strides = t_sizes_and_strides.select(0, 1);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
explicit TensorGeometryHolder(const Tensor& t)
|
| 171 |
+
: TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {}
|
| 172 |
+
|
| 173 |
+
auto operator*() const {
|
| 174 |
+
return std::make_tuple(
|
| 175 |
+
t_sizes.template data_ptr<int64_t>(),
|
| 176 |
+
t_strides.template data_ptr<int64_t>());
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
private:
|
| 180 |
+
geometry_holder_t t_sizes;
|
| 181 |
+
geometry_holder_t t_strides;
|
| 182 |
+
};
|
| 183 |
+
|
| 184 |
+
// Return all indices of a tensor with the given shape.
|
| 185 |
+
//
|
| 186 |
+
// full_coo_indices(shape) is equivalent to
|
| 187 |
+
// torch.ones(shape).nonzero().transpose(-2, -1) but much faster.
|
| 188 |
+
TORCH_API Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options);
|
| 189 |
+
|
| 190 |
+
} // namespace at::sparse
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/StridedRandomAccessor.h
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at::native {
|
| 4 |
+
|
| 5 |
+
// (Const)StridedRandomAccessor is a
|
| 6 |
+
// (const) random access iterator defined over
|
| 7 |
+
// a strided array.
|
| 8 |
+
|
| 9 |
+
// The traits below are to introduce __restrict__
|
| 10 |
+
// modifier on different platforms.
|
| 11 |
+
|
| 12 |
+
template <typename T>
|
| 13 |
+
struct DefaultPtrTraits {
|
| 14 |
+
using PtrType = T*;
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
#if (defined(_WIN32) || defined(_WIN64))
|
| 18 |
+
#define RESTRICT __restrict
|
| 19 |
+
#else
|
| 20 |
+
#define RESTRICT __restrict__
|
| 21 |
+
#endif
|
| 22 |
+
|
| 23 |
+
template <typename T>
|
| 24 |
+
struct RestrictPtrTraits {
|
| 25 |
+
using PtrType = T* RESTRICT;
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
template <
|
| 29 |
+
typename T,
|
| 30 |
+
typename index_t = int64_t,
|
| 31 |
+
template <typename U> class PtrTraits = DefaultPtrTraits
|
| 32 |
+
>
|
| 33 |
+
class ConstStridedRandomAccessor {
|
| 34 |
+
public:
|
| 35 |
+
using difference_type = index_t;
|
| 36 |
+
using value_type = const T;
|
| 37 |
+
using pointer = const typename PtrTraits<T>::PtrType;
|
| 38 |
+
using reference = const value_type&;
|
| 39 |
+
using iterator_category = std::random_access_iterator_tag;
|
| 40 |
+
|
| 41 |
+
using PtrType = typename PtrTraits<T>::PtrType;
|
| 42 |
+
using index_type = index_t;
|
| 43 |
+
|
| 44 |
+
// Constructors {
|
| 45 |
+
C10_HOST_DEVICE
|
| 46 |
+
ConstStridedRandomAccessor(PtrType ptr, index_t stride)
|
| 47 |
+
: ptr{ptr}, stride{stride}
|
| 48 |
+
{}
|
| 49 |
+
|
| 50 |
+
C10_HOST_DEVICE
|
| 51 |
+
explicit ConstStridedRandomAccessor(PtrType ptr)
|
| 52 |
+
: ptr{ptr}, stride{static_cast<index_t>(1)}
|
| 53 |
+
{}
|
| 54 |
+
|
| 55 |
+
C10_HOST_DEVICE
|
| 56 |
+
ConstStridedRandomAccessor()
|
| 57 |
+
: ptr{nullptr}, stride{static_cast<index_t>(1)}
|
| 58 |
+
{}
|
| 59 |
+
// }
|
| 60 |
+
|
| 61 |
+
// Pointer-like operations {
|
| 62 |
+
C10_HOST_DEVICE
|
| 63 |
+
reference operator*() const {
|
| 64 |
+
return *ptr;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
C10_HOST_DEVICE
|
| 68 |
+
const value_type* operator->() const {
|
| 69 |
+
return reinterpret_cast<const value_type*>(ptr);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
C10_HOST_DEVICE
|
| 73 |
+
reference operator[](index_t idx) const {
|
| 74 |
+
return ptr[idx * stride];
|
| 75 |
+
}
|
| 76 |
+
// }
|
| 77 |
+
|
| 78 |
+
// Prefix/postfix increment/decrement {
|
| 79 |
+
C10_HOST_DEVICE
|
| 80 |
+
ConstStridedRandomAccessor& operator++() {
|
| 81 |
+
ptr += stride;
|
| 82 |
+
return *this;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
C10_HOST_DEVICE
|
| 86 |
+
ConstStridedRandomAccessor operator++(int) {
|
| 87 |
+
ConstStridedRandomAccessor copy(*this);
|
| 88 |
+
++*this;
|
| 89 |
+
return copy;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
C10_HOST_DEVICE
|
| 93 |
+
ConstStridedRandomAccessor& operator--() {
|
| 94 |
+
ptr -= stride;
|
| 95 |
+
return *this;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
C10_HOST_DEVICE
|
| 99 |
+
ConstStridedRandomAccessor operator--(int) {
|
| 100 |
+
ConstStridedRandomAccessor copy(*this);
|
| 101 |
+
--*this;
|
| 102 |
+
return copy;
|
| 103 |
+
}
|
| 104 |
+
// }
|
| 105 |
+
|
| 106 |
+
// Arithmetic operations {
|
| 107 |
+
C10_HOST_DEVICE
|
| 108 |
+
ConstStridedRandomAccessor& operator+=(index_t offset) {
|
| 109 |
+
ptr += offset * stride;
|
| 110 |
+
return *this;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
C10_HOST_DEVICE
|
| 114 |
+
ConstStridedRandomAccessor operator+(index_t offset) const {
|
| 115 |
+
return ConstStridedRandomAccessor(ptr + offset * stride, stride);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
C10_HOST_DEVICE
|
| 119 |
+
friend ConstStridedRandomAccessor operator+(
|
| 120 |
+
index_t offset,
|
| 121 |
+
const ConstStridedRandomAccessor& accessor
|
| 122 |
+
) {
|
| 123 |
+
return accessor + offset;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
C10_HOST_DEVICE
|
| 127 |
+
ConstStridedRandomAccessor& operator-=(index_t offset) {
|
| 128 |
+
ptr -= offset * stride;
|
| 129 |
+
return *this;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
C10_HOST_DEVICE
|
| 133 |
+
ConstStridedRandomAccessor operator-(index_t offset) const {
|
| 134 |
+
return ConstStridedRandomAccessor(ptr - offset * stride, stride);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
// Note that this operator is well-defined when `this` and `other`
|
| 138 |
+
// represent the same sequences, i.e. when
|
| 139 |
+
// 1. this.stride == other.stride,
|
| 140 |
+
// 2. |other - this| / this.stride is an Integer.
|
| 141 |
+
C10_HOST_DEVICE
|
| 142 |
+
difference_type operator-(const ConstStridedRandomAccessor& other) const {
|
| 143 |
+
return (ptr - other.ptr) / stride;
|
| 144 |
+
}
|
| 145 |
+
// }
|
| 146 |
+
|
| 147 |
+
// Comparison operators {
|
| 148 |
+
C10_HOST_DEVICE
|
| 149 |
+
bool operator==(const ConstStridedRandomAccessor& other) const {
|
| 150 |
+
return (ptr == other.ptr) && (stride == other.stride);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
C10_HOST_DEVICE
|
| 154 |
+
bool operator!=(const ConstStridedRandomAccessor& other) const {
|
| 155 |
+
return !(*this == other);
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
C10_HOST_DEVICE
|
| 159 |
+
bool operator<(const ConstStridedRandomAccessor& other) const {
|
| 160 |
+
return ptr < other.ptr;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
C10_HOST_DEVICE
|
| 164 |
+
bool operator<=(const ConstStridedRandomAccessor& other) const {
|
| 165 |
+
return (*this < other) || (*this == other);
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
C10_HOST_DEVICE
|
| 169 |
+
bool operator>(const ConstStridedRandomAccessor& other) const {
|
| 170 |
+
return !(*this <= other);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
C10_HOST_DEVICE
|
| 174 |
+
bool operator>=(const ConstStridedRandomAccessor& other) const {
|
| 175 |
+
return !(*this < other);
|
| 176 |
+
}
|
| 177 |
+
// }
|
| 178 |
+
|
| 179 |
+
protected:
|
| 180 |
+
PtrType ptr;
|
| 181 |
+
index_t stride;
|
| 182 |
+
};
|
| 183 |
+
|
| 184 |
+
template <
|
| 185 |
+
typename T,
|
| 186 |
+
typename index_t = int64_t,
|
| 187 |
+
template <typename U> class PtrTraits = DefaultPtrTraits
|
| 188 |
+
>
|
| 189 |
+
class StridedRandomAccessor
|
| 190 |
+
: public ConstStridedRandomAccessor<T, index_t, PtrTraits> {
|
| 191 |
+
public:
|
| 192 |
+
using difference_type = index_t;
|
| 193 |
+
using value_type = T;
|
| 194 |
+
using pointer = typename PtrTraits<T>::PtrType;
|
| 195 |
+
using reference = value_type&;
|
| 196 |
+
|
| 197 |
+
using BaseType = ConstStridedRandomAccessor<T, index_t, PtrTraits>;
|
| 198 |
+
using PtrType = typename PtrTraits<T>::PtrType;
|
| 199 |
+
|
| 200 |
+
// Constructors {
|
| 201 |
+
C10_HOST_DEVICE
|
| 202 |
+
StridedRandomAccessor(PtrType ptr, index_t stride)
|
| 203 |
+
: BaseType(ptr, stride)
|
| 204 |
+
{}
|
| 205 |
+
|
| 206 |
+
C10_HOST_DEVICE
|
| 207 |
+
explicit StridedRandomAccessor(PtrType ptr)
|
| 208 |
+
: BaseType(ptr)
|
| 209 |
+
{}
|
| 210 |
+
|
| 211 |
+
C10_HOST_DEVICE
|
| 212 |
+
StridedRandomAccessor()
|
| 213 |
+
: BaseType()
|
| 214 |
+
{}
|
| 215 |
+
// }
|
| 216 |
+
|
| 217 |
+
// Pointer-like operations {
|
| 218 |
+
C10_HOST_DEVICE
|
| 219 |
+
reference operator*() const {
|
| 220 |
+
return *this->ptr;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
C10_HOST_DEVICE
|
| 224 |
+
value_type* operator->() const {
|
| 225 |
+
return reinterpret_cast<value_type*>(this->ptr);
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
C10_HOST_DEVICE
|
| 229 |
+
reference operator[](index_t idx) const {
|
| 230 |
+
return this->ptr[idx * this->stride];
|
| 231 |
+
}
|
| 232 |
+
// }
|
| 233 |
+
|
| 234 |
+
// Prefix/postfix increment/decrement {
|
| 235 |
+
C10_HOST_DEVICE
|
| 236 |
+
StridedRandomAccessor& operator++() {
|
| 237 |
+
this->ptr += this->stride;
|
| 238 |
+
return *this;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
C10_HOST_DEVICE
|
| 242 |
+
StridedRandomAccessor operator++(int) {
|
| 243 |
+
StridedRandomAccessor copy(*this);
|
| 244 |
+
++*this;
|
| 245 |
+
return copy;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
C10_HOST_DEVICE
|
| 249 |
+
StridedRandomAccessor& operator--() {
|
| 250 |
+
this->ptr -= this->stride;
|
| 251 |
+
return *this;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
C10_HOST_DEVICE
|
| 255 |
+
StridedRandomAccessor operator--(int) {
|
| 256 |
+
StridedRandomAccessor copy(*this);
|
| 257 |
+
--*this;
|
| 258 |
+
return copy;
|
| 259 |
+
}
|
| 260 |
+
// }
|
| 261 |
+
|
| 262 |
+
// Arithmetic operations {
|
| 263 |
+
C10_HOST_DEVICE
|
| 264 |
+
StridedRandomAccessor& operator+=(index_t offset) {
|
| 265 |
+
this->ptr += offset * this->stride;
|
| 266 |
+
return *this;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
C10_HOST_DEVICE
|
| 270 |
+
StridedRandomAccessor operator+(index_t offset) const {
|
| 271 |
+
return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride);
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
C10_HOST_DEVICE
|
| 275 |
+
friend StridedRandomAccessor operator+(
|
| 276 |
+
index_t offset,
|
| 277 |
+
const StridedRandomAccessor& accessor
|
| 278 |
+
) {
|
| 279 |
+
return accessor + offset;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
C10_HOST_DEVICE
|
| 283 |
+
StridedRandomAccessor& operator-=(index_t offset) {
|
| 284 |
+
this->ptr -= offset * this->stride;
|
| 285 |
+
return *this;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
C10_HOST_DEVICE
|
| 289 |
+
StridedRandomAccessor operator-(index_t offset) const {
|
| 290 |
+
return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride);
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
// Note that here we call BaseType::operator- version
|
| 294 |
+
C10_HOST_DEVICE
|
| 295 |
+
difference_type operator-(const BaseType& other) const {
|
| 296 |
+
return (static_cast<const BaseType&>(*this) - other);
|
| 297 |
+
}
|
| 298 |
+
// }
|
| 299 |
+
};
|
| 300 |
+
|
| 301 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/native/IndexingUtils.h>
|
| 4 |
+
#include <ATen/native/TensorIterator.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
namespace {
|
| 8 |
+
static std::string shapes_as_str(TensorList tensors) {
|
| 9 |
+
std::ostringstream os;
|
| 10 |
+
bool first = true;
|
| 11 |
+
for (auto& tensor : tensors) {
|
| 12 |
+
if (tensor.defined()) {
|
| 13 |
+
if (!first) {
|
| 14 |
+
os << ", ";
|
| 15 |
+
}
|
| 16 |
+
os << tensor.sizes();
|
| 17 |
+
first = false;
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
return os.str();
|
| 21 |
+
}
|
| 22 |
+
} // anonymous namespace
|
| 23 |
+
|
| 24 |
+
static std::tuple<bool, Tensor> canDispatchToMaskedFill(const Tensor& self, const torch::List<c10::optional<at::Tensor>>& indices,
|
| 25 |
+
const Tensor& value){
|
| 26 |
+
if (!(value.numel() ==1 && value.device().is_cpu())){
|
| 27 |
+
return std::make_tuple(false,Tensor());
|
| 28 |
+
}
|
| 29 |
+
int64_t num_ind = 0;
|
| 30 |
+
Tensor mask;
|
| 31 |
+
auto self_device = self.device();
|
| 32 |
+
for (const c10::optional<Tensor>& i: indices) {
|
| 33 |
+
if (!i.has_value() || !(*i).defined()){
|
| 34 |
+
num_ind++;
|
| 35 |
+
} else {
|
| 36 |
+
const Tensor &index = *i;
|
| 37 |
+
if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
|
| 38 |
+
index.device() != self_device || mask.defined()){
|
| 39 |
+
return std::make_tuple(false, Tensor());
|
| 40 |
+
} else {
|
| 41 |
+
mask = index;
|
| 42 |
+
for (const auto j : c10::irange(index.dim())) {
|
| 43 |
+
int64_t srcIdx = num_ind + j;
|
| 44 |
+
TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(), " at index ", j,
|
| 45 |
+
" does not match the shape of the indexed tensor ", self.sizes(), " at index ", srcIdx);
|
| 46 |
+
}
|
| 47 |
+
num_ind += mask.ndimension();
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
for (C10_UNUSED const auto i : c10::irange(num_ind, self.ndimension())) {
|
| 52 |
+
mask = mask.unsqueeze(-1);
|
| 53 |
+
}
|
| 54 |
+
return std::make_tuple(true, mask);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
static AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
|
| 58 |
+
checkIndexTensorTypes(orig, /*allow_int*/ true);
|
| 59 |
+
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
| 60 |
+
auto indices = expandTensors(self, orig);
|
| 61 |
+
// next broadcast all index tensors together
|
| 62 |
+
try {
|
| 63 |
+
indices = expand_outplace(indices);
|
| 64 |
+
} catch (std::exception& e) {
|
| 65 |
+
TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
|
| 66 |
+
" with shapes ", shapes_as_str(indices));
|
| 67 |
+
}
|
| 68 |
+
// add missing null Tensors so that it matches self.dim()
|
| 69 |
+
while (indices.size() < (size_t)self.dim()) {
|
| 70 |
+
indices.emplace_back();
|
| 71 |
+
}
|
| 72 |
+
// if the non-null indices are not all adjacent, transpose self and indices
|
| 73 |
+
// together so that they're adjacent at the front
|
| 74 |
+
if (!hasContiguousSubspace(indices)) {
|
| 75 |
+
std::tie(self, indices) = transposeToFront(self, indices);
|
| 76 |
+
}
|
| 77 |
+
// Ensure indices are on the same device as self
|
| 78 |
+
for (auto & indice : indices) {
|
| 79 |
+
if (indice.defined() && indice.device() != self.device()) {
|
| 80 |
+
indice = indice.to(self.device());
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
for (auto & indice : indices) {
|
| 84 |
+
if (indice.defined() && indice.dtype() == at::kInt) {
|
| 85 |
+
indice = indice.to(at::kLong);
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
return AdvancedIndex(self, indices);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorDimApply.h
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
//input tensors are non-zero dim and non-empty
|
| 7 |
+
template<typename T1, typename T2, typename Function>
|
| 8 |
+
|
| 9 |
+
void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) {
|
| 10 |
+
int ndims = self.dim();
|
| 11 |
+
int tensor_dim_apply_has_finished = 0;
|
| 12 |
+
std::vector<int64_t> counter(ndims, 0);
|
| 13 |
+
const T1* self_data = self.const_data_ptr<T1>();
|
| 14 |
+
T1* values_data = values.data_ptr<T1>();
|
| 15 |
+
T2* indices_data = indices.data_ptr<T2>();
|
| 16 |
+
int64_t self_stride = self.stride(dim);
|
| 17 |
+
int64_t values_stride = values.stride(dim);
|
| 18 |
+
int64_t indices_stride = indices.stride(dim);
|
| 19 |
+
int self_dim_size = self.size(dim);
|
| 20 |
+
|
| 21 |
+
while (!tensor_dim_apply_has_finished) {
|
| 22 |
+
func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride);
|
| 23 |
+
if (ndims == 1) {
|
| 24 |
+
break;
|
| 25 |
+
}
|
| 26 |
+
for (const auto dim_i : c10::irange(ndims)) {
|
| 27 |
+
if (dim_i == dim) {
|
| 28 |
+
if (dim_i == (ndims - 1)) {
|
| 29 |
+
tensor_dim_apply_has_finished = 1;
|
| 30 |
+
break;
|
| 31 |
+
}
|
| 32 |
+
continue;
|
| 33 |
+
}
|
| 34 |
+
counter[dim_i]++;
|
| 35 |
+
self_data += self.stride(dim_i);
|
| 36 |
+
values_data += values.stride(dim_i);
|
| 37 |
+
indices_data += indices.stride(dim_i);
|
| 38 |
+
|
| 39 |
+
if (counter[dim_i] == self.size(dim_i)) {
|
| 40 |
+
if (dim_i == ndims-1) {
|
| 41 |
+
tensor_dim_apply_has_finished = 1;
|
| 42 |
+
break;
|
| 43 |
+
} else {
|
| 44 |
+
self_data -= counter[dim_i]*self.stride(dim_i);
|
| 45 |
+
values_data -= counter[dim_i]*values.stride(dim_i);
|
| 46 |
+
indices_data -= counter[dim_i]*indices.stride(dim_i);
|
| 47 |
+
counter[dim_i] = 0;
|
| 48 |
+
}
|
| 49 |
+
} else {
|
| 50 |
+
break;
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorFactories.h
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/EmptyTensor.h>
|
| 5 |
+
#include <ATen/TensorIterator.h>
|
| 6 |
+
#include <ATen/Dispatch.h>
|
| 7 |
+
#include <ATen/Dispatch_v2.h>
|
| 8 |
+
#include <ATen/native/DispatchStub.h>
|
| 9 |
+
|
| 10 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 11 |
+
#include <ATen/Functions.h>
|
| 12 |
+
#else
|
| 13 |
+
#include <ATen/ops/scalar_tensor.h>
|
| 14 |
+
#endif
|
| 15 |
+
|
| 16 |
+
namespace at::native {
|
| 17 |
+
// Different combinations of row, col, and offset can lead to two cases:
|
| 18 |
+
//
|
| 19 |
+
// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
|
| 20 |
+
// Example A: offset > 0
|
| 21 |
+
// 1 1 0 0 0
|
| 22 |
+
// 1 1 1 0 0
|
| 23 |
+
// 1 1 1 1 0
|
| 24 |
+
// Example B: offset <= 0
|
| 25 |
+
// 0 0 0
|
| 26 |
+
// 1 0 0
|
| 27 |
+
// 1 1 0
|
| 28 |
+
// In this case, we calculate the number of elements in the first row and
|
| 29 |
+
// last row of the tril respectively, and then compute the tril size.
|
| 30 |
+
//
|
| 31 |
+
// Case 2 - Trapezoid + Rectangle: row + offset > col
|
| 32 |
+
// Example:
|
| 33 |
+
// 1 1 0
|
| 34 |
+
// 1 1 1
|
| 35 |
+
// 1 1 1
|
| 36 |
+
// In this case, we first calculate the size of top trapezoid, and then
|
| 37 |
+
// calculate the size of the bottom rectangle.
|
| 38 |
+
inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
|
| 39 |
+
// If either dimension is 0 then the there is no tril
|
| 40 |
+
if (row == 0 || col == 0) {
|
| 41 |
+
return 0;
|
| 42 |
+
}
|
| 43 |
+
// number of elements in the first row of the tril
|
| 44 |
+
auto m_first_row = offset > 0 ?
|
| 45 |
+
std::min<int64_t>(col, 1 + offset) : // upper bounded by col
|
| 46 |
+
row + offset > 0; // either 0 or 1
|
| 47 |
+
// number of elements in the last row of the tril, bounded by [0, col]
|
| 48 |
+
auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
|
| 49 |
+
// number of rows, bounded by [0, row]
|
| 50 |
+
auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
|
| 51 |
+
auto n_row_trapezoid = (m_last_row - m_first_row + 1);
|
| 52 |
+
|
| 53 |
+
// calculate # of elements in the top trapezoid
|
| 54 |
+
auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1;
|
| 55 |
+
|
| 56 |
+
// calculate # of elements in the bottom rectangle if there is any
|
| 57 |
+
auto diff_row = n_row_all - n_row_trapezoid;
|
| 58 |
+
if (diff_row > 0) {
|
| 59 |
+
tril_size += diff_row * col;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
return tril_size;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
inline void check_args(
|
| 66 |
+
int64_t row, int64_t col, c10::optional<Layout> layout_opt) {
|
| 67 |
+
TORCH_CHECK(row >= 0, "row must be non-negative, got", row);
|
| 68 |
+
TORCH_CHECK(col >= 0, "col must be non-negative, got", col);
|
| 69 |
+
if (layout_opt.has_value()) {
|
| 70 |
+
TORCH_CHECK(
|
| 71 |
+
*layout_opt == at::kStrided,
|
| 72 |
+
"only support layout=torch.strided, got",
|
| 73 |
+
*layout_opt)
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
using at::check_size_nonnegative;
|
| 78 |
+
|
| 79 |
+
// assumes maximum value in created tensor is n-1 (e.g., torch.randperm(n))
|
| 80 |
+
inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
|
| 81 |
+
// match defined() to behavior of checks below
|
| 82 |
+
TORCH_CHECK(at::scalar_tensor(n>0?n-1:n, tensor.options()).defined(),
|
| 83 |
+
"n is too large for result tensor type: '", tensor.toString(), "'");
|
| 84 |
+
|
| 85 |
+
// Ensure sufficient precision for floating point representation.
|
| 86 |
+
switch (tensor.scalar_type()) {
|
| 87 |
+
case at::ScalarType::Half:
|
| 88 |
+
TORCH_CHECK(n <= (int64_t(1) << 11) + 1, "n cannot be greater than 2049 for Half type.");
|
| 89 |
+
break;
|
| 90 |
+
case at::ScalarType::Float:
|
| 91 |
+
TORCH_CHECK(n <= (int64_t(1) << 24) + 1, "n cannot be greater than 2^24+1 for Float type.");
|
| 92 |
+
break;
|
| 93 |
+
case at::ScalarType::Double: // Unlikely to happen, but doesn't hurt to check
|
| 94 |
+
TORCH_CHECK(n <= (int64_t(1) << 53) + 1, "n cannot be greater than 2^53+1 for Double type.");
|
| 95 |
+
break;
|
| 96 |
+
default:
|
| 97 |
+
break;
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// Called by `empty*` functions when deterministic algorithms are enabled to
|
| 102 |
+
// fill the tensor with NaN if it is floating point or complex type, or fill
|
| 103 |
+
// with max value if it is integer type
|
| 104 |
+
inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
|
| 105 |
+
if (tensor.is_floating_point() || tensor.is_complex()) {
|
| 106 |
+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
| 107 |
+
kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
|
| 108 |
+
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
|
| 109 |
+
});
|
| 110 |
+
} else {
|
| 111 |
+
AT_DISPATCH_V2(
|
| 112 |
+
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
|
| 113 |
+
tensor.fill_(std::numeric_limits<scalar_t>::max());
|
| 114 |
+
}), kBool, AT_EXPAND(AT_INTEGRAL_TYPES_V2));
|
| 115 |
+
}
|
| 116 |
+
return tensor;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// The ZeroTensor allocator ignores whatever allocation is requested and always
|
| 120 |
+
// gives you nullptr
|
| 121 |
+
struct ZeroTensorAllocator final : public at::Allocator {
|
| 122 |
+
ZeroTensorAllocator(at::Device device) : device_(device) {};
|
| 123 |
+
~ZeroTensorAllocator() override = default;
|
| 124 |
+
static void deleter(void* const pointer) {
|
| 125 |
+
TORCH_INTERNAL_ASSERT(!pointer);
|
| 126 |
+
}
|
| 127 |
+
DataPtr allocate(const size_t /*nbytes*/) override {
|
| 128 |
+
return {nullptr, nullptr, &deleter, device_};
|
| 129 |
+
}
|
| 130 |
+
DeleterFnPtr raw_deleter() const override {
|
| 131 |
+
return deleter;
|
| 132 |
+
}
|
| 133 |
+
void copy_data(void* dest, const void* src, std::size_t count) const final {}
|
| 134 |
+
at::Device device_;
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
using binary_fn = void (*)(TensorIterator&);
|
| 138 |
+
|
| 139 |
+
DECLARE_DISPATCH(binary_fn, complex_stub);
|
| 140 |
+
DECLARE_DISPATCH(binary_fn, polar_stub);
|
| 141 |
+
|
| 142 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <complex>
|
| 4 |
+
#include <type_traits>
|
| 5 |
+
#include <c10/core/ScalarType.h>
|
| 6 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 7 |
+
#include <ATen/native/TensorIterator.h>
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
// This file includes utilities for dynamic_casting done by TensorIterator, see CUDALoops.cuh and Loops.h.
|
| 11 |
+
|
| 12 |
+
// dynamic_casting handles when the types expected by the iterator do not match the types of the arguments
|
| 13 |
+
// to the function that is being called.
|
| 14 |
+
// On CUDA, the cast is currently pushed down into the kernel (for performance reasons).
|
| 15 |
+
// On CPU, there is currently an internal assert that a dynamic_cast is not needed.
|
| 16 |
+
|
| 17 |
+
namespace at::native {
|
| 18 |
+
|
| 19 |
+
// `needs_dynamic_casting` compares the types expected by iterator
|
| 20 |
+
// (i.e. dtypes of the operands) with the actual type of the arguments
|
| 21 |
+
// (and returns) of func_t
|
| 22 |
+
template<typename func_t, int nargs=function_traits<func_t>::arity>
|
| 23 |
+
struct needs_dynamic_casting {
|
| 24 |
+
static bool check(TensorIteratorBase& iter) {
|
| 25 |
+
using traits = function_traits<func_t>;
|
| 26 |
+
using cpp_type = typename traits::template arg<nargs - 1>::type;
|
| 27 |
+
using cpp_map = c10::CppTypeToScalarType<cpp_type>;
|
| 28 |
+
|
| 29 |
+
if (iter.input_dtype(nargs-1) != cpp_map::value) {
|
| 30 |
+
return true;
|
| 31 |
+
}
|
| 32 |
+
return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
|
| 33 |
+
}
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
template<typename func_t>
|
| 37 |
+
struct needs_dynamic_casting<func_t, 0> {
|
| 38 |
+
static bool check(TensorIteratorBase& iter) {
|
| 39 |
+
using traits = function_traits<func_t>;
|
| 40 |
+
using cpp_type = typename traits::result_type;
|
| 41 |
+
|
| 42 |
+
// we could assert output numbers are correct here, but checks
|
| 43 |
+
// (including arity) are currently pushed outside of this struct.
|
| 44 |
+
if constexpr (std::is_void_v<cpp_type>) {
|
| 45 |
+
return false;
|
| 46 |
+
} else {
|
| 47 |
+
return iter.dtype(0) != c10::CppTypeToScalarType<cpp_type>::value;
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
} //namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorProperties.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// See NOTE: [Tensor vs. TensorBase]
|
| 4 |
+
namespace at {
|
| 5 |
+
class TensorBase;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
TORCH_API bool cudnn_is_acceptable(const TensorBase& self);
|
| 11 |
+
|
| 12 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorShape.h
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
#include <ATen/core/IListRef.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
|
| 9 |
+
|
| 10 |
+
inline bool cat_should_skip_tensor(const Tensor& t) {
|
| 11 |
+
return t.numel() == 0 && t.dim() == 1;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
// Check to see if the shape of tensors is compatible
|
| 15 |
+
// for being concatenated along a given dimension.
|
| 16 |
+
inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
|
| 17 |
+
int64_t first_dims = first.dim();
|
| 18 |
+
int64_t second_dims = second.dim();
|
| 19 |
+
TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
|
| 20 |
+
first_dims, " and ", second_dims);
|
| 21 |
+
for (const auto dim : c10::irange(first_dims)) {
|
| 22 |
+
if (dim == dimension) {
|
| 23 |
+
continue;
|
| 24 |
+
}
|
| 25 |
+
int64_t first_dim_size = first.sizes()[dim];
|
| 26 |
+
int64_t second_dim_size = second.sizes()[dim];
|
| 27 |
+
TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
|
| 28 |
+
dimension, ". Expected size ", static_cast<long long>(first_dim_size), " but got size ", static_cast<long long>(second_dim_size), " for tensor number ", index, " in the list.");
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) {
|
| 33 |
+
int64_t i = 0;
|
| 34 |
+
for(const Tensor& t : tensors) {
|
| 35 |
+
TORCH_CHECK(t.dim() > 0,
|
| 36 |
+
"zero-dimensional tensor (at position ", i, ") cannot be concatenated");
|
| 37 |
+
i++;
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t dim) {
|
| 42 |
+
TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
|
| 43 |
+
TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
|
| 44 |
+
int64_t dim_size = self.size(dim);
|
| 45 |
+
TORCH_CHECK(split_size > 0 || dim_size == 0,
|
| 46 |
+
"split_size can only be 0 if dimension size is 0, "
|
| 47 |
+
"but got dimension size of ", dim_size);
|
| 48 |
+
// if split_size is 0 and dimension size is 0, there is 1 split.
|
| 49 |
+
int64_t num_splits = 1;
|
| 50 |
+
if (split_size != 0) {
|
| 51 |
+
// ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
|
| 52 |
+
// (returns a single split). We might want to error here, but keep it for BC.
|
| 53 |
+
num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
|
| 54 |
+
}
|
| 55 |
+
return num_splits;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
inline bool have_same_ndims(TensorList tensors) {
|
| 59 |
+
auto ndim = tensors[0].dim();
|
| 60 |
+
for (const auto tensor_idx : c10::irange(tensors.size())) {
|
| 61 |
+
if(tensors[tensor_idx].dim() != ndim) {
|
| 62 |
+
return false;
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
return true;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
|
| 69 |
+
auto tensor_zero_size = tensors[0].sizes();
|
| 70 |
+
std::vector<c10::SymInt> leading_dim_sizes(tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
|
| 71 |
+
for (const auto i : c10::irange(tensors.size())) {
|
| 72 |
+
at::Tensor tensor = tensors[i];
|
| 73 |
+
for(const auto j : c10::irange(dim)) {
|
| 74 |
+
TORCH_CHECK(
|
| 75 |
+
tensor.size(j) == leading_dim_sizes[j],
|
| 76 |
+
"_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"
|
| 77 |
+
);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
inline int64_t preprocess_chunk_cat_inputs(TensorList tensors, int64_t dim, int64_t num_chunks) {
|
| 83 |
+
TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
|
| 84 |
+
TORCH_CHECK(!tensors.empty(),
|
| 85 |
+
"_chunk_cat expects a non-empty input tensor list");
|
| 86 |
+
auto expected_dtype = tensors[0].dtype();
|
| 87 |
+
auto expected_device = tensors[0].device();
|
| 88 |
+
for(const auto i : c10::irange(tensors.size())) {
|
| 89 |
+
TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
|
| 90 |
+
TORCH_CHECK(tensors[i].dtype() == expected_dtype, "_chunk_cat expects all input tensors with the same dtype");
|
| 91 |
+
TORCH_CHECK(tensors[i].device() == expected_device, "_chunk_cat expects all inputs tensors on the same device");
|
| 92 |
+
}
|
| 93 |
+
if (have_same_ndims(tensors)) {
|
| 94 |
+
dim = maybe_wrap_dim(dim, tensors[0].dim());
|
| 95 |
+
} else {
|
| 96 |
+
TORCH_CHECK(dim >= 0, "_chunk_cat expects non-negative dim when input tensors have different ndims")
|
| 97 |
+
for(const auto i : c10::irange(tensors.size())) {
|
| 98 |
+
TORCH_CHECK(dim < tensors[i].ndimension(), "_chunk_cat expects dim < ndim for all input tensors");
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
leading_dimension_matches(tensors, dim);
|
| 102 |
+
return dim;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TopKImpl.h
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/TensorAccessor.h>
|
| 3 |
+
#include <ATen/NumericUtils.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
#ifdef CPU_CAPABILITY
|
| 8 |
+
inline namespace CPU_CAPABILITY {
|
| 9 |
+
#else
|
| 10 |
+
inline namespace DEFAULT {
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
// Core topk loop, shared between CPU and QuantizedCPU
|
| 14 |
+
template <typename scalar_t, typename accscalar_t>
|
| 15 |
+
void topk_impl_loop(
|
| 16 |
+
const int64_t mode_values_stride,
|
| 17 |
+
const int64_t mode_indices_stride,
|
| 18 |
+
const int64_t tmp_values_stride,
|
| 19 |
+
const int64_t k,
|
| 20 |
+
const int64_t dim_size,
|
| 21 |
+
const bool largest,
|
| 22 |
+
const bool sorted,
|
| 23 |
+
char** data, const int64_t* strides, const int64_t n) {
|
| 24 |
+
|
| 25 |
+
// If k is zero, then output values and indices are empty tensors
|
| 26 |
+
// So iterating over other dims is pointless
|
| 27 |
+
if (k == 0) {
|
| 28 |
+
return;
|
| 29 |
+
}
|
| 30 |
+
using elem_t = std::pair<accscalar_t, int64_t>;
|
| 31 |
+
std::vector<elem_t> queue(dim_size);
|
| 32 |
+
for (const auto i : c10::irange(n)) {
|
| 33 |
+
TensorAccessor<scalar_t, 1> mode_values(
|
| 34 |
+
reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
|
| 35 |
+
&k, &mode_values_stride);
|
| 36 |
+
TensorAccessor<int64_t, 1> mode_indices(
|
| 37 |
+
reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
|
| 38 |
+
&k, &mode_indices_stride);
|
| 39 |
+
TensorAccessor<const scalar_t, 1> tmp_values(
|
| 40 |
+
reinterpret_cast<scalar_t*>(data[2] + i * strides[2]),
|
| 41 |
+
&dim_size, &tmp_values_stride);
|
| 42 |
+
|
| 43 |
+
auto n_2 = dim_size;
|
| 44 |
+
auto use_partial_sort = k * 64 <= n_2;
|
| 45 |
+
|
| 46 |
+
for (const auto j : c10::irange(n_2)) {
|
| 47 |
+
queue[j].first = tmp_values[j];
|
| 48 |
+
queue[j].second = j;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// we want nan to be sorted as top for numpy compatibility
|
| 52 |
+
if (use_partial_sort) {
|
| 53 |
+
if (largest) {
|
| 54 |
+
std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
|
| 55 |
+
[](const elem_t& x, const elem_t& y) -> bool {
|
| 56 |
+
return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
|
| 57 |
+
});
|
| 58 |
+
} else {
|
| 59 |
+
std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
|
| 60 |
+
[](const elem_t& x, const elem_t& y) -> bool {
|
| 61 |
+
return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
|
| 62 |
+
});
|
| 63 |
+
}
|
| 64 |
+
} else {
|
| 65 |
+
if (largest) {
|
| 66 |
+
std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
|
| 67 |
+
[](const elem_t& x, const elem_t& y) -> bool {
|
| 68 |
+
return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
|
| 69 |
+
});
|
| 70 |
+
if (sorted) {
|
| 71 |
+
std::sort(queue.begin(), queue.begin() + k - 1,
|
| 72 |
+
[](const elem_t& x, const elem_t& y) -> bool {
|
| 73 |
+
return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
|
| 74 |
+
});
|
| 75 |
+
}
|
| 76 |
+
} else {
|
| 77 |
+
std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
|
| 78 |
+
[](const elem_t& x, const elem_t& y) -> bool {
|
| 79 |
+
return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
|
| 80 |
+
});
|
| 81 |
+
if (sorted) {
|
| 82 |
+
std::sort(queue.begin(), queue.begin() + k -1,
|
| 83 |
+
[](const elem_t& x, const elem_t& y) -> bool {
|
| 84 |
+
return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
|
| 85 |
+
});
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
for (const auto j : c10::irange(k)) {
|
| 91 |
+
mode_values[j] = queue[j].first;
|
| 92 |
+
mode_indices[j] = queue[j].second;
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
} // namespace CPU_CAPABILITY
|
| 98 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TransposeType.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/util/Exception.h>
|
| 3 |
+
|
| 4 |
+
namespace at::native {
|
| 5 |
+
|
| 6 |
+
// Used as an interface between the different BLAS-like libraries
|
| 7 |
+
enum class TransposeType {
|
| 8 |
+
NoTranspose,
|
| 9 |
+
Transpose,
|
| 10 |
+
ConjTranspose,
|
| 11 |
+
};
|
| 12 |
+
|
| 13 |
+
// Transforms TransposeType into the BLAS / LAPACK format
|
| 14 |
+
static inline char to_blas(TransposeType trans) {
|
| 15 |
+
switch (trans) {
|
| 16 |
+
case TransposeType::Transpose: return 'T';
|
| 17 |
+
case TransposeType::NoTranspose: return 'N';
|
| 18 |
+
case TransposeType::ConjTranspose: return 'C';
|
| 19 |
+
}
|
| 20 |
+
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold3d.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/ScalarType.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
void Unfold3dCopyCPU(
|
| 8 |
+
ScalarType dtype,
|
| 9 |
+
const void *src,
|
| 10 |
+
int64_t C,
|
| 11 |
+
int64_t X_D,
|
| 12 |
+
int64_t X_H,
|
| 13 |
+
int64_t X_W,
|
| 14 |
+
int64_t Y_D,
|
| 15 |
+
int64_t Y_H,
|
| 16 |
+
int64_t Y_W,
|
| 17 |
+
int64_t kernel_d,
|
| 18 |
+
int64_t kernel_h,
|
| 19 |
+
int64_t kernel_w,
|
| 20 |
+
int64_t stride_d,
|
| 21 |
+
int64_t stride_h,
|
| 22 |
+
int64_t stride_w,
|
| 23 |
+
int64_t pad_d,
|
| 24 |
+
int64_t pad_h,
|
| 25 |
+
int64_t pad_w,
|
| 26 |
+
void* dst);
|
| 27 |
+
|
| 28 |
+
void Unfold3dAccCPU(
|
| 29 |
+
ScalarType dtype,
|
| 30 |
+
const void *src,
|
| 31 |
+
int64_t C,
|
| 32 |
+
int64_t X_D,
|
| 33 |
+
int64_t X_H,
|
| 34 |
+
int64_t X_W,
|
| 35 |
+
int64_t Y_D,
|
| 36 |
+
int64_t Y_H,
|
| 37 |
+
int64_t Y_W,
|
| 38 |
+
int64_t kernel_d,
|
| 39 |
+
int64_t kernel_h,
|
| 40 |
+
int64_t kernel_w,
|
| 41 |
+
int64_t stride_d,
|
| 42 |
+
int64_t stride_h,
|
| 43 |
+
int64_t stride_w,
|
| 44 |
+
int64_t pad_d,
|
| 45 |
+
int64_t pad_h,
|
| 46 |
+
int64_t pad_w,
|
| 47 |
+
void *dst);
|
| 48 |
+
|
| 49 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnfoldBackward.h
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/TensorIterator.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
#include <ATen/native/NonEmptyUtils.h>
|
| 7 |
+
|
| 8 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 9 |
+
#include <ATen/Functions.h>
|
| 10 |
+
#else
|
| 11 |
+
#include <ATen/ops/arange.h>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
|
| 16 |
+
using unfold_backward_fn = void (*)(
|
| 17 |
+
Tensor& grad_in,
|
| 18 |
+
const Tensor& grad,
|
| 19 |
+
int64_t dim,
|
| 20 |
+
int64_t size,
|
| 21 |
+
int64_t step
|
| 22 |
+
);
|
| 23 |
+
|
| 24 |
+
DECLARE_DISPATCH(unfold_backward_fn, unfold_backward_stub);
|
| 25 |
+
|
| 26 |
+
namespace {
|
| 27 |
+
|
| 28 |
+
// Note on naming: it is unconventional.
|
| 29 |
+
// grad_in does not mean that it is a gradient wrt to input,
|
| 30 |
+
// grad_in/grad_out is just an input/output of unfold_backward kernel.
|
| 31 |
+
|
| 32 |
+
static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_out(
|
| 33 |
+
Tensor& grad_out,
|
| 34 |
+
const Tensor& grad_in,
|
| 35 |
+
int64_t dim,
|
| 36 |
+
int64_t size,
|
| 37 |
+
int64_t step
|
| 38 |
+
) {
|
| 39 |
+
dim = maybe_wrap_dim(dim, grad_out.dim());
|
| 40 |
+
// last dim stores the folds
|
| 41 |
+
|
| 42 |
+
auto grad_out_dim_size = ensure_nonempty_size(grad_out, dim);
|
| 43 |
+
auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
|
| 44 |
+
// dictates the number of elements to iterate over
|
| 45 |
+
// in dimension `dim`
|
| 46 |
+
auto iter_dim_size = std::min(
|
| 47 |
+
grad_out_dim_size,
|
| 48 |
+
(grad_in_dim_size - 1) * step + size
|
| 49 |
+
);
|
| 50 |
+
|
| 51 |
+
/* prepare grad_out for TensorIterator { */
|
| 52 |
+
auto grad_out_strides = ensure_nonempty_vec(grad_out.strides().vec());
|
| 53 |
+
auto grad_out_sizes = ensure_nonempty_vec(grad_out.sizes().vec());
|
| 54 |
+
grad_out_sizes[dim] = iter_dim_size;
|
| 55 |
+
auto grad_out_restrided = grad_out.as_strided(
|
| 56 |
+
grad_out_sizes, grad_out_strides
|
| 57 |
+
);
|
| 58 |
+
/* } */
|
| 59 |
+
|
| 60 |
+
/* prepare grad_in for TensorIterator { */
|
| 61 |
+
auto grad_in_strides = ensure_nonempty_vec(grad_in.strides().vec());
|
| 62 |
+
auto grad_in_sizes = ensure_nonempty_vec(grad_in.sizes().vec());
|
| 63 |
+
|
| 64 |
+
// set strides for dim to 0
|
| 65 |
+
// and size to 1 because
|
| 66 |
+
// this dimension is indexed inside the kernel
|
| 67 |
+
grad_in_strides[dim] = 0;
|
| 68 |
+
grad_in_sizes[dim] = 1;
|
| 69 |
+
|
| 70 |
+
grad_in_strides.pop_back();
|
| 71 |
+
grad_in_sizes.pop_back();
|
| 72 |
+
|
| 73 |
+
auto grad_in_restrided = grad_in.squeeze(-1).as_strided(
|
| 74 |
+
grad_in_sizes, grad_in_strides
|
| 75 |
+
);
|
| 76 |
+
/* } */
|
| 77 |
+
|
| 78 |
+
// During the TensorIterator iteration we have to know
|
| 79 |
+
// i_dim in grad_out[i_1,...,i_dim,...i_n],
|
| 80 |
+
// idx_dim stores this information
|
| 81 |
+
/* prepare idx_dim for TensorIterator { */
|
| 82 |
+
auto idx_dim = at::arange(
|
| 83 |
+
0, iter_dim_size, grad_in.options().dtype(at::kLong)
|
| 84 |
+
);
|
| 85 |
+
|
| 86 |
+
auto grad_out_dim = ensure_nonempty_dim(grad_out.dim());
|
| 87 |
+
|
| 88 |
+
auto idx_dim_strides = std::vector<int64_t>(grad_out_dim, 0);
|
| 89 |
+
auto idx_dim_sizes = std::vector<int64_t>(grad_out_dim, 1);
|
| 90 |
+
|
| 91 |
+
idx_dim_strides[dim] = 1;
|
| 92 |
+
idx_dim_sizes[dim] = iter_dim_size;
|
| 93 |
+
|
| 94 |
+
// idx_dim size will broadcast over determined by grad_out sizes in TensorIterator
|
| 95 |
+
auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);
|
| 96 |
+
/* } */
|
| 97 |
+
|
| 98 |
+
auto iter = TensorIteratorConfig()
|
| 99 |
+
.set_check_mem_overlap(false)
|
| 100 |
+
.check_all_same_dtype(false)
|
| 101 |
+
.resize_outputs(false)
|
| 102 |
+
.add_owned_output(grad_out_restrided)
|
| 103 |
+
.add_owned_input(grad_in_restrided)
|
| 104 |
+
.add_owned_input(idx_dim_restrided)
|
| 105 |
+
.build();
|
| 106 |
+
|
| 107 |
+
return iter;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UpSample.h
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <math.h>
|
| 4 |
+
|
| 5 |
+
#include <ATen/OpMathType.h>
|
| 6 |
+
#include <ATen/TensorUtils.h>
|
| 7 |
+
#include <ATen/OpMathType.h>
|
| 8 |
+
#include <ATen/core/Tensor.h>
|
| 9 |
+
#include <ATen/cpu/vec/functional.h>
|
| 10 |
+
#include <ATen/cpu/vec/vec.h>
|
| 11 |
+
#include <ATen/native/DispatchStub.h>
|
| 12 |
+
#include <ATen/native/cpu/utils.h>
|
| 13 |
+
|
| 14 |
+
/**
|
| 15 |
+
* Note [compute_scales_value]
|
| 16 |
+
* Note [area_pixel_compute_scale]
|
| 17 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 18 |
+
* Interpolate with scale_factor can have different behaviors
|
| 19 |
+
* depending on the value of recompute_scale_factor:
|
| 20 |
+
*
|
| 21 |
+
* - With recompute_scale_factor = True (current default behavior):
|
| 22 |
+
* the scale_factor, when provided by the user, are used to calculate
|
| 23 |
+
* the output size. The input size and the computed output_size
|
| 24 |
+
* are then used to infer new values for the scales which are
|
| 25 |
+
* used in the interpolation. Because floating-point math is not exact,
|
| 26 |
+
* this may be a different value from the user-supplied scales.
|
| 27 |
+
*
|
| 28 |
+
* - With recompute_scale_factor = False (which will be the default
|
| 29 |
+
* behavior starting 1.5.0):
|
| 30 |
+
* the behavior follows opencv logic, and the scales provided by
|
| 31 |
+
* the user are the ones used in the interpolation calculations.
|
| 32 |
+
*
|
| 33 |
+
* If the scales are not provided or if they are provided but
|
| 34 |
+
* recompute_scale_factor is set to True (default behavior), the scales
|
| 35 |
+
* are computed from the input and the output size;
|
| 36 |
+
*
|
| 37 |
+
*
|
| 38 |
+
* When the scales are inferred from the input and output sizes,
|
| 39 |
+
* we view each pixel as an area, idx + 0.5 as its center index.
|
| 40 |
+
* Here is an example formula in 1D case.
|
| 41 |
+
* if align_corners: center of two corner pixel areas are preserved,
|
| 42 |
+
* (0.5, 0.5) -> (0.5, 0.5),
|
| 43 |
+
* (input_size - 0.5, 0.5) -> (output_size - 0.5)
|
| 44 |
+
* scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
|
| 45 |
+
* src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
|
| 46 |
+
* if not align_corners: the whole range is scaled accordingly
|
| 47 |
+
* scale = input_size / output_size
|
| 48 |
+
* src_idx + 0.5 = scale * (dst_index + 0.5)
|
| 49 |
+
*/
|
| 50 |
+
|
| 51 |
+
namespace at::native {
|
| 52 |
+
|
| 53 |
+
namespace upsample {
|
| 54 |
+
|
| 55 |
+
TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
|
| 56 |
+
c10::IntArrayRef input_size, // Full input tensor size.
|
| 57 |
+
at::OptionalIntArrayRef output_size,
|
| 58 |
+
c10::optional<c10::ArrayRef<double>> scale_factors);
|
| 59 |
+
|
| 60 |
+
inline c10::optional<double> get_scale_value(c10::optional<c10::ArrayRef<double>> scales, int idx) {
|
| 61 |
+
if (!scales) {
|
| 62 |
+
return c10::nullopt;
|
| 63 |
+
}
|
| 64 |
+
return scales->at(idx);
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
} // namespace upsample
|
| 68 |
+
|
| 69 |
+
using scale_t = c10::optional<double>;
|
| 70 |
+
using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
|
| 71 |
+
using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
|
| 72 |
+
using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
|
| 73 |
+
using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
|
| 74 |
+
using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
|
| 75 |
+
using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
|
| 76 |
+
using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w);
|
| 77 |
+
using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
|
| 78 |
+
using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
|
| 79 |
+
using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w);
|
| 80 |
+
using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
|
| 81 |
+
using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
|
| 82 |
+
DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel);
|
| 83 |
+
DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel);
|
| 84 |
+
DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel);
|
| 85 |
+
DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel);
|
| 86 |
+
DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel);
|
| 87 |
+
DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel);
|
| 88 |
+
DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel);
|
| 89 |
+
DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel);
|
| 90 |
+
DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel);
|
| 91 |
+
DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel);
|
| 92 |
+
DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel);
|
| 93 |
+
DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel);
|
| 94 |
+
DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel);
|
| 95 |
+
DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel);
|
| 96 |
+
DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel);
|
| 97 |
+
DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel);
|
| 98 |
+
DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel);
|
| 99 |
+
DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel);
|
| 100 |
+
DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel);
|
| 101 |
+
DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel);
|
| 102 |
+
DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel);
|
| 103 |
+
DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel);
|
| 104 |
+
DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel);
|
| 105 |
+
|
| 106 |
+
static C10_UNUSED std::array<int64_t, 3> upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
|
| 107 |
+
TORCH_CHECK(
|
| 108 |
+
output_size.size() == 1,
|
| 109 |
+
"It is expected output_size equals to 1, but got size ",
|
| 110 |
+
output_size.size());
|
| 111 |
+
|
| 112 |
+
TORCH_CHECK(
|
| 113 |
+
input_size.size() == 3,
|
| 114 |
+
"It is expected input_size equals to 3, but got size ",
|
| 115 |
+
input_size.size());
|
| 116 |
+
|
| 117 |
+
int64_t output_width = output_size[0];
|
| 118 |
+
|
| 119 |
+
int64_t nbatch = input_size[0];
|
| 120 |
+
int64_t channels = input_size[1];
|
| 121 |
+
int64_t input_width = input_size[2];
|
| 122 |
+
|
| 123 |
+
TORCH_CHECK(
|
| 124 |
+
input_width > 0 && output_width > 0,
|
| 125 |
+
"Input and output sizes should be greater than 0, but got input (W: ",
|
| 126 |
+
input_width,
|
| 127 |
+
") and output (W: ",
|
| 128 |
+
output_width,
|
| 129 |
+
")");
|
| 130 |
+
|
| 131 |
+
return {nbatch, channels, output_width};
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
static C10_UNUSED std::array<int64_t, 4> upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
|
| 135 |
+
TORCH_CHECK(
|
| 136 |
+
output_size.size() == 2,
|
| 137 |
+
"It is expected output_size equals to 2, but got size ",
|
| 138 |
+
output_size.size());
|
| 139 |
+
|
| 140 |
+
TORCH_CHECK(
|
| 141 |
+
input_size.size() == 4,
|
| 142 |
+
"It is expected input_size equals to 4, but got size ",
|
| 143 |
+
input_size.size());
|
| 144 |
+
|
| 145 |
+
int64_t output_height = output_size[0];
|
| 146 |
+
int64_t output_width = output_size[1];
|
| 147 |
+
|
| 148 |
+
int64_t nbatch = input_size[0];
|
| 149 |
+
int64_t channels = input_size[1];
|
| 150 |
+
int64_t input_height = input_size[2];
|
| 151 |
+
int64_t input_width = input_size[3];
|
| 152 |
+
|
| 153 |
+
TORCH_CHECK(
|
| 154 |
+
input_height > 0 && input_width > 0 && output_height > 0 &&
|
| 155 |
+
output_width > 0,
|
| 156 |
+
"Input and output sizes should be greater than 0,"
|
| 157 |
+
" but got input (H: ",
|
| 158 |
+
input_height,
|
| 159 |
+
", W: ",
|
| 160 |
+
input_width,
|
| 161 |
+
") output (H: ",
|
| 162 |
+
output_height,
|
| 163 |
+
", W: ",
|
| 164 |
+
output_width,
|
| 165 |
+
")");
|
| 166 |
+
|
| 167 |
+
return {nbatch, channels, output_height, output_width};
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
static C10_UNUSED
|
| 171 |
+
std::array<int64_t, 5> upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
|
| 172 |
+
TORCH_CHECK(
|
| 173 |
+
output_size.size() == 3,
|
| 174 |
+
"It is expected output_size equals to 3, but got size ",
|
| 175 |
+
output_size.size());
|
| 176 |
+
|
| 177 |
+
TORCH_CHECK(
|
| 178 |
+
input_size.size() == 5,
|
| 179 |
+
"It is expected input_size equals to 5, but got size ",
|
| 180 |
+
input_size.size());
|
| 181 |
+
|
| 182 |
+
int64_t output_depth = output_size[0];
|
| 183 |
+
int64_t output_height = output_size[1];
|
| 184 |
+
int64_t output_width = output_size[2];
|
| 185 |
+
|
| 186 |
+
int64_t nbatch = input_size[0];
|
| 187 |
+
int64_t channels = input_size[1];
|
| 188 |
+
int64_t input_depth = input_size[2];
|
| 189 |
+
int64_t input_height = input_size[3];
|
| 190 |
+
int64_t input_width = input_size[4];
|
| 191 |
+
|
| 192 |
+
TORCH_CHECK(
|
| 193 |
+
input_depth > 0 && input_height > 0 && input_width > 0 &&
|
| 194 |
+
output_depth > 0 && output_height > 0 && output_width > 0,
|
| 195 |
+
"Input and output sizes should be greater than 0, but got input (D: ",
|
| 196 |
+
input_depth,
|
| 197 |
+
", H: ",
|
| 198 |
+
input_height,
|
| 199 |
+
", W: ",
|
| 200 |
+
input_width,
|
| 201 |
+
") output (D: ",
|
| 202 |
+
output_depth,
|
| 203 |
+
", H: ",
|
| 204 |
+
output_height,
|
| 205 |
+
", W: ",
|
| 206 |
+
output_width,
|
| 207 |
+
")");
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
return {nbatch, channels, output_depth, output_height, output_width};
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
static inline void upsample_2d_shape_check(
|
| 214 |
+
const Tensor& input,
|
| 215 |
+
const Tensor& grad_output,
|
| 216 |
+
int64_t nbatch,
|
| 217 |
+
int64_t nchannels,
|
| 218 |
+
int64_t input_height,
|
| 219 |
+
int64_t input_width,
|
| 220 |
+
int64_t output_height,
|
| 221 |
+
int64_t output_width) {
|
| 222 |
+
TORCH_CHECK(
|
| 223 |
+
input_height > 0 && input_width > 0 && output_height > 0 &&
|
| 224 |
+
output_width > 0,
|
| 225 |
+
"Input and output sizes should be greater than 0,"
|
| 226 |
+
" but got input (H: ",
|
| 227 |
+
input_height,
|
| 228 |
+
", W: ",
|
| 229 |
+
input_width,
|
| 230 |
+
") output (H: ",
|
| 231 |
+
output_height,
|
| 232 |
+
", W: ",
|
| 233 |
+
output_width,
|
| 234 |
+
")");
|
| 235 |
+
|
| 236 |
+
if (input.defined()) {
|
| 237 |
+
// Allow for empty batch size but not other dimensions
|
| 238 |
+
TORCH_CHECK(
|
| 239 |
+
(input.numel() != 0 ||
|
| 240 |
+
(input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0)
|
| 241 |
+
) &&
|
| 242 |
+
input.dim() == 4,
|
| 243 |
+
"Non-empty 4D data tensor expected but got a tensor with sizes ",
|
| 244 |
+
input.sizes());
|
| 245 |
+
} else if (grad_output.defined()) {
|
| 246 |
+
check_dim_size(grad_output, 4, 0, nbatch);
|
| 247 |
+
check_dim_size(grad_output, 4, 1, nchannels);
|
| 248 |
+
check_dim_size(grad_output, 4, 2, output_height);
|
| 249 |
+
check_dim_size(grad_output, 4, 3, output_width);
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
template <typename scalar_t>
|
| 254 |
+
static inline scalar_t compute_scales_value(
|
| 255 |
+
const c10::optional<double> scale,
|
| 256 |
+
int64_t input_size,
|
| 257 |
+
int64_t output_size) {
|
| 258 |
+
// see Note [compute_scales_value]
|
| 259 |
+
// FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
|
| 260 |
+
return (scale.has_value() && scale.value() > 0.)
|
| 261 |
+
? static_cast<scalar_t>(1.0 / scale.value())
|
| 262 |
+
: (static_cast<scalar_t>(input_size) / output_size);
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
template <typename scalar_t>
|
| 266 |
+
static inline scalar_t area_pixel_compute_scale(
|
| 267 |
+
int64_t input_size,
|
| 268 |
+
int64_t output_size,
|
| 269 |
+
bool align_corners,
|
| 270 |
+
const c10::optional<double> scale) {
|
| 271 |
+
// see Note [area_pixel_compute_scale]
|
| 272 |
+
if(align_corners) {
|
| 273 |
+
if(output_size > 1) {
|
| 274 |
+
return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
|
| 275 |
+
} else {
|
| 276 |
+
return static_cast<scalar_t>(0);
|
| 277 |
+
}
|
| 278 |
+
} else {
|
| 279 |
+
return compute_scales_value<scalar_t>(scale, input_size, output_size);
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
template <typename scalar_t>
|
| 284 |
+
static inline scalar_t area_pixel_compute_source_index(
|
| 285 |
+
scalar_t scale,
|
| 286 |
+
int64_t dst_index,
|
| 287 |
+
bool align_corners,
|
| 288 |
+
bool cubic) {
|
| 289 |
+
if (align_corners) {
|
| 290 |
+
return scale * dst_index;
|
| 291 |
+
} else {
|
| 292 |
+
scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
|
| 293 |
+
static_cast<scalar_t>(0.5);
|
| 294 |
+
// [Note] Follow Opencv resize logic:
|
| 295 |
+
// We allow negative src_idx here and later will use
|
| 296 |
+
// dx = src_idx - floorf(src_idx)
|
| 297 |
+
// to compute the "distance"(which affects weights).
|
| 298 |
+
// For linear modes, weight distribution doesn't matter
|
| 299 |
+
// for negative indices as they use 2 pixels to interpolate.
|
| 300 |
+
// For example, [-1, 0], they both use pixel 0 value so it
|
| 301 |
+
// doesn't affect if we bound the src_idx to 0 or not.
|
| 302 |
+
// TODO: Our current linear mode impls use unbound indices
|
| 303 |
+
// where we should and then remove this cubic flag.
|
| 304 |
+
// This matters in cubic mode, as we might need [-1, 0, 1, 2]
|
| 305 |
+
// to interpolate and the weights can be affected.
|
| 306 |
+
return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
|
| 307 |
+
: src_idx;
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
static inline int64_t nearest_neighbor_compute_source_index(
|
| 312 |
+
const float scale,
|
| 313 |
+
int64_t dst_index,
|
| 314 |
+
int64_t input_size) {
|
| 315 |
+
// Index computation matching OpenCV INTER_NEAREST
|
| 316 |
+
// which is buggy and kept for BC
|
| 317 |
+
const int64_t src_index =
|
| 318 |
+
std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
|
| 319 |
+
return src_index;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
static inline int64_t nearest_neighbor_exact_compute_source_index(
|
| 323 |
+
const float scale,
|
| 324 |
+
int64_t dst_index,
|
| 325 |
+
int64_t input_size) {
|
| 326 |
+
// index_f32 = (output_index + 0.5) * scale - 0.5
|
| 327 |
+
// input_index = round(index_f32)
|
| 328 |
+
// Same as Pillow and Scikit-Image/Scipy ndi.zoom
|
| 329 |
+
const int64_t src_index =
|
| 330 |
+
std::min(static_cast<int64_t>(floorf((dst_index + 0.5) * scale)), input_size - 1);
|
| 331 |
+
return src_index;
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
static inline int64_t nearest_idx(
|
| 335 |
+
int64_t output_index,
|
| 336 |
+
int64_t input_size,
|
| 337 |
+
int64_t output_size,
|
| 338 |
+
c10::optional<double> scales) {
|
| 339 |
+
// This method specificly treats cases: output_size == input_size or
|
| 340 |
+
// output_size == 2 * input_size, that we would like to get rid of
|
| 341 |
+
// We keep this method for BC and consider as deprecated.
|
| 342 |
+
// See nearest_exact_idx as replacement
|
| 343 |
+
if (output_size == input_size) {
|
| 344 |
+
// scale_factor = 1, simply copy
|
| 345 |
+
return output_index;
|
| 346 |
+
} else if (output_size == 2 * input_size) {
|
| 347 |
+
// scale_factor = 2, shift input index
|
| 348 |
+
return output_index >> 1;
|
| 349 |
+
} else {
|
| 350 |
+
float scale = compute_scales_value<float>(scales, input_size, output_size);
|
| 351 |
+
return nearest_neighbor_compute_source_index(scale, output_index, input_size);
|
| 352 |
+
}
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
static inline int64_t nearest_exact_idx(
|
| 356 |
+
int64_t output_index,
|
| 357 |
+
int64_t input_size,
|
| 358 |
+
int64_t output_size,
|
| 359 |
+
c10::optional<double> scales) {
|
| 360 |
+
float scale = compute_scales_value<float>(scales, input_size, output_size);
|
| 361 |
+
return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size);
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
// Define a typedef to dispatch to nearest_idx or nearest_exact_idx
|
| 365 |
+
typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, c10::optional<double>);
|
| 366 |
+
|
| 367 |
+
template <typename scalar_t>
|
| 368 |
+
static scalar_t upsample_get_value_bounded(
|
| 369 |
+
scalar_t* data,
|
| 370 |
+
int64_t width,
|
| 371 |
+
int64_t height,
|
| 372 |
+
int64_t x,
|
| 373 |
+
int64_t y) {
|
| 374 |
+
int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
|
| 375 |
+
int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
|
| 376 |
+
return data[access_y * width + access_x];
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
template <typename scalar_t>
|
| 380 |
+
static void upsample_increment_value_bounded(
|
| 381 |
+
scalar_t* data,
|
| 382 |
+
int64_t width,
|
| 383 |
+
int64_t height,
|
| 384 |
+
int64_t x,
|
| 385 |
+
int64_t y,
|
| 386 |
+
scalar_t value) {
|
| 387 |
+
int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
|
| 388 |
+
int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
|
| 389 |
+
data[access_y * width + access_x] += value;
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
// Based on
|
| 393 |
+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
| 394 |
+
template <typename scalar_t>
|
| 395 |
+
static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
|
| 396 |
+
return ((A + 2) * x - (A + 3)) * x * x + 1;
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
template <typename scalar_t>
|
| 400 |
+
static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
|
| 401 |
+
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
template <typename scalar_t>
|
| 405 |
+
static inline void get_cubic_upsample_coefficients(
|
| 406 |
+
scalar_t coeffs[4],
|
| 407 |
+
scalar_t t) {
|
| 408 |
+
scalar_t A = -0.75;
|
| 409 |
+
|
| 410 |
+
scalar_t x1 = t;
|
| 411 |
+
coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
|
| 412 |
+
coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
|
| 413 |
+
|
| 414 |
+
// opposite coefficients
|
| 415 |
+
scalar_t x2 = 1.0 - t;
|
| 416 |
+
coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
|
| 417 |
+
coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
template <typename scalar_t>
|
| 421 |
+
static inline scalar_t cubic_interp1d(
|
| 422 |
+
scalar_t x0,
|
| 423 |
+
scalar_t x1,
|
| 424 |
+
scalar_t x2,
|
| 425 |
+
scalar_t x3,
|
| 426 |
+
scalar_t t) {
|
| 427 |
+
scalar_t coeffs[4];
|
| 428 |
+
get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
|
| 429 |
+
|
| 430 |
+
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
// when `real_input_index` becomes larger than the range the floating point
|
| 434 |
+
// type can accurately represent, the type casting to `int64_t` might exceed
|
| 435 |
+
// `input_size`, causing overflow. So we guard it with `std::min` below.
|
| 436 |
+
template<typename scalar_t, typename opmath_t>
|
| 437 |
+
static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
|
| 438 |
+
input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
|
| 439 |
+
lambda = std::min(
|
| 440 |
+
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
|
| 441 |
+
static_cast<opmath_t>(1)
|
| 442 |
+
);
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
template<typename scalar_t, typename opmath_t>
|
| 446 |
+
static inline void compute_source_index_and_lambda(
|
| 447 |
+
int64_t& input_index0,
|
| 448 |
+
int64_t& input_index1,
|
| 449 |
+
scalar_t& lambda0,
|
| 450 |
+
scalar_t& lambda1,
|
| 451 |
+
opmath_t ratio,
|
| 452 |
+
int64_t output_index,
|
| 453 |
+
int64_t input_size,
|
| 454 |
+
int64_t output_size,
|
| 455 |
+
bool align_corners) {
|
| 456 |
+
if (output_size == input_size) {
|
| 457 |
+
// scale_factor = 1, simply copy
|
| 458 |
+
input_index0 = output_index;
|
| 459 |
+
input_index1 = output_index;
|
| 460 |
+
lambda0 = static_cast<scalar_t>(1);
|
| 461 |
+
lambda1 = static_cast<scalar_t>(0);
|
| 462 |
+
} else {
|
| 463 |
+
const auto real_input_index =
|
| 464 |
+
area_pixel_compute_source_index<opmath_t>(
|
| 465 |
+
ratio, output_index, align_corners, /*cubic=*/false);
|
| 466 |
+
guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
|
| 467 |
+
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
|
| 468 |
+
input_index1 = input_index0 + offset;
|
| 469 |
+
lambda0 = static_cast<scalar_t>(1.) - lambda1;
|
| 470 |
+
}
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
// It will not be used by data types other than BFloat16 and Half.
|
| 474 |
+
template <typename scalar_in, typename scalar_out,
|
| 475 |
+
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_out> || !std::is_same<scalar_in, float>::value, int> = 0>
|
| 476 |
+
void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
|
| 477 |
+
TORCH_CHECK((is_reduced_floating_point_v<scalar_out>),
|
| 478 |
+
"Upsample backward only support BFloat16 and Half in the lower precision data types on CPU.")
|
| 479 |
+
TORCH_CHECK((std::is_same<scalar_in, float>::value),
|
| 480 |
+
"Upsample backward should use float as acc buffer for BFloat16 and Half grad input on CPU.")
|
| 481 |
+
return;
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
template <typename scalar_in, typename scalar_out,
|
| 485 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_out> && std::is_same<scalar_in, float>::value, int> = 0>
|
| 486 |
+
void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
|
| 487 |
+
using bVec = Vectorized<scalar_out>;
|
| 488 |
+
using fVec = Vectorized<float>;
|
| 489 |
+
int64_t d = 0;
|
| 490 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 491 |
+
bVec gin_bvec = bVec::loadu(gin + d);
|
| 492 |
+
fVec gin_fvec0, gin_fvec1;
|
| 493 |
+
std::tie(gin_fvec0, gin_fvec1) = convert_to_float<scalar_out>(gin_bvec);
|
| 494 |
+
gin_fvec0 += fVec::loadu(buffer_ptr + d);
|
| 495 |
+
gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
|
| 496 |
+
fVec(0).store(buffer_ptr + d);
|
| 497 |
+
fVec(0).store(buffer_ptr + d + fVec::size());
|
| 498 |
+
convert_from_float<scalar_out>(gin_fvec0, gin_fvec1).store(gin + d);
|
| 499 |
+
}
|
| 500 |
+
for (; d < size; d++) {
|
| 501 |
+
gin[d] += buffer_ptr[d];
|
| 502 |
+
buffer_ptr[d] = 0;
|
| 503 |
+
}
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// DON'T include this except from Binary*.cu files. It should not leak into
|
| 2 |
+
// headers.
|
| 3 |
+
#pragma once
|
| 4 |
+
#define TORCH_ASSERT_NO_OPERATORS
|
| 5 |
+
#include <ATen/AccumulateType.h>
|
| 6 |
+
#include <ATen/Dispatch.h>
|
| 7 |
+
#include <ATen/native/BinaryOps.h>
|
| 8 |
+
#include <ATen/native/DispatchStub.h>
|
| 9 |
+
#include <ATen/native/TensorIterator.h>
|
| 10 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 11 |
+
#include <c10/cuda/CUDAMathCompat.h>
|
| 12 |
+
#include <c10/util/TypeSafeSignMath.h>
|
| 13 |
+
#include <ATen/native/cuda/JitLoops.cuh>
|
| 14 |
+
#include <ATen/native/cuda/Loops.cuh>
|
| 15 |
+
|
| 16 |
+
#include <type_traits>
|
| 17 |
+
|
| 18 |
+
namespace at {
|
| 19 |
+
namespace native {
|
| 20 |
+
namespace binary_internal {
|
| 21 |
+
|
| 22 |
+
template <typename scalar_t>
|
| 23 |
+
struct DivFunctor {
|
| 24 |
+
__device__ scalar_t operator()(scalar_t a, scalar_t b) const {
|
| 25 |
+
return a / b;
|
| 26 |
+
}
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
template <typename T>
|
| 30 |
+
struct MulFunctor {
|
| 31 |
+
__device__ T operator()(T a, T b) const {
|
| 32 |
+
return a * b;
|
| 33 |
+
}
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
// Workaround for the error: '*' in boolean context, suggest '&&' instead
|
| 37 |
+
// [-Werror=int-in-bool-context]
|
| 38 |
+
template <>
|
| 39 |
+
struct MulFunctor<bool> {
|
| 40 |
+
__device__ bool operator()(bool a, bool b) const {
|
| 41 |
+
return a && b;
|
| 42 |
+
}
|
| 43 |
+
};
|
| 44 |
+
void div_true_kernel_cuda(TensorIteratorBase& iter);
|
| 45 |
+
void div_trunc_kernel_cuda(TensorIteratorBase& iter);
|
| 46 |
+
} // namespace binary_internal
|
| 47 |
+
} // namespace native
|
| 48 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/CompositeRandomAccessorCommon.h>
|
| 4 |
+
#include <thrust/tuple.h>
|
| 5 |
+
|
| 6 |
+
namespace at { namespace native {
|
| 7 |
+
|
| 8 |
+
struct TupleInfoCPU {
|
| 9 |
+
template <typename ...Types>
|
| 10 |
+
using tuple = thrust::tuple<Types...>;
|
| 11 |
+
|
| 12 |
+
template <typename ...Types>
|
| 13 |
+
static constexpr auto tie(Types&... args) noexcept {
|
| 14 |
+
return thrust::tie(args...);
|
| 15 |
+
}
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
template <typename KeyAccessor, typename ValueAccessor>
|
| 19 |
+
using CompositeRandomAccessorCPU =
|
| 20 |
+
CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfoCPU>;
|
| 21 |
+
|
| 22 |
+
template <typename Values, typename References>
|
| 23 |
+
void swap(
|
| 24 |
+
references_holder<Values, References> rh1,
|
| 25 |
+
references_holder<Values, References> rh2
|
| 26 |
+
) {
|
| 27 |
+
return thrust::swap(rh1.data(), rh2.data());
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <int N, typename Values, typename References>
|
| 31 |
+
auto get(references_holder<Values, References> rh) -> decltype(thrust::get<N>(rh.data())) {
|
| 32 |
+
return thrust::get<N>(rh.data());
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/Config.h>
|
| 2 |
+
#include <ATen/core/DimVector.h>
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
#include <ATen/native/cuda/CuFFTUtils.h>
|
| 5 |
+
#include <ATen/native/utils/ParamsHash.h>
|
| 6 |
+
#include <c10/util/accumulate.h>
|
| 7 |
+
#include <c10/util/irange.h>
|
| 8 |
+
|
| 9 |
+
#include <cufft.h>
|
| 10 |
+
#include <cufftXt.h>
|
| 11 |
+
|
| 12 |
+
#include <limits>
|
| 13 |
+
#include <list>
|
| 14 |
+
#include <sstream>
|
| 15 |
+
#include <stdexcept>
|
| 16 |
+
#include <string>
|
| 17 |
+
#include <unordered_map>
|
| 18 |
+
|
| 19 |
+
namespace at { namespace native { namespace detail {
|
| 20 |
+
|
| 21 |
+
// Enum representing the FFT type
|
| 22 |
+
enum class CuFFTTransformType : int8_t {
|
| 23 |
+
C2C, // Complex-to-complex
|
| 24 |
+
R2C, // Real-to-complex
|
| 25 |
+
C2R, // Complex-to-real
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
// This struct is used to let us easily compute hashes of the
|
| 29 |
+
// parameters.
|
| 30 |
+
// It will be the **key** to the plan cache.
|
| 31 |
+
struct CuFFTParams
|
| 32 |
+
{
|
| 33 |
+
int64_t signal_ndim_; // between 1 and max_rank, i.e., 1 <= signal_ndim <= 3
|
| 34 |
+
// These include additional batch dimension as well.
|
| 35 |
+
int64_t sizes_[max_rank + 1];
|
| 36 |
+
int64_t input_strides_[max_rank + 1];
|
| 37 |
+
int64_t output_strides_[max_rank + 1];
|
| 38 |
+
CuFFTTransformType fft_type_;
|
| 39 |
+
ScalarType value_type_;
|
| 40 |
+
|
| 41 |
+
CuFFTParams() = default;
|
| 42 |
+
|
| 43 |
+
CuFFTParams(IntArrayRef in_strides, IntArrayRef out_strides,
|
| 44 |
+
IntArrayRef signal_sizes, CuFFTTransformType fft_type, ScalarType value_type) {
|
| 45 |
+
// Padding bits must be zeroed for hashing
|
| 46 |
+
memset(this, 0, sizeof(*this));
|
| 47 |
+
signal_ndim_ = signal_sizes.size() - 1;
|
| 48 |
+
fft_type_ = fft_type;
|
| 49 |
+
value_type_ = value_type;
|
| 50 |
+
|
| 51 |
+
TORCH_INTERNAL_ASSERT(in_strides.size() == signal_sizes.size());
|
| 52 |
+
TORCH_INTERNAL_ASSERT(out_strides.size() == signal_sizes.size());
|
| 53 |
+
TORCH_INTERNAL_ASSERT(1 <= signal_ndim_ && signal_ndim_ <= max_rank);
|
| 54 |
+
|
| 55 |
+
std::copy(signal_sizes.cbegin(), signal_sizes.cend(), sizes_);
|
| 56 |
+
std::copy(in_strides.cbegin(), in_strides.cend(), input_strides_);
|
| 57 |
+
std::copy(out_strides.cbegin(), out_strides.cend(), output_strides_);
|
| 58 |
+
}
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
static_assert(std::is_trivial<CuFFTParams>::value, "");
|
| 62 |
+
|
| 63 |
+
// Returns true if the transform type has complex input
|
| 64 |
+
inline bool cufft_complex_input(CuFFTTransformType type) {
|
| 65 |
+
switch (type) {
|
| 66 |
+
case CuFFTTransformType::C2C:
|
| 67 |
+
case CuFFTTransformType::C2R:
|
| 68 |
+
return true;
|
| 69 |
+
|
| 70 |
+
case CuFFTTransformType::R2C:
|
| 71 |
+
return false;
|
| 72 |
+
}
|
| 73 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// Returns true if the transform type has complex output
|
| 77 |
+
inline bool cufft_complex_output(CuFFTTransformType type) {
|
| 78 |
+
switch (type) {
|
| 79 |
+
case CuFFTTransformType::C2C:
|
| 80 |
+
case CuFFTTransformType::R2C:
|
| 81 |
+
return true;
|
| 82 |
+
|
| 83 |
+
case CuFFTTransformType::C2R:
|
| 84 |
+
return false;
|
| 85 |
+
}
|
| 86 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// Create transform type enum from bools representing if input and output are complex
|
| 90 |
+
inline CuFFTTransformType GetCuFFTTransformType(bool complex_input, bool complex_output) {
|
| 91 |
+
if (complex_input && complex_output) {
|
| 92 |
+
return CuFFTTransformType::C2C;
|
| 93 |
+
} else if (complex_input && !complex_output) {
|
| 94 |
+
return CuFFTTransformType::C2R;
|
| 95 |
+
} else if (!complex_input && complex_output) {
|
| 96 |
+
return CuFFTTransformType::R2C;
|
| 97 |
+
}
|
| 98 |
+
TORCH_INTERNAL_ASSERT(false, "Real to real FFTs are not supported");
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class CuFFTHandle {
|
| 103 |
+
::cufftHandle handle_;
|
| 104 |
+
public:
|
| 105 |
+
|
| 106 |
+
CuFFTHandle() {
|
| 107 |
+
CUFFT_CHECK(cufftCreate(&handle_));
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
::cufftHandle & get() { return handle_; }
|
| 111 |
+
const ::cufftHandle & get() const { return handle_; }
|
| 112 |
+
|
| 113 |
+
~CuFFTHandle() {
|
| 114 |
+
// Not using fftDestroy() for rocFFT to work around double freeing of handles
|
| 115 |
+
#if !defined(USE_ROCM)
|
| 116 |
+
cufftDestroy(handle_);
|
| 117 |
+
#endif
|
| 118 |
+
}
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
__forceinline__
|
| 122 |
+
static bool is_pow_of_two(int64_t x) {
|
| 123 |
+
return (x & (x - 1)) == 0;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
using cufft_size_type = long long int;
|
| 127 |
+
|
| 128 |
+
using CuFFTDimVector = c10::SmallVector<cufft_size_type, at::kDimVectorStaticSize>;
|
| 129 |
+
|
| 130 |
+
// Struct representing a tensor in CuFFT's data layout for planning transforms
|
| 131 |
+
// See NOTE [ cuFFT Embedded Strides ].
|
| 132 |
+
struct CuFFTDataLayout {
|
| 133 |
+
CuFFTDimVector embed;
|
| 134 |
+
cufft_size_type stride, dist;
|
| 135 |
+
bool must_clone, simple;
|
| 136 |
+
};
|
| 137 |
+
|
| 138 |
+
// Returns a cufft embedding for a contiguous signal of the given size.
|
| 139 |
+
// e.g. if the input is cloned, this will be the resulting data layout
|
| 140 |
+
// See NOTE [ cuFFT Embedded Strides ].
|
| 141 |
+
inline CuFFTDataLayout cufft_simple_embed(IntArrayRef sizes, bool onesided) {
|
| 142 |
+
CuFFTDataLayout layout;
|
| 143 |
+
layout.simple = true;
|
| 144 |
+
layout.must_clone = false;
|
| 145 |
+
layout.embed.assign(sizes.cbegin() + 1, sizes.cend());
|
| 146 |
+
if (onesided) {
|
| 147 |
+
layout.embed.back() = sizes.back() / 2 + 1;
|
| 148 |
+
}
|
| 149 |
+
layout.stride = 1;
|
| 150 |
+
layout.dist = 1;
|
| 151 |
+
for (const auto& len : layout.embed) {
|
| 152 |
+
layout.dist *= len;
|
| 153 |
+
}
|
| 154 |
+
return layout;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
// Convert strides to a CuFFT embedded representation.
|
| 158 |
+
// If strides cannot be embedded, returns a simple layout and sets must_clone flag
|
| 159 |
+
// See NOTE [ cuFFT Embedded Strides ].
|
| 160 |
+
inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bool onesided) {
|
| 161 |
+
const auto signal_ndim = strides.size() - 1;
|
| 162 |
+
CuFFTDataLayout layout;
|
| 163 |
+
auto last_stride = strides[signal_ndim];
|
| 164 |
+
layout.must_clone = (last_stride <= 0);
|
| 165 |
+
|
| 166 |
+
const auto last_dim_size = onesided ?
|
| 167 |
+
sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim];
|
| 168 |
+
const auto signal_numel = c10::multiply_integers(sizes.slice(1, sizes.size() - 2)) * last_dim_size;
|
| 169 |
+
|
| 170 |
+
// Zero stides are not allowed, even if the batch size is one.
|
| 171 |
+
// If that happens just set a dummy case
|
| 172 |
+
if (sizes[0] == 1) {
|
| 173 |
+
layout.dist = signal_numel;
|
| 174 |
+
} else if (strides[0] == 0) {
|
| 175 |
+
layout.must_clone = true;
|
| 176 |
+
} else {
|
| 177 |
+
layout.dist = strides[0];
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// Calculate the embedding shape, or set must_clone if the strides cannot be embedded
|
| 181 |
+
layout.embed.resize(signal_ndim);
|
| 182 |
+
for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) {
|
| 183 |
+
auto stride = strides[i];
|
| 184 |
+
if (sizes[i] == 1) {
|
| 185 |
+
layout.embed[i] = 1;
|
| 186 |
+
} else if (stride > 0 && stride % last_stride == 0) {
|
| 187 |
+
layout.embed[i] = stride / last_stride;
|
| 188 |
+
last_stride = stride;
|
| 189 |
+
} else {
|
| 190 |
+
layout.must_clone = true;
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
if (layout.must_clone) {
|
| 195 |
+
// If the input needs to be cloned, assume it will be contiguous
|
| 196 |
+
layout = cufft_simple_embed(sizes, onesided);
|
| 197 |
+
layout.must_clone = true;
|
| 198 |
+
} else {
|
| 199 |
+
layout.embed[0] = sizes[1];
|
| 200 |
+
layout.stride = strides[signal_ndim];
|
| 201 |
+
// Determine if layout represents a simple embedding (contiguous data)
|
| 202 |
+
layout.simple = [&] {
|
| 203 |
+
for (const auto i : c10::irange(1, signal_ndim - 1)) {
|
| 204 |
+
if (layout.embed[i] != sizes[i + 1]) {
|
| 205 |
+
return false;
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
return (layout.stride == 1 && layout.dist == signal_numel &&
|
| 210 |
+
layout.embed.back() == last_dim_size);
|
| 211 |
+
}();
|
| 212 |
+
}
|
| 213 |
+
return layout;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
// This class contains all the information needed to execute a cuFFT plan:
|
| 217 |
+
// 1. the plan
|
| 218 |
+
// 2. whether to clone input before executing the plan
|
| 219 |
+
// 3. the workspace size needed
|
| 220 |
+
//
|
| 221 |
+
// This class will be the **value** in the plan cache.
|
| 222 |
+
// It **owns** the raw plan via a unique_ptr.
|
| 223 |
+
class CuFFTConfig {
|
| 224 |
+
public:
|
| 225 |
+
|
| 226 |
+
// Only move semantics is enought for this class. Although we already use
|
| 227 |
+
// unique_ptr for the plan, still remove copy constructor and assignment op so
|
| 228 |
+
// we don't accidentally copy and take perf hit.
|
| 229 |
+
CuFFTConfig(const CuFFTConfig&) = delete;
|
| 230 |
+
CuFFTConfig& operator=(CuFFTConfig const&) = delete;
|
| 231 |
+
|
| 232 |
+
explicit CuFFTConfig(const CuFFTParams& params):
|
| 233 |
+
CuFFTConfig(
|
| 234 |
+
IntArrayRef(params.input_strides_, params.signal_ndim_ + 1),
|
| 235 |
+
IntArrayRef(params.output_strides_, params.signal_ndim_ + 1),
|
| 236 |
+
IntArrayRef(params.sizes_, params.signal_ndim_ + 1),
|
| 237 |
+
params.fft_type_,
|
| 238 |
+
params.value_type_) {}
|
| 239 |
+
|
| 240 |
+
// For complex types, strides are in units of 2 * element_size(dtype)
|
| 241 |
+
// sizes are for the full signal, including batch size and always two-sided
|
| 242 |
+
CuFFTConfig(IntArrayRef in_strides, IntArrayRef out_strides,
|
| 243 |
+
IntArrayRef sizes, CuFFTTransformType fft_type, ScalarType dtype):
|
| 244 |
+
fft_type_(fft_type), value_type_(dtype) {
|
| 245 |
+
|
| 246 |
+
// signal sizes (excluding batch dim)
|
| 247 |
+
CuFFTDimVector signal_sizes(sizes.begin() + 1, sizes.end());
|
| 248 |
+
|
| 249 |
+
// input batch size
|
| 250 |
+
const int64_t batch = sizes[0];
|
| 251 |
+
const int64_t signal_ndim = sizes.size() - 1;
|
| 252 |
+
|
| 253 |
+
// Since cuFFT has limited non-unit stride support and various constraints, we
|
| 254 |
+
// use a flag to keep track throughout this function to see if we need to
|
| 255 |
+
// input = input.clone();
|
| 256 |
+
|
| 257 |
+
#if defined(USE_ROCM)
|
| 258 |
+
// clone input to avoid issues with hipfft clobering the input and failing tests
|
| 259 |
+
clone_input = true;
|
| 260 |
+
#else
|
| 261 |
+
clone_input = false;
|
| 262 |
+
#endif
|
| 263 |
+
|
| 264 |
+
// For half, base strides on the real part of real-to-complex and
|
| 265 |
+
// complex-to-real transforms are not supported. Since our output is always
|
| 266 |
+
// contiguous, only need to check real-to-complex case.
|
| 267 |
+
if (dtype == ScalarType::Half) {
|
| 268 |
+
// cuFFT on half requires compute capability of at least SM_53
|
| 269 |
+
auto dev_prop = at::cuda::getCurrentDeviceProperties();
|
| 270 |
+
TORCH_CHECK(dev_prop->major >= 5 && !(dev_prop->major == 5 && dev_prop->minor < 3),
|
| 271 |
+
"cuFFT doesn't support signals of half type with compute "
|
| 272 |
+
"capability less than SM_53, but the device containing input half "
|
| 273 |
+
"tensor only has SM_", dev_prop->major, dev_prop->minor);
|
| 274 |
+
for (const auto i : c10::irange(signal_ndim)) {
|
| 275 |
+
TORCH_CHECK(is_pow_of_two(sizes[i + 1]),
|
| 276 |
+
"cuFFT only supports dimensions whose sizes are powers of two when"
|
| 277 |
+
" computing in half precision, but got a signal size of",
|
| 278 |
+
sizes.slice(1));
|
| 279 |
+
}
|
| 280 |
+
clone_input |= in_strides.back() != 1;
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
CuFFTDataLayout in_layout;
|
| 284 |
+
if (clone_input) {
|
| 285 |
+
in_layout = cufft_simple_embed(sizes, fft_type == CuFFTTransformType::C2R);
|
| 286 |
+
} else {
|
| 287 |
+
in_layout = as_cufft_embed(in_strides, sizes, fft_type == CuFFTTransformType::C2R);
|
| 288 |
+
}
|
| 289 |
+
auto out_layout = as_cufft_embed(out_strides, sizes, fft_type == CuFFTTransformType::R2C);
|
| 290 |
+
TORCH_INTERNAL_ASSERT(!out_layout.must_clone, "Out strides cannot be represented as CuFFT embedding");
|
| 291 |
+
clone_input |= in_layout.must_clone;
|
| 292 |
+
|
| 293 |
+
// Check if we can take advantage of simple data layout.
|
| 294 |
+
//
|
| 295 |
+
// See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
|
| 296 |
+
|
| 297 |
+
const bool simple_layout = in_layout.simple && out_layout.simple;
|
| 298 |
+
cudaDataType itype, otype, exec_type;
|
| 299 |
+
const auto complex_input = cufft_complex_input(fft_type);
|
| 300 |
+
const auto complex_output = cufft_complex_output(fft_type);
|
| 301 |
+
if (dtype == ScalarType::Float) {
|
| 302 |
+
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
|
| 303 |
+
otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
|
| 304 |
+
exec_type = CUDA_C_32F;
|
| 305 |
+
} else if (dtype == ScalarType::Double) {
|
| 306 |
+
itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
|
| 307 |
+
otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
|
| 308 |
+
exec_type = CUDA_C_64F;
|
| 309 |
+
} else if (dtype == ScalarType::Half) {
|
| 310 |
+
itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
|
| 311 |
+
otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
|
| 312 |
+
exec_type = CUDA_C_16F;
|
| 313 |
+
} else {
|
| 314 |
+
TORCH_CHECK(false, "cuFFT doesn't support tensor of type: ", dtype);
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
// disable auto allocation of workspace to use THC allocator
|
| 318 |
+
CUFFT_CHECK(cufftSetAutoAllocation(plan(), /* autoAllocate */ 0));
|
| 319 |
+
|
| 320 |
+
size_t ws_size_t;
|
| 321 |
+
|
| 322 |
+
// make plan
|
| 323 |
+
if (simple_layout) {
|
| 324 |
+
// If with unit-stride, we tell cuFFT by setting inembed == onembed == NULL.
|
| 325 |
+
// In such case, cuFFT ignores istride, ostride, idist, and odist
|
| 326 |
+
// by assuming istride = ostride = 1.
|
| 327 |
+
//
|
| 328 |
+
// See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
|
| 329 |
+
CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
|
| 330 |
+
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
|
| 331 |
+
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
|
| 332 |
+
batch, &ws_size_t, exec_type));
|
| 333 |
+
} else {
|
| 334 |
+
CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
|
| 335 |
+
in_layout.embed.data(), in_layout.stride, in_layout.dist, itype,
|
| 336 |
+
out_layout.embed.data(), out_layout.stride, out_layout.dist, otype,
|
| 337 |
+
batch, &ws_size_t, exec_type));
|
| 338 |
+
}
|
| 339 |
+
ws_size = static_cast<int64_t>(ws_size_t);
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
const cufftHandle &plan() const { return plan_ptr.get(); }
|
| 343 |
+
|
| 344 |
+
CuFFTTransformType transform_type() const { return fft_type_; }
|
| 345 |
+
ScalarType data_type() const { return value_type_; }
|
| 346 |
+
bool should_clone_input() const { return clone_input; }
|
| 347 |
+
int64_t workspace_size() const { return ws_size; }
|
| 348 |
+
|
| 349 |
+
private:
|
| 350 |
+
CuFFTHandle plan_ptr;
|
| 351 |
+
bool clone_input;
|
| 352 |
+
int64_t ws_size;
|
| 353 |
+
CuFFTTransformType fft_type_;
|
| 354 |
+
ScalarType value_type_;
|
| 355 |
+
};
|
| 356 |
+
|
| 357 |
+
#if defined(USE_ROCM)
|
| 358 |
+
// Note that the max plan number for CUDA version < 10 has to be 1023
|
| 359 |
+
// due to a bug that fails on the 1024th plan
|
| 360 |
+
constexpr int64_t CUFFT_MAX_PLAN_NUM = 1023;
|
| 361 |
+
constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM;
|
| 362 |
+
#else
|
| 363 |
+
constexpr int64_t CUFFT_MAX_PLAN_NUM = std::numeric_limits<int64_t>::max();
|
| 364 |
+
// The default max cache size chosen for CUDA version > 10 is arbitrary.
|
| 365 |
+
// This number puts a limit on how big of a plan cache should we maintain by
|
| 366 |
+
// default. Users can always configure it via cufft_set_plan_cache_max_size.
|
| 367 |
+
constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = 4096;
|
| 368 |
+
#endif
|
| 369 |
+
static_assert(0 <= CUFFT_MAX_PLAN_NUM && CUFFT_MAX_PLAN_NUM <= std::numeric_limits<int64_t>::max(),
|
| 370 |
+
"CUFFT_MAX_PLAN_NUM not in size_t range");
|
| 371 |
+
static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 && CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM,
|
| 372 |
+
"CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range");
|
| 373 |
+
|
| 374 |
+
// This cache assumes that the mapping from key to value never changes.
|
| 375 |
+
// This is **NOT** thread-safe. Please use a mutex when using it **AND** the
|
| 376 |
+
// value returned from try_emplace_value.
|
| 377 |
+
// The contract of using this cache is that try_emplace_value should only be
|
| 378 |
+
// used when the max_size is positive.
|
| 379 |
+
class CuFFTParamsLRUCache {
|
| 380 |
+
public:
|
| 381 |
+
using kv_t = typename std::pair<CuFFTParams, CuFFTConfig>;
|
| 382 |
+
using map_t = typename std::unordered_map<std::reference_wrapper<CuFFTParams>,
|
| 383 |
+
typename std::list<kv_t>::iterator,
|
| 384 |
+
ParamsHash<CuFFTParams>,
|
| 385 |
+
ParamsEqual<CuFFTParams>>;
|
| 386 |
+
using map_kkv_iter_t = typename map_t::iterator;
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
CuFFTParamsLRUCache() : CuFFTParamsLRUCache(CUFFT_DEFAULT_CACHE_SIZE) {}
|
| 390 |
+
|
| 391 |
+
CuFFTParamsLRUCache(int64_t max_size) {
|
| 392 |
+
_set_max_size(max_size);
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
CuFFTParamsLRUCache(CuFFTParamsLRUCache&& other) noexcept :
|
| 396 |
+
_usage_list(std::move(other._usage_list)),
|
| 397 |
+
_cache_map(std::move(other._cache_map)),
|
| 398 |
+
_max_size(other._max_size) {}
|
| 399 |
+
|
| 400 |
+
CuFFTParamsLRUCache& operator=(CuFFTParamsLRUCache&& other) noexcept {
|
| 401 |
+
_usage_list = std::move(other._usage_list);
|
| 402 |
+
_cache_map = std::move(other._cache_map);
|
| 403 |
+
_max_size = other._max_size;
|
| 404 |
+
return *this;
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
// If key is in this cache, return the cached config. Otherwise, emplace the
|
| 408 |
+
// config in this cache and return it.
|
| 409 |
+
// Return const reference because CuFFTConfig shouldn't be tampered with once
|
| 410 |
+
// created.
|
| 411 |
+
const CuFFTConfig &lookup(CuFFTParams params) {
|
| 412 |
+
AT_ASSERT(_max_size > 0);
|
| 413 |
+
|
| 414 |
+
map_kkv_iter_t map_it = _cache_map.find(params);
|
| 415 |
+
// Hit, put to list front
|
| 416 |
+
if (map_it != _cache_map.end()) {
|
| 417 |
+
_usage_list.splice(_usage_list.begin(), _usage_list, map_it->second);
|
| 418 |
+
return map_it->second->second;
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
// Miss
|
| 422 |
+
// remove if needed
|
| 423 |
+
if (_usage_list.size() >= _max_size) {
|
| 424 |
+
auto last = _usage_list.end();
|
| 425 |
+
last--;
|
| 426 |
+
_cache_map.erase(last->first);
|
| 427 |
+
_usage_list.pop_back();
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
// construct new plan at list front, then insert into _cache_map
|
| 431 |
+
_usage_list.emplace_front(std::piecewise_construct,
|
| 432 |
+
std::forward_as_tuple(params),
|
| 433 |
+
std::forward_as_tuple(params));
|
| 434 |
+
auto kv_it = _usage_list.begin();
|
| 435 |
+
_cache_map.emplace(std::piecewise_construct,
|
| 436 |
+
std::forward_as_tuple(kv_it->first),
|
| 437 |
+
std::forward_as_tuple(kv_it));
|
| 438 |
+
return kv_it->second;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
void clear() {
|
| 442 |
+
_cache_map.clear();
|
| 443 |
+
_usage_list.clear();
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
void resize(int64_t new_size) {
|
| 447 |
+
_set_max_size(new_size);
|
| 448 |
+
auto cur_size = _usage_list.size();
|
| 449 |
+
if (cur_size > _max_size) {
|
| 450 |
+
auto delete_it = _usage_list.end();
|
| 451 |
+
for (size_t i = 0; i < cur_size - _max_size; i++) {
|
| 452 |
+
delete_it--;
|
| 453 |
+
_cache_map.erase(delete_it->first);
|
| 454 |
+
}
|
| 455 |
+
_usage_list.erase(delete_it, _usage_list.end());
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
size_t size() const { return _cache_map.size(); }
|
| 460 |
+
|
| 461 |
+
size_t max_size() const noexcept { return _max_size; }
|
| 462 |
+
|
| 463 |
+
std::mutex mutex;
|
| 464 |
+
|
| 465 |
+
private:
|
| 466 |
+
// Only sets size and does value check. Does not resize the data structures.
|
| 467 |
+
void _set_max_size(int64_t new_size) {
|
| 468 |
+
// We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since
|
| 469 |
+
// CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check
|
| 470 |
+
// first.
|
| 471 |
+
TORCH_CHECK(new_size >= 0,
|
| 472 |
+
"cuFFT plan cache size must be non-negative, but got ", new_size);
|
| 473 |
+
TORCH_CHECK(new_size <= CUFFT_MAX_PLAN_NUM,
|
| 474 |
+
"cuFFT plan cache size can not be larger than ", CUFFT_MAX_PLAN_NUM, ", but got ", new_size);
|
| 475 |
+
_max_size = static_cast<size_t>(new_size);
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
std::list<kv_t> _usage_list;
|
| 479 |
+
map_t _cache_map;
|
| 480 |
+
size_t _max_size;
|
| 481 |
+
};
|
| 482 |
+
|
| 483 |
+
// Since ATen is separated into CPU build and CUDA build, we need a way to call
|
| 484 |
+
// these functions only when CUDA is loaded. We use CUDA hooks for this purpose
|
| 485 |
+
// (at cuda/detail/CUDAHooks.cpp), and call the hooked functions from the actual
|
| 486 |
+
// native function counterparts (at native/SpectralOps.cpp), i.e.,
|
| 487 |
+
// _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size
|
| 488 |
+
// _cufft_get_plan_cache_size, and _cufft_clear_plan_cache.
|
| 489 |
+
int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index);
|
| 490 |
+
void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size);
|
| 491 |
+
int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index);
|
| 492 |
+
void cufft_clear_plan_cache_impl(DeviceIndex device_index);
|
| 493 |
+
|
| 494 |
+
}}} // namespace at::native::detail
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at { namespace native {
|
| 4 |
+
#if defined(USE_ROCM)
|
| 5 |
+
// take these out when ROCm implements std:: math functions
|
| 6 |
+
#include <math.h>
|
| 7 |
+
template <typename scalar_t>
|
| 8 |
+
static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
|
| 9 |
+
|
| 10 |
+
template <>
|
| 11 |
+
__forceinline__ __device__ float device_sqrt(float val) {
|
| 12 |
+
return ::sqrtf(val);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
template <>
|
| 16 |
+
__forceinline__ __device__ double device_sqrt(double val) {
|
| 17 |
+
return ::sqrt(val);
|
| 18 |
+
}
|
| 19 |
+
#else
|
| 20 |
+
template<typename scalar_t>
|
| 21 |
+
__forceinline__ __device__ double device_sqrt(scalar_t val) {
|
| 22 |
+
return std::sqrt(val);
|
| 23 |
+
}
|
| 24 |
+
#endif
|
| 25 |
+
}}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h
ADDED
|
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/AccumulateType.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/Dispatch_v2.h>
|
| 6 |
+
#include <ATen/ExpandBase.h>
|
| 7 |
+
#include <ATen/OpMathType.h>
|
| 8 |
+
#include <ATen/native/TensorIterator.h>
|
| 9 |
+
#include <ATen/native/cuda/Loops.cuh>
|
| 10 |
+
#include <c10/util/Half.h>
|
| 11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
| 12 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 13 |
+
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
| 14 |
+
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
| 15 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 16 |
+
#include <ATen/core/DistributionsHelper.h>
|
| 17 |
+
|
| 18 |
+
#include <curand.h>
|
| 19 |
+
#include <curand_kernel.h>
|
| 20 |
+
#include <curand_philox4x32_x.h>
|
| 21 |
+
#include <cstdint>
|
| 22 |
+
#include <limits>
|
| 23 |
+
#include <utility>
|
| 24 |
+
#include <mutex>
|
| 25 |
+
#include <tuple>
|
| 26 |
+
#include <type_traits>
|
| 27 |
+
|
| 28 |
+
namespace at {
|
| 29 |
+
namespace native {
|
| 30 |
+
namespace {
|
| 31 |
+
|
| 32 |
+
// launch bounds used for kernels utilizing TensorIterator
|
| 33 |
+
const uint32_t block_size_bound = 256;
|
| 34 |
+
const uint32_t grid_size_bound = 4;
|
| 35 |
+
// number of randoms given by distributions like curand_uniform4, curand_uniform2_double
|
| 36 |
+
// used in calculating philox offset.
|
| 37 |
+
const uint32_t curand4_engine_calls = 4;
|
| 38 |
+
|
| 39 |
+
// utility function that calculates proper philox_offset
|
| 40 |
+
// for distributions utilizing TensorIterator. For distributions using
|
| 41 |
+
// TensorIterator, we are using a grid-stride loop with each
|
| 42 |
+
// thread yielding one element per thread. For the edge of the grid-stride
|
| 43 |
+
// loop, if the tensor size is large, the unroll loop will kick in and the float4
|
| 44 |
+
// from curand4 will start getting utilized (for common tensor sizes, we end up
|
| 45 |
+
// using rand.x from each thread). Hence, the philox_offset is
|
| 46 |
+
// (number of elements per thread * number of engine calls), which makes
|
| 47 |
+
// sure that philox offset increment is not less than the number of randoms used
|
| 48 |
+
// in each thread.
|
| 49 |
+
std::tuple<uint64_t, dim3, dim3> calc_execution_policy(int64_t total_elements) {
|
| 50 |
+
const uint64_t numel = static_cast<uint64_t>(total_elements);
|
| 51 |
+
const uint32_t block_size = block_size_bound;
|
| 52 |
+
const uint32_t unroll = curand4_engine_calls;
|
| 53 |
+
dim3 dim_block(block_size);
|
| 54 |
+
dim3 grid((numel + block_size - 1) / block_size);
|
| 55 |
+
uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
|
| 56 |
+
grid.x = std::min(
|
| 57 |
+
static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
|
| 58 |
+
grid.x);
|
| 59 |
+
//number of times random will be generated per thread, to offset philox counter in thc random state
|
| 60 |
+
uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll) + 1)
|
| 61 |
+
* curand4_engine_calls;
|
| 62 |
+
return std::make_tuple(counter_offset, grid, dim_block);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
// grid stride loop kernel for distributions
|
| 66 |
+
template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
|
| 67 |
+
C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
|
| 68 |
+
__global__ void distribution_elementwise_grid_stride_kernel(int numel,
|
| 69 |
+
PhiloxCudaState philox_args,
|
| 70 |
+
const dist_t dist_func,
|
| 71 |
+
const transform_t transform_func) {
|
| 72 |
+
auto seeds = at::cuda::philox::unpack(philox_args);
|
| 73 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 74 |
+
curandStatePhilox4_32_10_t state;
|
| 75 |
+
curand_init(std::get<0>(seeds),
|
| 76 |
+
idx,
|
| 77 |
+
std::get<1>(seeds),
|
| 78 |
+
&state);
|
| 79 |
+
|
| 80 |
+
int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
|
| 81 |
+
blockDim.x * gridDim.x * unroll_factor;
|
| 82 |
+
for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
|
| 83 |
+
auto rand = dist_func(&state);
|
| 84 |
+
#pragma unroll
|
| 85 |
+
for (int ii = 0; ii < unroll_factor; ii++) {
|
| 86 |
+
int li = linear_index + blockDim.x * gridDim.x * ii;
|
| 87 |
+
if (li < numel) {
|
| 88 |
+
transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
__syncthreads();
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
/**
|
| 96 |
+
* distribution_nullary_kernel is analogous to gpu_kernel in
|
| 97 |
+
* ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses
|
| 98 |
+
* TensorIterator to launch a kernel. However, the differences are
|
| 99 |
+
* - it launches a grid-stride loop based kernel. The kernel is not
|
| 100 |
+
* generic like elementwise_kernel in Loops.cuh and is specialized
|
| 101 |
+
* for the distribution kernels here.
|
| 102 |
+
* - For big size tensors, we can launch multiple kernels recursively
|
| 103 |
+
* (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
|
| 104 |
+
* offset calculation is done in this function.
|
| 105 |
+
*
|
| 106 |
+
* FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
|
| 107 |
+
* to have grid-stride loop kernel and then use that to launch our distribution
|
| 108 |
+
* kernels? Note that we need a grid-stride loop kernel because, we found by testing
|
| 109 |
+
* that it achieves peak effective bandwidth.
|
| 110 |
+
*/
|
| 111 |
+
template<typename scalar_t,
|
| 112 |
+
typename accscalar_t,
|
| 113 |
+
int unroll_factor,
|
| 114 |
+
typename RNG,
|
| 115 |
+
typename dist_t,
|
| 116 |
+
typename transform_t>
|
| 117 |
+
void distribution_nullary_kernel(at::TensorIteratorBase& iter,
|
| 118 |
+
RNG gen,
|
| 119 |
+
const dist_t& dist_func,
|
| 120 |
+
const transform_t transform_func) {
|
| 121 |
+
static_assert(unroll_factor >= 1, "unroll_factor must be >= 1.");
|
| 122 |
+
int64_t numel = iter.numel();
|
| 123 |
+
if (numel == 0) {
|
| 124 |
+
return;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
auto execution_policy = calc_execution_policy(numel);
|
| 128 |
+
auto counter_offset = std::get<0>(execution_policy);
|
| 129 |
+
auto grid = std::get<1>(execution_policy);
|
| 130 |
+
auto block = std::get<2>(execution_policy);
|
| 131 |
+
PhiloxCudaState rng_engine_inputs;
|
| 132 |
+
{
|
| 133 |
+
// See Note [Acquire lock when using random generators]
|
| 134 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 135 |
+
rng_engine_inputs = gen->philox_cuda_state(counter_offset);
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
if (!iter.can_use_32bit_indexing()) {
|
| 139 |
+
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
| 140 |
+
distribution_nullary_kernel<scalar_t, accscalar_t, unroll_factor>(sub_iter,
|
| 141 |
+
gen, dist_func, transform_func);
|
| 142 |
+
}
|
| 143 |
+
return;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
char* out_data = (char*)iter.data_ptr(0);
|
| 147 |
+
|
| 148 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 149 |
+
if (iter.is_trivial_1d()) {
|
| 150 |
+
auto strides = iter.get_inner_strides();
|
| 151 |
+
int stride0 = strides[0];
|
| 152 |
+
distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
|
| 153 |
+
numel,
|
| 154 |
+
rng_engine_inputs,
|
| 155 |
+
dist_func,
|
| 156 |
+
[=]__device__(int idx, accscalar_t rand) {
|
| 157 |
+
scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
|
| 158 |
+
*out = transform_func(rand);
|
| 159 |
+
}
|
| 160 |
+
);
|
| 161 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 162 |
+
} else {
|
| 163 |
+
auto offset_calc = make_offset_calculator<1>(iter);
|
| 164 |
+
distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
|
| 165 |
+
numel,
|
| 166 |
+
rng_engine_inputs,
|
| 167 |
+
dist_func,
|
| 168 |
+
[=]__device__(int idx, accscalar_t rand) {
|
| 169 |
+
auto offsets = offset_calc.get(idx);
|
| 170 |
+
scalar_t* out = (scalar_t*)&out_data[offsets[0]];
|
| 171 |
+
*out = transform_func(rand);
|
| 172 |
+
}
|
| 173 |
+
);
|
| 174 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
// Binary kernel
|
| 179 |
+
template <typename func_t, typename inp_offset_calc_t, typename out_offset_calc_t>
|
| 180 |
+
__global__ void distribution_binary_elementwise_kernel(
|
| 181 |
+
int numel,
|
| 182 |
+
func_t f,
|
| 183 |
+
PhiloxCudaState philox_args,
|
| 184 |
+
typename function_traits<func_t>::result_type *output_data,
|
| 185 |
+
const typename function_traits<func_t>::template arg<1>::type *input_data_1,
|
| 186 |
+
const typename function_traits<func_t>::template arg<2>::type *input_data_2,
|
| 187 |
+
inp_offset_calc_t inp_calc,
|
| 188 |
+
out_offset_calc_t out_calc) {
|
| 189 |
+
auto seeds = at::cuda::philox::unpack(philox_args);
|
| 190 |
+
|
| 191 |
+
using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
|
| 192 |
+
using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
|
| 193 |
+
|
| 194 |
+
input_t_1 inputs_1[thread_work_size()];
|
| 195 |
+
input_t_2 inputs_2[thread_work_size()];
|
| 196 |
+
|
| 197 |
+
int base_index = block_work_size() * blockIdx.x;
|
| 198 |
+
int remaining = std::min<int>(numel - base_index, block_work_size());
|
| 199 |
+
|
| 200 |
+
curandStatePhilox4_32_10_t state;
|
| 201 |
+
curand_init(std::get<0>(seeds),
|
| 202 |
+
blockIdx.x * blockDim.x + threadIdx.x,
|
| 203 |
+
std::get<1>(seeds),
|
| 204 |
+
&state);
|
| 205 |
+
|
| 206 |
+
// load data into registers
|
| 207 |
+
int thread_idx = threadIdx.x;
|
| 208 |
+
#pragma unroll
|
| 209 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 210 |
+
if (thread_idx >= remaining) {
|
| 211 |
+
break;
|
| 212 |
+
}
|
| 213 |
+
int input_idx = thread_idx + base_index;
|
| 214 |
+
auto offsets = inp_calc.get(input_idx);
|
| 215 |
+
inputs_1[i] = input_data_1[offsets[0]];
|
| 216 |
+
inputs_2[i] = input_data_2[offsets[1]];
|
| 217 |
+
|
| 218 |
+
thread_idx += num_threads();
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// compute and store
|
| 222 |
+
thread_idx = threadIdx.x;
|
| 223 |
+
#pragma unroll
|
| 224 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 225 |
+
if (thread_idx >= remaining) {
|
| 226 |
+
break;
|
| 227 |
+
}
|
| 228 |
+
int input_idx = thread_idx + base_index;
|
| 229 |
+
auto offsets = out_calc.get(input_idx);
|
| 230 |
+
output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]);
|
| 231 |
+
thread_idx += num_threads();
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
template <typename func_t>
|
| 236 |
+
void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) {
|
| 237 |
+
static_assert(std::is_same<typename function_traits<func_t>::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t");
|
| 238 |
+
using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
|
| 239 |
+
using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
|
| 240 |
+
using output_t = typename function_traits<func_t>::result_type;
|
| 241 |
+
|
| 242 |
+
if (!iter.can_use_32bit_indexing()) {
|
| 243 |
+
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
| 244 |
+
distribution_binary_kernel(sub_iter, philox_args, f);
|
| 245 |
+
}
|
| 246 |
+
return;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());
|
| 250 |
+
|
| 251 |
+
int64_t numel = iter.numel();
|
| 252 |
+
if (numel == 0) {
|
| 253 |
+
return;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
output_t *output_data = static_cast<output_t *>(iter.data_ptr(0));
|
| 257 |
+
const input_t_1 *input_data_1 = static_cast<const input_t_1 *>(iter.data_ptr(1));
|
| 258 |
+
const input_t_2 *input_data_2 = static_cast<const input_t_2 *>(iter.data_ptr(2));
|
| 259 |
+
|
| 260 |
+
int64_t grid = (numel + block_work_size() - 1) / block_work_size();
|
| 261 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 262 |
+
|
| 263 |
+
if (iter.is_contiguous()) {
|
| 264 |
+
distribution_binary_elementwise_kernel<<<grid,num_threads(), 0, stream>>>(
|
| 265 |
+
numel, f, philox_args, output_data, input_data_1, input_data_2,
|
| 266 |
+
TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>());
|
| 267 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 268 |
+
} else {
|
| 269 |
+
distribution_binary_elementwise_kernel<<<grid, num_threads(), 0, stream>>>(
|
| 270 |
+
numel, f, philox_args, output_data, input_data_1, input_data_2,
|
| 271 |
+
make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter));
|
| 272 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
} // namespace
|
| 277 |
+
}} // namespace at::native
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
namespace at {
|
| 281 |
+
namespace native {
|
| 282 |
+
namespace templates {
|
| 283 |
+
namespace cuda {
|
| 284 |
+
|
| 285 |
+
// ==================================================== Random ========================================================
|
| 286 |
+
|
| 287 |
+
template<typename RNG>
|
| 288 |
+
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
|
| 289 |
+
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
|
| 290 |
+
if ((
|
| 291 |
+
std::is_same<scalar_t, int64_t>::value ||
|
| 292 |
+
std::is_same<scalar_t, double>::value ||
|
| 293 |
+
std::is_same<scalar_t, float>::value ||
|
| 294 |
+
std::is_same<scalar_t, at::BFloat16>::value) && range >= 1ULL << 32)
|
| 295 |
+
{
|
| 296 |
+
// define lambda to mod with range and add base
|
| 297 |
+
auto random_func = [range, base] __device__ (uint64_t rand) {
|
| 298 |
+
return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
|
| 299 |
+
};
|
| 300 |
+
distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
|
| 301 |
+
gen,
|
| 302 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
|
| 303 |
+
ulonglong2 ret;
|
| 304 |
+
uint4 rand_val = curand4(state);
|
| 305 |
+
ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
|
| 306 |
+
ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
|
| 307 |
+
return ret;
|
| 308 |
+
},
|
| 309 |
+
random_func);
|
| 310 |
+
} else {
|
| 311 |
+
auto random_func = [range, base] __device__ (uint32_t rand) {
|
| 312 |
+
return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
|
| 313 |
+
};
|
| 314 |
+
distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
|
| 315 |
+
gen,
|
| 316 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) {
|
| 317 |
+
return curand4(state);
|
| 318 |
+
},
|
| 319 |
+
random_func);
|
| 320 |
+
}
|
| 321 |
+
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
// This is the special kernel to handle single specific case:
|
| 325 |
+
// from(inclusive) = std::numeric_limits<int64_t>::lowest()
|
| 326 |
+
// to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
|
| 327 |
+
template<typename RNG>
|
| 328 |
+
void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) {
|
| 329 |
+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] {
|
| 330 |
+
if (std::is_same<scalar_t, int64_t>::value ||
|
| 331 |
+
std::is_same<scalar_t, double>::value ||
|
| 332 |
+
std::is_same<scalar_t, float>::value ||
|
| 333 |
+
std::is_same<scalar_t, at::BFloat16>::value) {
|
| 334 |
+
auto random_func = [] __device__ (uint64_t rand) {
|
| 335 |
+
return transformation::uniform_int_full_range<scalar_t>(rand);
|
| 336 |
+
};
|
| 337 |
+
distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
|
| 338 |
+
gen,
|
| 339 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
|
| 340 |
+
ulonglong2 ret;
|
| 341 |
+
uint4 rand_val = curand4(state);
|
| 342 |
+
ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
|
| 343 |
+
ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
|
| 344 |
+
return ret;
|
| 345 |
+
},
|
| 346 |
+
random_func);
|
| 347 |
+
} else {
|
| 348 |
+
TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16");
|
| 349 |
+
}
|
| 350 |
+
});
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
template<typename RNG>
|
| 354 |
+
struct RandomFromToKernel {
|
| 355 |
+
void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, c10::optional<Generator> gen) {
|
| 356 |
+
random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
|
| 357 |
+
}
|
| 358 |
+
void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
|
| 359 |
+
random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
|
| 360 |
+
}
|
| 361 |
+
};
|
| 362 |
+
|
| 363 |
+
template<typename RNG>
|
| 364 |
+
void random_kernel(TensorIteratorBase& iter, RNG gen) {
|
| 365 |
+
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
|
| 366 |
+
if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
|
| 367 |
+
auto random_func = [] __device__ (uint64_t rand) {
|
| 368 |
+
return transformation::uniform_int<scalar_t>(rand);
|
| 369 |
+
};
|
| 370 |
+
distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter, gen,
|
| 371 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
|
| 372 |
+
ulonglong2 ret;
|
| 373 |
+
uint4 rand_val = curand4(state);
|
| 374 |
+
ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
|
| 375 |
+
ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
|
| 376 |
+
return ret;
|
| 377 |
+
},
|
| 378 |
+
random_func);
|
| 379 |
+
} else {
|
| 380 |
+
auto random_func = [] __device__ (uint32_t rand) {
|
| 381 |
+
return transformation::uniform_int<scalar_t>(rand);
|
| 382 |
+
};
|
| 383 |
+
distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
|
| 384 |
+
gen,
|
| 385 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) {
|
| 386 |
+
return curand4(state);
|
| 387 |
+
},
|
| 388 |
+
random_func);
|
| 389 |
+
}
|
| 390 |
+
});
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
template<typename RNG>
|
| 394 |
+
struct RandomKernel {
|
| 395 |
+
void operator()(TensorIteratorBase& iter, RNG gen) {
|
| 396 |
+
random_kernel(iter, gen);
|
| 397 |
+
}
|
| 398 |
+
};
|
| 399 |
+
|
| 400 |
+
// ====================================================================================================================
|
| 401 |
+
|
| 402 |
+
template<typename scalar_t, typename accscalar_t, size_t curand4_engine_calls, typename RNG, typename transform_t>
|
| 403 |
+
void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
|
| 404 |
+
if (std::is_same<scalar_t, double>::value) {
|
| 405 |
+
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
|
| 406 |
+
gen,
|
| 407 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
|
| 408 |
+
transform);
|
| 409 |
+
} else {
|
| 410 |
+
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
|
| 411 |
+
gen,
|
| 412 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
|
| 413 |
+
transform);
|
| 414 |
+
}
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
template<typename scalar_t, typename accscalar_t, size_t curand4_engine_calls, typename RNG, typename transform_t>
|
| 418 |
+
void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
|
| 419 |
+
if (std::is_same<scalar_t, double>::value) {
|
| 420 |
+
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
|
| 421 |
+
gen,
|
| 422 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal2_double(state); },
|
| 423 |
+
transform);
|
| 424 |
+
} else {
|
| 425 |
+
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
|
| 426 |
+
gen,
|
| 427 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal4(state); },
|
| 428 |
+
transform);
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
// ==================================================== Normal ========================================================
|
| 433 |
+
|
| 434 |
+
template<typename RNG>
|
| 435 |
+
void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) {
|
| 436 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 437 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] {
|
| 438 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 439 |
+
auto mean = static_cast<accscalar_t>(mean_);
|
| 440 |
+
auto std = static_cast<accscalar_t>(std_);
|
| 441 |
+
// define lambda to multiply std and add mean
|
| 442 |
+
auto normal_func = [mean, std] __device__ (accscalar_t rand) {
|
| 443 |
+
return static_cast<scalar_t>(transformation::normal<accscalar_t>(rand, mean, std));
|
| 444 |
+
};
|
| 445 |
+
normal_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, normal_func);
|
| 446 |
+
});
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
template<typename RNG>
|
| 450 |
+
struct NormalKernel {
|
| 451 |
+
void operator()(const TensorBase &self, double mean, double std, c10::optional<Generator> gen) {
|
| 452 |
+
normal_kernel(self, mean, std, check_generator<RNG>(gen));
|
| 453 |
+
}
|
| 454 |
+
};
|
| 455 |
+
|
| 456 |
+
// ==================================================== Uniform ========================================================
|
| 457 |
+
|
| 458 |
+
template<typename RNG>
|
| 459 |
+
void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) {
|
| 460 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
|
| 461 |
+
auto from = static_cast<scalar_t>(from_);
|
| 462 |
+
auto to = static_cast<scalar_t>(to_);
|
| 463 |
+
using opmath_t = at::opmath_type<scalar_t>;
|
| 464 |
+
auto range = static_cast<opmath_t>(to-from);
|
| 465 |
+
// define lambda to reverse bounds, multiply 'range' and add 'from_'
|
| 466 |
+
auto uniform_func = [range, from, to] __device__ (opmath_t rand) {
|
| 467 |
+
// Compute output value before reversing the bounds
|
| 468 |
+
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947
|
| 469 |
+
auto value = static_cast<scalar_t>(rand * range + from);
|
| 470 |
+
// reverse the bounds of curand4 from (0, 1] to [0, 1)
|
| 471 |
+
// Note that this method is from legacy THCTensorRandom and is likely to give
|
| 472 |
+
// you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
|
| 473 |
+
// by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
|
| 474 |
+
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
|
| 475 |
+
auto reverse_bound_value = value == to ? from : value;
|
| 476 |
+
return reverse_bound_value;
|
| 477 |
+
};
|
| 478 |
+
uniform_and_transform<scalar_t, opmath_t, curand4_engine_calls>(iter, gen, uniform_func);
|
| 479 |
+
});
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
template<typename RNG>
|
| 483 |
+
struct UniformKernel {
|
| 484 |
+
void operator()(TensorIteratorBase& iter, double from, double to, c10::optional<Generator> gen) {
|
| 485 |
+
uniform_kernel(iter, from, to, check_generator<RNG>(gen));
|
| 486 |
+
}
|
| 487 |
+
};
|
| 488 |
+
|
| 489 |
+
// ================================================== LogNormal =======================================================
|
| 490 |
+
|
| 491 |
+
template<typename RNG>
|
| 492 |
+
void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) {
|
| 493 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
|
| 494 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 495 |
+
auto mean = static_cast<accscalar_t>(mean_);
|
| 496 |
+
auto std = static_cast<accscalar_t>(std_);
|
| 497 |
+
// define lambda for log_normal transformation
|
| 498 |
+
auto log_normal_func = [mean, std] __device__ (accscalar_t rand) {
|
| 499 |
+
return static_cast<scalar_t>(transformation::log_normal<accscalar_t>(transformation::normal<accscalar_t>(rand, mean, std)));
|
| 500 |
+
};
|
| 501 |
+
normal_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, log_normal_func);
|
| 502 |
+
});
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
template<typename RNG>
|
| 506 |
+
struct LogNormalKernel {
|
| 507 |
+
void operator()(TensorIteratorBase& iter, double mean, double std, c10::optional<Generator> gen) {
|
| 508 |
+
log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
|
| 509 |
+
}
|
| 510 |
+
};
|
| 511 |
+
|
| 512 |
+
// =================================================== Geometric ======================================================
|
| 513 |
+
|
| 514 |
+
template<typename RNG>
|
| 515 |
+
void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) {
|
| 516 |
+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
|
| 517 |
+
using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
|
| 518 |
+
// define lambda for geometric transformation
|
| 519 |
+
auto geometric_func = [p] __device__ (accscalar_t rand) {
|
| 520 |
+
return static_cast<scalar_t>(transformation::geometric<accscalar_t>(rand, p));
|
| 521 |
+
};
|
| 522 |
+
uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, geometric_func);
|
| 523 |
+
});
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
template<typename RNG>
|
| 527 |
+
struct GeometricKernel {
|
| 528 |
+
void operator()(TensorIteratorBase& iter, double p, c10::optional<Generator> gen) {
|
| 529 |
+
geometric_kernel(iter, p, check_generator<RNG>(gen));
|
| 530 |
+
}
|
| 531 |
+
};
|
| 532 |
+
|
| 533 |
+
// ================================================== Exponential =====================================================
|
| 534 |
+
|
| 535 |
+
template<typename RNG>
|
| 536 |
+
void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) {
|
| 537 |
+
TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
|
| 538 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
|
| 539 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 540 |
+
auto lambda = static_cast<accscalar_t>(lambda_);
|
| 541 |
+
// define lambda for exponential transformation
|
| 542 |
+
auto exponential_func = [lambda] __device__ (accscalar_t rand) {
|
| 543 |
+
return static_cast<scalar_t>(transformation::exponential<accscalar_t>(rand, lambda));
|
| 544 |
+
};
|
| 545 |
+
uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, exponential_func);
|
| 546 |
+
});
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
template<typename RNG>
|
| 550 |
+
struct ExponentialKernel {
|
| 551 |
+
void operator()(TensorIteratorBase& iter, double lambda, c10::optional<Generator> gen) {
|
| 552 |
+
exponential_kernel(iter, lambda, check_generator<RNG>(gen));
|
| 553 |
+
}
|
| 554 |
+
};
|
| 555 |
+
|
| 556 |
+
// ==================================================== Cauchy ========================================================
|
| 557 |
+
|
| 558 |
+
template<typename RNG>
|
| 559 |
+
void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) {
|
| 560 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
|
| 561 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 562 |
+
auto median = static_cast<accscalar_t>(median_);
|
| 563 |
+
auto sigma = static_cast<accscalar_t>(sigma_);
|
| 564 |
+
// define lambda for cauchy transformation
|
| 565 |
+
auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
|
| 566 |
+
return static_cast<scalar_t>(transformation::cauchy<accscalar_t>(rand, median, sigma));
|
| 567 |
+
};
|
| 568 |
+
uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, cauchy_func);
|
| 569 |
+
});
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
template<typename RNG>
|
| 573 |
+
struct CauchyKernel {
|
| 574 |
+
void operator()(TensorIteratorBase& iter, double median, double sigma, c10::optional<Generator> gen) {
|
| 575 |
+
cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
|
| 576 |
+
}
|
| 577 |
+
};
|
| 578 |
+
|
| 579 |
+
// ==================================================== Bernoulli =====================================================
|
| 580 |
+
|
| 581 |
+
template<typename scalar_t, typename prob_t>
|
| 582 |
+
void bernoulli_tensor_cuda_kernel(
|
| 583 |
+
const TensorBase &ret, const at::TensorBase &p,
|
| 584 |
+
PhiloxCudaState philox_args) {
|
| 585 |
+
auto functor = [philox_args] __device__(
|
| 586 |
+
int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
|
| 587 |
+
const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
|
| 588 |
+
auto seeds = at::cuda::philox::unpack(philox_args);
|
| 589 |
+
curandStatePhilox4_32_10_t state;
|
| 590 |
+
curand_init(std::get<0>(seeds),
|
| 591 |
+
blockIdx.x * blockDim.x + threadIdx.x,
|
| 592 |
+
std::get<1>(seeds),
|
| 593 |
+
&state);
|
| 594 |
+
|
| 595 |
+
// See Note [Register spilling in curand call for CUDA < 10]
|
| 596 |
+
float4 rand = curand_uniform4(&state);
|
| 597 |
+
switch (n) {
|
| 598 |
+
case 4: {
|
| 599 |
+
CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
|
| 600 |
+
v4 = static_cast<scalar_t>(rand.w <= p4);
|
| 601 |
+
// fallthrough
|
| 602 |
+
}
|
| 603 |
+
case 3: {
|
| 604 |
+
CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
|
| 605 |
+
v3 = static_cast<scalar_t>(rand.z <= p3);
|
| 606 |
+
// fallthrough
|
| 607 |
+
}
|
| 608 |
+
case 2: {
|
| 609 |
+
CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
|
| 610 |
+
v2 = static_cast<scalar_t>(rand.y <= p2);
|
| 611 |
+
// fallthrough
|
| 612 |
+
}
|
| 613 |
+
case 1: {
|
| 614 |
+
CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
|
| 615 |
+
v1 = static_cast<scalar_t>(rand.x <= p1);
|
| 616 |
+
}
|
| 617 |
+
}
|
| 618 |
+
};
|
| 619 |
+
// The template argument `4` below indicates that we want to operate on four
|
| 620 |
+
// element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
|
| 621 |
+
at::cuda::CUDA_tensor_apply2<scalar_t, prob_t, 4, decltype(functor),
|
| 622 |
+
/*max_threads_per_block=*/512,
|
| 623 |
+
/*min_blocks_per_sm==*/2>(ret, p, functor);
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
template<typename RNG>
|
| 627 |
+
void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) {
|
| 628 |
+
PhiloxCudaState rng_engine_inputs;
|
| 629 |
+
{
|
| 630 |
+
// See Note [Acquire lock when using random generators]
|
| 631 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 632 |
+
rng_engine_inputs = gen->philox_cuda_state(10);
|
| 633 |
+
}
|
| 634 |
+
TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type());
|
| 635 |
+
// cast probabilities tensor to double for double `self` tensor, and to `float` for everything else
|
| 636 |
+
const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
|
| 637 |
+
auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type));
|
| 638 |
+
auto p = expand_inplace(self, p_cuda);
|
| 639 |
+
AT_DISPATCH_ALL_TYPES_AND3(
|
| 640 |
+
at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
|
| 641 |
+
if (std::is_same<scalar_t, double>::value) {
|
| 642 |
+
return bernoulli_tensor_cuda_kernel<double, double>(self, *p, rng_engine_inputs);
|
| 643 |
+
} else {
|
| 644 |
+
return bernoulli_tensor_cuda_kernel<scalar_t, float>(self, *p, rng_engine_inputs);
|
| 645 |
+
}
|
| 646 |
+
});
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
template<typename RNG>
|
| 650 |
+
void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) {
|
| 651 |
+
AT_DISPATCH_ALL_TYPES_AND3(
|
| 652 |
+
at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
|
| 653 |
+
using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
|
| 654 |
+
// define lambda for bernoulli transformation
|
| 655 |
+
auto bernoulli_func = [p] __device__ (accscalar_t rand) {
|
| 656 |
+
return static_cast<scalar_t>(transformation::bernoulli<accscalar_t>(rand, p));
|
| 657 |
+
};
|
| 658 |
+
uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, bernoulli_func);
|
| 659 |
+
});
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
template<typename RNG>
|
| 663 |
+
struct BernoulliKernel {
|
| 664 |
+
void operator()(TensorIteratorBase& iter, double p, c10::optional<Generator> gen) {
|
| 665 |
+
bernoulli_kernel(iter, p, check_generator<RNG>(gen));
|
| 666 |
+
}
|
| 667 |
+
void operator()(const TensorBase &self, const TensorBase &p_, c10::optional<Generator> gen) {
|
| 668 |
+
bernoulli_kernel(self, p_, check_generator<RNG>(gen));
|
| 669 |
+
}
|
| 670 |
+
};
|
| 671 |
+
|
| 672 |
+
}}}}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at {
|
| 4 |
+
struct CUDAGeneratorImpl;
|
| 5 |
+
struct TensorIteratorBase;
|
| 6 |
+
class TensorBase;
|
| 7 |
+
|
| 8 |
+
namespace native {
|
| 9 |
+
|
| 10 |
+
void launch_poisson_cuda_kernel(
|
| 11 |
+
const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen);
|
| 12 |
+
|
| 13 |
+
void launch_gamma_kernel(
|
| 14 |
+
const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen);
|
| 15 |
+
|
| 16 |
+
void launch_binomial_cuda_kernel(
|
| 17 |
+
TensorIteratorBase &iter, CUDAGeneratorImpl *gen);
|
| 18 |
+
|
| 19 |
+
void launch_dirichlet_kernel(TensorIteratorBase &iter);
|
| 20 |
+
|
| 21 |
+
void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter);
|
| 22 |
+
|
| 23 |
+
void launch_dirichlet_grad_kernel(TensorIteratorBase &iter);
|
| 24 |
+
|
| 25 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/cuda/Atomic.cuh>
|
| 4 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 5 |
+
#include <ATen/TensorUtils.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
namespace native {
|
| 9 |
+
|
| 10 |
+
Tensor embedding_backward_cuda_kernel(
|
| 11 |
+
const Tensor &grad,
|
| 12 |
+
const Tensor &orig_indices,
|
| 13 |
+
const Tensor &sorted_indices,
|
| 14 |
+
const Tensor &count,
|
| 15 |
+
int64_t num_weights,
|
| 16 |
+
int padding_idx = -1,
|
| 17 |
+
bool mode_mean = false,
|
| 18 |
+
const Tensor &offset2bag = Tensor(),
|
| 19 |
+
const Tensor &bag_size = Tensor(),
|
| 20 |
+
const Tensor &per_sample_weights = Tensor());
|
| 21 |
+
|
| 22 |
+
}}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/OpMathType.h>
|
| 3 |
+
#include <ATen/native/ForeachUtils.h>
|
| 4 |
+
#include <ATen/native/cuda/MultiTensorApply.cuh>
|
| 5 |
+
#include <ATen/native/cuda/Pow.cuh>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
namespace {
|
| 10 |
+
|
| 11 |
+
// TODO(crcrpar): Handle version bump in codegen.
|
| 12 |
+
// rel:
|
| 13 |
+
// https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
|
| 14 |
+
inline void increment_version(TensorList tensors) {
|
| 15 |
+
for (const auto& t : tensors) {
|
| 16 |
+
t.unsafeGetTensorImpl()->bump_version();
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// Initializes args and checks if all args are aligned
|
| 21 |
+
template <int depth, typename T>
|
| 22 |
+
__device__ bool init_args(
|
| 23 |
+
T** args,
|
| 24 |
+
TensorListMetadata<depth>& tl,
|
| 25 |
+
const int64_t chunk_idx,
|
| 26 |
+
const int64_t chunk_size,
|
| 27 |
+
const int64_t tensor_loc) {
|
| 28 |
+
bool all_aligned = true;
|
| 29 |
+
for (int i = 0; i < depth; i++) {
|
| 30 |
+
args[i] = (T*)tl.addresses[i][tensor_loc];
|
| 31 |
+
args[i] += chunk_idx * chunk_size;
|
| 32 |
+
|
| 33 |
+
if (!is_aligned(args[i])) {
|
| 34 |
+
all_aligned = false;
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
return all_aligned;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// Initializes args and checks if all args are aligned
|
| 41 |
+
template <int depth, typename T, typename T2>
|
| 42 |
+
__device__ bool init_args(
|
| 43 |
+
T** args,
|
| 44 |
+
TensorListScalarListMetadata<T2, depth>& tl,
|
| 45 |
+
const int64_t chunk_idx,
|
| 46 |
+
const int64_t chunk_size,
|
| 47 |
+
const int64_t tensor_loc) {
|
| 48 |
+
bool all_aligned = true;
|
| 49 |
+
for (int i = 0; i < depth; i++) {
|
| 50 |
+
args[i] = (T*)tl.addresses[i][tensor_loc];
|
| 51 |
+
args[i] += chunk_idx * chunk_size;
|
| 52 |
+
|
| 53 |
+
if (!is_aligned(args[i])) {
|
| 54 |
+
all_aligned = false;
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
return all_aligned;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template <int depth, typename T>
|
| 61 |
+
__device__ bool init_args(
|
| 62 |
+
T** args,
|
| 63 |
+
FusedOptimizerTensorListMetadata<depth>& tl,
|
| 64 |
+
const int64_t chunk_idx,
|
| 65 |
+
const int64_t chunk_size,
|
| 66 |
+
const int64_t tensor_loc) {
|
| 67 |
+
bool all_aligned = true;
|
| 68 |
+
for (int i = 0; i < depth; i++) {
|
| 69 |
+
args[i] = (T*)tl.addresses[i][tensor_loc];
|
| 70 |
+
args[i] += chunk_idx * chunk_size;
|
| 71 |
+
|
| 72 |
+
if (!is_aligned(args[i])) {
|
| 73 |
+
all_aligned = false;
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
return all_aligned;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
template <int depth, typename T>
|
| 80 |
+
__device__ void load_args(
|
| 81 |
+
T r_args[][kILP],
|
| 82 |
+
T** args,
|
| 83 |
+
const int64_t i_start,
|
| 84 |
+
const int64_t chunk_size,
|
| 85 |
+
const int64_t n) {
|
| 86 |
+
#pragma unroll
|
| 87 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 88 |
+
const auto i = i_start + threadIdx.x + ii * blockDim.x;
|
| 89 |
+
for (int r_index = 0; r_index < depth; r_index++) {
|
| 90 |
+
r_args[r_index][ii] = 0;
|
| 91 |
+
if (i < n && i < chunk_size) {
|
| 92 |
+
r_args[r_index][ii] = args[r_index][i];
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
template <typename T>
|
| 99 |
+
__device__ void store_args(
|
| 100 |
+
T* dst,
|
| 101 |
+
T* src,
|
| 102 |
+
const int64_t i_start,
|
| 103 |
+
const int64_t chunk_size,
|
| 104 |
+
const int64_t n) {
|
| 105 |
+
#pragma unroll
|
| 106 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 107 |
+
const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
|
| 108 |
+
if (i < n && i < chunk_size)
|
| 109 |
+
dst[i] = src[ii];
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
template <int res_arg_index, typename Op, typename T, typename opmath_t>
|
| 114 |
+
__device__ __forceinline__ void binary_op_scalar(
|
| 115 |
+
T r_args[][kILP],
|
| 116 |
+
T** args,
|
| 117 |
+
opmath_t scalar,
|
| 118 |
+
const int64_t n,
|
| 119 |
+
const int64_t chunk_size,
|
| 120 |
+
const bool all_aligned,
|
| 121 |
+
Op op) {
|
| 122 |
+
// to make things simple, we put aligned case in a different code path
|
| 123 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 124 |
+
for (int64_t i_start = threadIdx.x;
|
| 125 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 126 |
+
i_start += blockDim.x) {
|
| 127 |
+
// load
|
| 128 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 129 |
+
#pragma unroll
|
| 130 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 131 |
+
r_args[0][ii] = static_cast<T>(
|
| 132 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 133 |
+
static_cast<opmath_t>(scalar)));
|
| 134 |
+
}
|
| 135 |
+
// store
|
| 136 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 137 |
+
}
|
| 138 |
+
} else {
|
| 139 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 140 |
+
i_start += blockDim.x * kILP) {
|
| 141 |
+
// Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args
|
| 142 |
+
// has depth 1
|
| 143 |
+
load_args<1>(r_args, args, i_start, chunk_size, n);
|
| 144 |
+
#pragma unroll
|
| 145 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 146 |
+
r_args[0][ii] = static_cast<T>(
|
| 147 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 148 |
+
static_cast<opmath_t>(scalar)));
|
| 149 |
+
}
|
| 150 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template <int res_arg_index, typename Op, typename T, typename opmath_t>
|
| 156 |
+
__device__ __forceinline__ void pointwise_op_scalar(
|
| 157 |
+
T r_args[][kILP],
|
| 158 |
+
T** args,
|
| 159 |
+
opmath_t scalar,
|
| 160 |
+
const int64_t n,
|
| 161 |
+
const int64_t chunk_size,
|
| 162 |
+
const bool all_aligned,
|
| 163 |
+
Op op) {
|
| 164 |
+
// to make things simple, we put aligned case in a different code path
|
| 165 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 166 |
+
for (int64_t i_start = threadIdx.x;
|
| 167 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 168 |
+
i_start += blockDim.x) {
|
| 169 |
+
// load
|
| 170 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 171 |
+
load_store(r_args[1], args[1], 0, i_start);
|
| 172 |
+
load_store(r_args[2], args[2], 0, i_start);
|
| 173 |
+
#pragma unroll
|
| 174 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 175 |
+
r_args[0][ii] = static_cast<T>(
|
| 176 |
+
static_cast<opmath_t>(r_args[0][ii]) +
|
| 177 |
+
scalar *
|
| 178 |
+
op(static_cast<opmath_t>(r_args[1][ii]),
|
| 179 |
+
static_cast<opmath_t>(r_args[2][ii])));
|
| 180 |
+
}
|
| 181 |
+
// store
|
| 182 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 183 |
+
}
|
| 184 |
+
} else {
|
| 185 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 186 |
+
i_start += blockDim.x * kILP) {
|
| 187 |
+
// Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args
|
| 188 |
+
// has depth 3
|
| 189 |
+
load_args<3>(r_args, args, i_start, chunk_size, n);
|
| 190 |
+
#pragma unroll
|
| 191 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 192 |
+
r_args[0][ii] = static_cast<T>(
|
| 193 |
+
static_cast<opmath_t>(r_args[0][ii]) +
|
| 194 |
+
scalar *
|
| 195 |
+
op(static_cast<opmath_t>(r_args[1][ii]),
|
| 196 |
+
static_cast<opmath_t>(r_args[2][ii])));
|
| 197 |
+
}
|
| 198 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
//
|
| 204 |
+
// Binary Functors
|
| 205 |
+
//
|
| 206 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 207 |
+
struct BinaryOpScalarFunctor {
|
| 208 |
+
using opmath_t = at::opmath_type<T>;
|
| 209 |
+
template <typename Op>
|
| 210 |
+
__device__ __forceinline__ void operator()(
|
| 211 |
+
int chunk_size,
|
| 212 |
+
TensorListMetadata<depth>& tl,
|
| 213 |
+
Op op,
|
| 214 |
+
opmath_t scalar) {
|
| 215 |
+
const int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 216 |
+
const int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 217 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 218 |
+
|
| 219 |
+
T* args[depth];
|
| 220 |
+
const bool all_aligned =
|
| 221 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 222 |
+
n -= chunk_idx * chunk_size;
|
| 223 |
+
T r_args[r_args_depth][kILP];
|
| 224 |
+
|
| 225 |
+
binary_op_scalar<res_arg_index>(
|
| 226 |
+
r_args, args, scalar, n, chunk_size, all_aligned, op);
|
| 227 |
+
}
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 231 |
+
struct BinaryOpScalarListFunctor {
|
| 232 |
+
using opmath_t = at::opmath_type<T>;
|
| 233 |
+
template <typename Op>
|
| 234 |
+
__device__ __forceinline__ void operator()(
|
| 235 |
+
int chunk_size,
|
| 236 |
+
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
| 237 |
+
Op op) {
|
| 238 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 239 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 240 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 241 |
+
|
| 242 |
+
T* args[depth];
|
| 243 |
+
const bool all_aligned =
|
| 244 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 245 |
+
opmath_t scalar = tl.scalar_vals[tensor_loc];
|
| 246 |
+
n -= chunk_idx * chunk_size;
|
| 247 |
+
T r_args[r_args_depth][kILP];
|
| 248 |
+
|
| 249 |
+
binary_op_scalar<res_arg_index>(
|
| 250 |
+
r_args, args, scalar, n, chunk_size, all_aligned, op);
|
| 251 |
+
}
|
| 252 |
+
};
|
| 253 |
+
|
| 254 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 255 |
+
struct BinaryOpListAlphaFunctor {
|
| 256 |
+
using opmath_t = at::opmath_type<T>;
|
| 257 |
+
template <typename Op>
|
| 258 |
+
__device__ __forceinline__ void operator()(
|
| 259 |
+
int chunk_size,
|
| 260 |
+
TensorListMetadata<depth>& tl,
|
| 261 |
+
Op op,
|
| 262 |
+
opmath_t alpha) {
|
| 263 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 264 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 265 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 266 |
+
|
| 267 |
+
T* args[depth];
|
| 268 |
+
const bool all_aligned =
|
| 269 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 270 |
+
n -= chunk_idx * chunk_size;
|
| 271 |
+
T r_args[r_args_depth][kILP];
|
| 272 |
+
|
| 273 |
+
// to make things simple, we put aligned case in a different code path
|
| 274 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 275 |
+
for (int64_t i_start = threadIdx.x;
|
| 276 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 277 |
+
i_start += blockDim.x) {
|
| 278 |
+
// load
|
| 279 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 280 |
+
load_store(r_args[1], args[1], 0, i_start);
|
| 281 |
+
#pragma unroll
|
| 282 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 283 |
+
r_args[0][ii] = static_cast<T>(
|
| 284 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 285 |
+
alpha * static_cast<opmath_t>(r_args[1][ii])));
|
| 286 |
+
}
|
| 287 |
+
// store
|
| 288 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 289 |
+
}
|
| 290 |
+
} else {
|
| 291 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 292 |
+
i_start += blockDim.x * kILP) {
|
| 293 |
+
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
| 294 |
+
#pragma unroll
|
| 295 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 296 |
+
r_args[0][ii] = static_cast<T>(
|
| 297 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 298 |
+
alpha * static_cast<opmath_t>(r_args[1][ii])));
|
| 299 |
+
}
|
| 300 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
};
|
| 305 |
+
|
| 306 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 307 |
+
struct BinaryOpScalarTensorFunctor {
|
| 308 |
+
using opmath_t = at::opmath_type<T>;
|
| 309 |
+
template <typename Op>
|
| 310 |
+
__device__ __forceinline__ void operator()(
|
| 311 |
+
int chunk_size,
|
| 312 |
+
TensorListMetadata<depth>& tl,
|
| 313 |
+
Op op,
|
| 314 |
+
T* scalar,
|
| 315 |
+
opmath_t alpha) {
|
| 316 |
+
const int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 317 |
+
const int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 318 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 319 |
+
|
| 320 |
+
T* args[depth];
|
| 321 |
+
const bool all_aligned =
|
| 322 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 323 |
+
n -= chunk_idx * chunk_size;
|
| 324 |
+
T r_args[r_args_depth][kILP];
|
| 325 |
+
|
| 326 |
+
// to make things simple, we put aligned case in a different code path
|
| 327 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 328 |
+
for (int64_t i_start = threadIdx.x;
|
| 329 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 330 |
+
i_start += blockDim.x) {
|
| 331 |
+
// load
|
| 332 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 333 |
+
#pragma unroll
|
| 334 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 335 |
+
r_args[0][ii] = static_cast<T>(op(
|
| 336 |
+
static_cast<opmath_t>(r_args[0][ii]),
|
| 337 |
+
static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
|
| 338 |
+
}
|
| 339 |
+
// store
|
| 340 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 341 |
+
}
|
| 342 |
+
} else {
|
| 343 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 344 |
+
i_start += blockDim.x * kILP) {
|
| 345 |
+
// Regardless if depth is 1 (for inplace) or 2 (for out of place),
|
| 346 |
+
// r_args has depth 1
|
| 347 |
+
load_args<1>(r_args, args, i_start, chunk_size, n);
|
| 348 |
+
#pragma unroll
|
| 349 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 350 |
+
r_args[0][ii] = static_cast<T>(op(
|
| 351 |
+
static_cast<opmath_t>(r_args[0][ii]),
|
| 352 |
+
static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
|
| 353 |
+
}
|
| 354 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
};
|
| 359 |
+
|
| 360 |
+
//
|
| 361 |
+
// Unary Functors
|
| 362 |
+
//
|
| 363 |
+
|
| 364 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 365 |
+
struct ZeroFunctor {
|
| 366 |
+
__device__ __forceinline__ void operator()(
|
| 367 |
+
int chunk_size,
|
| 368 |
+
TensorListMetadata<1>& tl) {
|
| 369 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 370 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 371 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 372 |
+
|
| 373 |
+
T* args[depth];
|
| 374 |
+
const auto all_aligned =
|
| 375 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 376 |
+
n -= chunk_idx * chunk_size;
|
| 377 |
+
T r_args[r_args_depth][kILP];
|
| 378 |
+
|
| 379 |
+
// to make things simple, we put aligned case in a different code path
|
| 380 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 381 |
+
for (int64_t i_start = threadIdx.x;
|
| 382 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 383 |
+
i_start += blockDim.x) {
|
| 384 |
+
#pragma unroll
|
| 385 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 386 |
+
r_args[0][ii] = 0;
|
| 387 |
+
}
|
| 388 |
+
// store
|
| 389 |
+
load_store(args[0], r_args[0], i_start, 0);
|
| 390 |
+
}
|
| 391 |
+
} else {
|
| 392 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 393 |
+
i_start += blockDim.x * kILP) {
|
| 394 |
+
#pragma unroll
|
| 395 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 396 |
+
r_args[0][ii] = 0;
|
| 397 |
+
}
|
| 398 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
};
|
| 403 |
+
|
| 404 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 405 |
+
struct UnaryOpFunctor {
|
| 406 |
+
using opmath_t = at::opmath_type<T>;
|
| 407 |
+
template <typename Op>
|
| 408 |
+
__device__ __forceinline__ void operator()(
|
| 409 |
+
int chunk_size,
|
| 410 |
+
TensorListMetadata<depth>& tl,
|
| 411 |
+
Op op) {
|
| 412 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 413 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 414 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 415 |
+
|
| 416 |
+
T* args[depth];
|
| 417 |
+
bool all_aligned =
|
| 418 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 419 |
+
n -= chunk_idx * chunk_size;
|
| 420 |
+
T r_args[r_args_depth][kILP];
|
| 421 |
+
|
| 422 |
+
// to make things simple, we put aligned case in a different code path
|
| 423 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 424 |
+
for (int64_t i_start = threadIdx.x;
|
| 425 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 426 |
+
i_start += blockDim.x) {
|
| 427 |
+
// load
|
| 428 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 429 |
+
#pragma unroll
|
| 430 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 431 |
+
r_args[0][ii] =
|
| 432 |
+
static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
|
| 433 |
+
}
|
| 434 |
+
// store
|
| 435 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 436 |
+
}
|
| 437 |
+
} else {
|
| 438 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 439 |
+
i_start += blockDim.x * kILP) {
|
| 440 |
+
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
| 441 |
+
#pragma unroll
|
| 442 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 443 |
+
r_args[0][ii] =
|
| 444 |
+
static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
|
| 445 |
+
}
|
| 446 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
}
|
| 450 |
+
};
|
| 451 |
+
|
| 452 |
+
//
|
| 453 |
+
// Pointwise Functors
|
| 454 |
+
//
|
| 455 |
+
|
| 456 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 457 |
+
struct PointwiseOpScalarFunctor {
|
| 458 |
+
using opmath_t = at::opmath_type<T>;
|
| 459 |
+
template <typename Op>
|
| 460 |
+
__device__ __forceinline__ void operator()(
|
| 461 |
+
int chunk_size,
|
| 462 |
+
TensorListMetadata<depth>& tl,
|
| 463 |
+
Op op,
|
| 464 |
+
opmath_t scalar) {
|
| 465 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 466 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 467 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 468 |
+
|
| 469 |
+
T* args[depth];
|
| 470 |
+
const bool all_aligned =
|
| 471 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 472 |
+
n -= chunk_idx * chunk_size;
|
| 473 |
+
T r_args[r_args_depth][kILP];
|
| 474 |
+
|
| 475 |
+
pointwise_op_scalar<res_arg_index>(
|
| 476 |
+
r_args, args, scalar, n, chunk_size, all_aligned, op);
|
| 477 |
+
}
|
| 478 |
+
};
|
| 479 |
+
|
| 480 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 481 |
+
struct PointwiseOpScalarListFunctor {
|
| 482 |
+
using opmath_t = at::opmath_type<T>;
|
| 483 |
+
template <typename Op>
|
| 484 |
+
__device__ __forceinline__ void operator()(
|
| 485 |
+
int chunk_size,
|
| 486 |
+
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
| 487 |
+
Op op) {
|
| 488 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 489 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 490 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 491 |
+
|
| 492 |
+
T* args[depth];
|
| 493 |
+
const bool all_aligned =
|
| 494 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 495 |
+
opmath_t scalar = tl.scalar_vals[tensor_loc];
|
| 496 |
+
n -= chunk_idx * chunk_size;
|
| 497 |
+
T r_args[r_args_depth][kILP];
|
| 498 |
+
|
| 499 |
+
pointwise_op_scalar<res_arg_index>(
|
| 500 |
+
r_args, args, scalar, n, chunk_size, all_aligned, op);
|
| 501 |
+
}
|
| 502 |
+
};
|
| 503 |
+
|
| 504 |
+
template <typename T, int depth>
|
| 505 |
+
struct PointwiseOpListFunctor {
|
| 506 |
+
using opmath_t = at::opmath_type<T>;
|
| 507 |
+
template <typename Op>
|
| 508 |
+
__device__ __forceinline__ void operator()(
|
| 509 |
+
int chunk_size,
|
| 510 |
+
TensorListMetadata<depth>& tl,
|
| 511 |
+
Op op) {
|
| 512 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 513 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 514 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 515 |
+
|
| 516 |
+
T* args[depth];
|
| 517 |
+
const bool all_aligned =
|
| 518 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 519 |
+
n -= chunk_idx * chunk_size;
|
| 520 |
+
T r_args[depth - 1][kILP];
|
| 521 |
+
|
| 522 |
+
// to make things simple, we put aligned case in a different code path
|
| 523 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 524 |
+
for (int64_t i_start = threadIdx.x;
|
| 525 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 526 |
+
i_start += blockDim.x) {
|
| 527 |
+
// load
|
| 528 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 529 |
+
load_store(r_args[1], args[1], 0, i_start);
|
| 530 |
+
#pragma unroll
|
| 531 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 532 |
+
r_args[0][ii] = static_cast<T>(
|
| 533 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 534 |
+
static_cast<opmath_t>(r_args[1][ii])));
|
| 535 |
+
}
|
| 536 |
+
// store
|
| 537 |
+
load_store(args[2], r_args[0], i_start, 0);
|
| 538 |
+
}
|
| 539 |
+
} else {
|
| 540 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 541 |
+
i_start += blockDim.x * kILP) {
|
| 542 |
+
load_args<depth - 1>(r_args, args, i_start, chunk_size, n);
|
| 543 |
+
#pragma unroll
|
| 544 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 545 |
+
r_args[0][ii] = static_cast<T>(
|
| 546 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 547 |
+
static_cast<opmath_t>(r_args[1][ii])));
|
| 548 |
+
}
|
| 549 |
+
store_args(args[2], r_args[0], i_start, chunk_size, n);
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
}
|
| 553 |
+
};
|
| 554 |
+
|
| 555 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 556 |
+
struct TernaryOpListFunctor {
|
| 557 |
+
using opmath_t = at::opmath_type<T>;
|
| 558 |
+
template <typename Op>
|
| 559 |
+
__device__ __forceinline__ void operator()(
|
| 560 |
+
int chunk_size,
|
| 561 |
+
TensorListMetadata<depth>& tl,
|
| 562 |
+
Op op) {
|
| 563 |
+
static_assert(depth == 3 || depth == 4, "");
|
| 564 |
+
static_assert(depth >= r_args_depth, "");
|
| 565 |
+
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
|
| 566 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 567 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 568 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 569 |
+
|
| 570 |
+
T* args[depth];
|
| 571 |
+
const bool all_aligned =
|
| 572 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 573 |
+
n -= chunk_idx * chunk_size;
|
| 574 |
+
T r_args[r_args_depth][kILP];
|
| 575 |
+
|
| 576 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 577 |
+
for (int64_t i_start = threadIdx.x;
|
| 578 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 579 |
+
i_start += blockDim.x) {
|
| 580 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 581 |
+
load_store(r_args[1], args[1], 0, i_start);
|
| 582 |
+
load_store(r_args[2], args[2], 0, i_start);
|
| 583 |
+
#pragma unroll
|
| 584 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 585 |
+
r_args[0][ii] =
|
| 586 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 587 |
+
static_cast<opmath_t>(r_args[1][ii]),
|
| 588 |
+
static_cast<opmath_t>(r_args[2][ii]));
|
| 589 |
+
}
|
| 590 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 591 |
+
}
|
| 592 |
+
} else {
|
| 593 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 594 |
+
i_start += blockDim.x * kILP) {
|
| 595 |
+
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
| 596 |
+
#pragma unroll
|
| 597 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 598 |
+
r_args[0][ii] =
|
| 599 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 600 |
+
static_cast<opmath_t>(r_args[1][ii]),
|
| 601 |
+
static_cast<opmath_t>(r_args[2][ii]));
|
| 602 |
+
}
|
| 603 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 604 |
+
}
|
| 605 |
+
}
|
| 606 |
+
}
|
| 607 |
+
};
|
| 608 |
+
|
| 609 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 610 |
+
struct TernaryOpScalarFunctor {
|
| 611 |
+
using opmath_t = at::opmath_type<T>;
|
| 612 |
+
template <typename Op>
|
| 613 |
+
__device__ __forceinline__ void operator()(
|
| 614 |
+
int chunk_size,
|
| 615 |
+
TensorListMetadata<depth>& tl,
|
| 616 |
+
Op op,
|
| 617 |
+
opmath_t alpha) {
|
| 618 |
+
static_assert(depth == 2 || depth == 3, "");
|
| 619 |
+
static_assert(depth >= r_args_depth, "");
|
| 620 |
+
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
|
| 621 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 622 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 623 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 624 |
+
|
| 625 |
+
T* args[depth];
|
| 626 |
+
const bool all_aligned =
|
| 627 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 628 |
+
n -= chunk_idx * chunk_size;
|
| 629 |
+
T r_args[r_args_depth][kILP];
|
| 630 |
+
|
| 631 |
+
// to make things simple, we put aligned case in a different code path
|
| 632 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 633 |
+
for (int64_t i_start = threadIdx.x;
|
| 634 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 635 |
+
i_start += blockDim.x) {
|
| 636 |
+
// load
|
| 637 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 638 |
+
load_store(r_args[1], args[1], 0, i_start);
|
| 639 |
+
#pragma unroll
|
| 640 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 641 |
+
r_args[0][ii] =
|
| 642 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 643 |
+
static_cast<opmath_t>(r_args[1][ii]),
|
| 644 |
+
alpha);
|
| 645 |
+
}
|
| 646 |
+
// store
|
| 647 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 648 |
+
}
|
| 649 |
+
} else {
|
| 650 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 651 |
+
i_start += blockDim.x * kILP) {
|
| 652 |
+
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
| 653 |
+
#pragma unroll
|
| 654 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 655 |
+
r_args[0][ii] =
|
| 656 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 657 |
+
static_cast<opmath_t>(r_args[1][ii]),
|
| 658 |
+
alpha);
|
| 659 |
+
}
|
| 660 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 661 |
+
}
|
| 662 |
+
}
|
| 663 |
+
}
|
| 664 |
+
};
|
| 665 |
+
|
| 666 |
+
template <typename T>
|
| 667 |
+
struct power_functor {
|
| 668 |
+
C10_DEVICE T operator()(const T& a, const T& b) const {
|
| 669 |
+
return at::native::pow_(a, b);
|
| 670 |
+
}
|
| 671 |
+
};
|
| 672 |
+
|
| 673 |
+
template <typename T>
|
| 674 |
+
struct reverse_power_functor {
|
| 675 |
+
C10_DEVICE T operator()(const T& a, const T& b) const {
|
| 676 |
+
return at::native::pow_(b, a);
|
| 677 |
+
}
|
| 678 |
+
};
|
| 679 |
+
|
| 680 |
+
} // namespace
|
| 681 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/NumericUtils.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
// std:: does not have clamp functors
|
| 8 |
+
template <typename T>
|
| 9 |
+
struct minimum {
|
| 10 |
+
__device__ T operator()(const T& a, const T& b) const {
|
| 11 |
+
return (_isnan(a) || a < b) ? a : b;
|
| 12 |
+
}
|
| 13 |
+
};
|
| 14 |
+
|
| 15 |
+
template <typename T>
|
| 16 |
+
struct maximum {
|
| 17 |
+
__device__ T operator()(const T& a, const T& b) const {
|
| 18 |
+
return (_isnan(a) || a > b) ? a : b;
|
| 19 |
+
}
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/cuda/KernelUtils.cuh>
|
| 3 |
+
#include <ATen/native/GridSamplerUtils.h>
|
| 4 |
+
|
| 5 |
+
namespace at { namespace native {
|
| 6 |
+
|
| 7 |
+
using detail::GridSamplerInterpolation;
|
| 8 |
+
using detail::GridSamplerPadding;
|
| 9 |
+
|
| 10 |
+
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
|
| 11 |
+
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
|
| 12 |
+
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
|
| 13 |
+
// -1 --> 0
|
| 14 |
+
// +1 --> (size - 1)
|
| 15 |
+
// scale_factor = (size - 1) / 2
|
| 16 |
+
// if not align_corners: -1 and +1 get sent to the image edges
|
| 17 |
+
// -1 --> -0.5
|
| 18 |
+
// +1 --> (size - 1) + 0.5 == size - 0.5
|
| 19 |
+
// scale_factor = size / 2
|
| 20 |
+
template <typename scalar_t>
|
| 21 |
+
static __forceinline__ __device__
|
| 22 |
+
scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
|
| 23 |
+
if (align_corners) {
|
| 24 |
+
// unnormalize coord from [-1, 1] to [0, size - 1]
|
| 25 |
+
return ((coord + 1.f) / 2) * (size - 1);
|
| 26 |
+
} else {
|
| 27 |
+
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
|
| 28 |
+
return ((coord + 1.f) * size - 1) / 2;
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
|
| 33 |
+
// except that it also returns the `d output / d input` via pointer argument
|
| 34 |
+
// `grad_in`.
|
| 35 |
+
// This is useful in the backward pass of grid_sampler.
|
| 36 |
+
template <typename scalar_t>
|
| 37 |
+
static __forceinline__ __device__
|
| 38 |
+
scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size,
|
| 39 |
+
bool align_corners, scalar_t *grad_in) {
|
| 40 |
+
if (align_corners) {
|
| 41 |
+
// unnormalize coord from [-1, 1] to [0, size - 1]
|
| 42 |
+
*grad_in = static_cast<scalar_t>(size - 1) / 2;
|
| 43 |
+
return ((coord + 1.f) / 2) * (size - 1);
|
| 44 |
+
} else {
|
| 45 |
+
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
|
| 46 |
+
*grad_in = static_cast<scalar_t>(size) / 2;
|
| 47 |
+
return ((coord + 1.f) * size - 1) / 2;
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// Clips coordinates to between 0 and clip_limit - 1
|
| 52 |
+
template <typename scalar_t>
|
| 53 |
+
static __forceinline__ __device__
|
| 54 |
+
scalar_t clip_coordinates(scalar_t in, int clip_limit) {
|
| 55 |
+
return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// clip_coordinates_set_grad works similarly to clip_coordinates except that
|
| 59 |
+
// it also returns the `d output / d input` via pointer argument `grad_in`.
|
| 60 |
+
// This is useful in the backward pass of grid_sampler.
|
| 61 |
+
template <typename scalar_t>
|
| 62 |
+
static __forceinline__ __device__
|
| 63 |
+
scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) {
|
| 64 |
+
// Note that it is important for the gradient calculation that borders
|
| 65 |
+
// are considered out of bounds.
|
| 66 |
+
if (in <= static_cast<scalar_t>(0)) {
|
| 67 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 68 |
+
return static_cast<scalar_t>(0);
|
| 69 |
+
} else {
|
| 70 |
+
scalar_t max = static_cast<scalar_t>(clip_limit - 1);
|
| 71 |
+
if (in >= max) {
|
| 72 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 73 |
+
return max;
|
| 74 |
+
} else {
|
| 75 |
+
*grad_in = static_cast<scalar_t>(1);
|
| 76 |
+
return in;
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// Reflects coordinates until they fall between low and high (inclusive).
|
| 82 |
+
// The bounds are passed as twice their value so that half-integer values
|
| 83 |
+
// can be represented as ints.
|
| 84 |
+
template <typename scalar_t>
|
| 85 |
+
static __forceinline__ __device__
|
| 86 |
+
scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
|
| 87 |
+
if (twice_low == twice_high) {
|
| 88 |
+
return static_cast<scalar_t>(0);
|
| 89 |
+
}
|
| 90 |
+
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
|
| 91 |
+
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
|
| 92 |
+
in = ::fabs(in - min);
|
| 93 |
+
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
|
| 94 |
+
scalar_t extra = ::fmod(in, span);
|
| 95 |
+
int flips = static_cast<int>(::floor(in / span));
|
| 96 |
+
if (flips % 2 == 0) {
|
| 97 |
+
return extra + min;
|
| 98 |
+
} else {
|
| 99 |
+
return span - extra + min;
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// reflect_coordinates_set_grad works similarly to reflect_coordinates except
|
| 104 |
+
// that it also returns the `d output / d input` via pointer argument
|
| 105 |
+
// `grad_in`.
|
| 106 |
+
// This is useful in the backward pass of grid_sampler.
|
| 107 |
+
template <typename scalar_t>
|
| 108 |
+
static __forceinline__ __device__
|
| 109 |
+
scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high,
|
| 110 |
+
scalar_t *grad_in) {
|
| 111 |
+
if (twice_low == twice_high) {
|
| 112 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 113 |
+
return static_cast<scalar_t>(0);
|
| 114 |
+
}
|
| 115 |
+
int grad_in_mult_;
|
| 116 |
+
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
|
| 117 |
+
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
|
| 118 |
+
in = in - min;
|
| 119 |
+
if (in < static_cast<scalar_t>(0)) {
|
| 120 |
+
grad_in_mult_ = -1;
|
| 121 |
+
in = -in;
|
| 122 |
+
} else {
|
| 123 |
+
grad_in_mult_ = 1;
|
| 124 |
+
}
|
| 125 |
+
// `fmod` returns same sign as `in`, which is positive after the `if` above.
|
| 126 |
+
scalar_t extra = ::fmod(in, span);
|
| 127 |
+
int flips = static_cast<int>(::floor(in / span));
|
| 128 |
+
if (flips % 2 == 0) {
|
| 129 |
+
*grad_in = static_cast<scalar_t>(grad_in_mult_);
|
| 130 |
+
return extra + min;
|
| 131 |
+
} else {
|
| 132 |
+
*grad_in = static_cast<scalar_t>(-grad_in_mult_);
|
| 133 |
+
return span - extra + min;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
template<typename scalar_t>
|
| 138 |
+
static __forceinline__ __device__
|
| 139 |
+
scalar_t safe_downgrade_to_int_range(scalar_t x){
|
| 140 |
+
// -100.0 does not have special meaning. This is just to make sure
|
| 141 |
+
// it's not within_bounds_2d or within_bounds_3d, and does not cause
|
| 142 |
+
// undefined behavior. See #35506.
|
| 143 |
+
if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast<double>(x)))
|
| 144 |
+
return static_cast<scalar_t>(-100.0);
|
| 145 |
+
return x;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
template<typename scalar_t>
|
| 149 |
+
static __forceinline__ __device__
|
| 150 |
+
scalar_t compute_coordinates(scalar_t coord, int size,
|
| 151 |
+
GridSamplerPadding padding_mode,
|
| 152 |
+
bool align_corners) {
|
| 153 |
+
if (padding_mode == GridSamplerPadding::Border) {
|
| 154 |
+
// clip coordinates to image borders
|
| 155 |
+
coord = clip_coordinates(coord, size);
|
| 156 |
+
} else if (padding_mode == GridSamplerPadding::Reflection) {
|
| 157 |
+
// reflect coordinates by image borders
|
| 158 |
+
if (align_corners) {
|
| 159 |
+
coord = reflect_coordinates(coord, 0, 2*(size - 1));
|
| 160 |
+
} else {
|
| 161 |
+
coord = reflect_coordinates(coord, -1, 2*size - 1);
|
| 162 |
+
}
|
| 163 |
+
// clip coordinates to image borders
|
| 164 |
+
coord = clip_coordinates(coord, size);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
coord = safe_downgrade_to_int_range(coord);
|
| 168 |
+
return coord;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
// Computes the pixel source index value for a grid coordinate
|
| 172 |
+
template <typename scalar_t>
|
| 173 |
+
static __forceinline__ __device__
|
| 174 |
+
scalar_t grid_sampler_compute_source_index(
|
| 175 |
+
scalar_t coord,
|
| 176 |
+
int size,
|
| 177 |
+
GridSamplerPadding padding_mode,
|
| 178 |
+
bool align_corners) {
|
| 179 |
+
coord = grid_sampler_unnormalize(coord, size, align_corners);
|
| 180 |
+
coord = compute_coordinates(coord, size, padding_mode, align_corners);
|
| 181 |
+
return coord;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
// grid_sampler_compute_source_index_set_grad works similarly to
|
| 185 |
+
// grid_sampler_compute_source_index except that it also returns the
|
| 186 |
+
// `d output / d input` via pointer argument `grad_in`.
|
| 187 |
+
// This is useful in the backward pass of grid_sampler.
|
| 188 |
+
template <typename scalar_t>
|
| 189 |
+
static __forceinline__ __device__
|
| 190 |
+
scalar_t grid_sampler_compute_source_index_set_grad(
|
| 191 |
+
scalar_t coord,
|
| 192 |
+
int size,
|
| 193 |
+
GridSamplerPadding padding_mode,
|
| 194 |
+
bool align_corners,
|
| 195 |
+
scalar_t *grad_in) {
|
| 196 |
+
scalar_t grad_clip, grad_refl;
|
| 197 |
+
coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
|
| 198 |
+
if (padding_mode == GridSamplerPadding::Border) {
|
| 199 |
+
// clip coordinates to image borders
|
| 200 |
+
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
|
| 201 |
+
*grad_in = (*grad_in) * grad_clip;
|
| 202 |
+
} else if (padding_mode == GridSamplerPadding::Reflection) {
|
| 203 |
+
// reflect coordinates by image borders
|
| 204 |
+
if (align_corners) {
|
| 205 |
+
coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
|
| 206 |
+
} else {
|
| 207 |
+
coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
|
| 208 |
+
}
|
| 209 |
+
// clip coordinates to image borders
|
| 210 |
+
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
|
| 211 |
+
*grad_in = (*grad_in) * grad_refl * grad_clip;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
coord = safe_downgrade_to_int_range(coord);
|
| 215 |
+
return coord;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
static __forceinline__ __device__
|
| 219 |
+
bool within_bounds_2d(int h, int w, int H, int W) {
|
| 220 |
+
return h >= 0 && h < H && w >= 0 && w < W;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
static __forceinline__ __device__
|
| 224 |
+
bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
|
| 225 |
+
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
template<typename scalar_t>
|
| 229 |
+
static __forceinline__ __device__
|
| 230 |
+
scalar_t get_value_bounded(
|
| 231 |
+
scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
|
| 232 |
+
GridSamplerPadding padding_mode,
|
| 233 |
+
bool align_corners) {
|
| 234 |
+
|
| 235 |
+
x = compute_coordinates(x, W, padding_mode, align_corners);
|
| 236 |
+
y = compute_coordinates(y, H, padding_mode, align_corners);
|
| 237 |
+
|
| 238 |
+
int ix = static_cast<int>(x);
|
| 239 |
+
int iy = static_cast<int>(y);
|
| 240 |
+
|
| 241 |
+
if (within_bounds_2d(iy, ix, H, W)) {
|
| 242 |
+
return data[iy * sH + ix * sW];
|
| 243 |
+
}
|
| 244 |
+
return static_cast<scalar_t>(0);
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
template<typename scalar_t, typename index_t>
|
| 248 |
+
static __forceinline__ __device__
|
| 249 |
+
void safe_add_2d(scalar_t *data, int h, int w,
|
| 250 |
+
int sH, int sW, int H, int W,
|
| 251 |
+
scalar_t delta,
|
| 252 |
+
const index_t NC_offset,
|
| 253 |
+
const index_t memory_span) {
|
| 254 |
+
if (within_bounds_2d(h, w, H, W)) {
|
| 255 |
+
fastAtomicAdd(data,
|
| 256 |
+
NC_offset + h * sH + w * sW,
|
| 257 |
+
memory_span,
|
| 258 |
+
delta,
|
| 259 |
+
true);
|
| 260 |
+
}
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
template<typename scalar_t, typename index_t>
|
| 264 |
+
static __forceinline__ __device__
|
| 265 |
+
void safe_add_3d(scalar_t *data, int d, int h, int w,
|
| 266 |
+
int sD, int sH, int sW, int D, int H, int W,
|
| 267 |
+
scalar_t delta,
|
| 268 |
+
const index_t NC_offset,
|
| 269 |
+
const index_t memory_span) {
|
| 270 |
+
if (within_bounds_3d(d, h, w, D, H, W)) {
|
| 271 |
+
fastAtomicAdd(data,
|
| 272 |
+
NC_offset + d * sD + h * sH + w * sW,
|
| 273 |
+
memory_span,
|
| 274 |
+
delta,
|
| 275 |
+
true);
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
template<typename scalar_t, typename index_t>
|
| 280 |
+
static __forceinline__ __device__
|
| 281 |
+
void add_value_bounded(
|
| 282 |
+
scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
|
| 283 |
+
scalar_t delta,
|
| 284 |
+
GridSamplerPadding padding_mode,
|
| 285 |
+
bool align_corners,
|
| 286 |
+
const index_t NC_offset,
|
| 287 |
+
const index_t memory_span) {
|
| 288 |
+
|
| 289 |
+
x = compute_coordinates(x, W, padding_mode, align_corners);
|
| 290 |
+
y = compute_coordinates(y, H, padding_mode, align_corners);
|
| 291 |
+
|
| 292 |
+
int ix = static_cast<int>(x);
|
| 293 |
+
int iy = static_cast<int>(y);
|
| 294 |
+
|
| 295 |
+
safe_add_2d(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
|
| 299 |
+
template<typename scalar_t>
|
| 300 |
+
static __forceinline__ __device__
|
| 301 |
+
void get_cubic_coefficients_grad(
|
| 302 |
+
scalar_t coeffs[4],
|
| 303 |
+
scalar_t t) {
|
| 304 |
+
|
| 305 |
+
// Must be the same as forward calculation in
|
| 306 |
+
// aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients
|
| 307 |
+
scalar_t A = -0.75;
|
| 308 |
+
|
| 309 |
+
scalar_t x;
|
| 310 |
+
x = -1 - t; // 1 < x = |-1 - tx| < 2
|
| 311 |
+
coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
|
| 312 |
+
x = -t; // x = |0 - tx| <= 1
|
| 313 |
+
coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
|
| 314 |
+
x = 1 - t; // x = |1 - tx| <= 1
|
| 315 |
+
coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
|
| 316 |
+
x = 2 - t; // 1 < x = |2 - tx| < 2
|
| 317 |
+
coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/jit_macros.h>
|
| 4 |
+
|
| 5 |
+
#if AT_USE_JITERATOR()
|
| 6 |
+
|
| 7 |
+
#include <ATen/cuda/CUDAConfig.h>
|
| 8 |
+
|
| 9 |
+
#include <ATen/OpMathType.h>
|
| 10 |
+
#include <ATen/TensorIterator.h>
|
| 11 |
+
#include <ATen/native/TensorIteratorDynamicCasting.h>
|
| 12 |
+
|
| 13 |
+
#include <ATen/native/cuda/MemoryAccess.cuh>
|
| 14 |
+
|
| 15 |
+
#include <ATen/native/cuda/CUDAJitLoops.cuh>
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
|
| 20 |
+
/* Note [Jiterator]
|
| 21 |
+
The "jiterator" simply just-in-time compiles the same kernels that
|
| 22 |
+
Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time,
|
| 23 |
+
build size, and initial CUDA context size.
|
| 24 |
+
|
| 25 |
+
By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels.
|
| 26 |
+
This behavior is controlled with two environment variables:
|
| 27 |
+
- USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use
|
| 28 |
+
- PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels
|
| 29 |
+
|
| 30 |
+
The jiterator currently has some limitations, however. It cannot:
|
| 31 |
+
- handle math on complex datatypes
|
| 32 |
+
- handle kernels with scalar parameters
|
| 33 |
+
|
| 34 |
+
These improvements will likely come soon.
|
| 35 |
+
|
| 36 |
+
For examples of how to use the jiterator see the i1 and gcd kernel
|
| 37 |
+
implementations, which pass jittable strings implementing their
|
| 38 |
+
operations instead of the typical CUDA functors.
|
| 39 |
+
|
| 40 |
+
To pass a runtime argument (similar to lambda captures in non-JIT kernels),
|
| 41 |
+
we need to pass to additional arguments to `jitted_gpu_kernel` by value.
|
| 42 |
+
Currently only primitive C++ types used for computation are valid.
|
| 43 |
+
The order of these extra arguments should be same as the order they appear
|
| 44 |
+
in kernel's function signature. (look at polygamma for example)
|
| 45 |
+
|
| 46 |
+
NOTE: One big restriction being that these arguments should be after the
|
| 47 |
+
arguments provided by TensorIterator. Eg. While capturing `n`, where
|
| 48 |
+
`scalar_t x` and `scalar_t y` are provided by TensorIterator,
|
| 49 |
+
* foo(scalar_t x, scalar_t y, int n) works!
|
| 50 |
+
* foo(int n, scalar_t x, scalar_y) doesn't work
|
| 51 |
+
* foo(scalar_t x, int n, scalar_y) doesn't work
|
| 52 |
+
|
| 53 |
+
*/
|
| 54 |
+
|
| 55 |
+
// Entrypoint for jitted GPU kernels.
|
| 56 |
+
// Only handles elementwise unary and binary kernels with a
|
| 57 |
+
// common dtype and a single output.
|
| 58 |
+
// NOTE: this assumes the op's iterator has a common_dtype.
|
| 59 |
+
// NOTE: We use std::tuple instead of parameter pack
|
| 60 |
+
// for `extra_args` due to following
|
| 61 |
+
// bug on older versions of clang
|
| 62 |
+
// https://bugs.llvm.org/show_bug.cgi?id=23029
|
| 63 |
+
template <
|
| 64 |
+
char const* name,
|
| 65 |
+
typename return_type,
|
| 66 |
+
typename f_inputs_type,
|
| 67 |
+
int arity,
|
| 68 |
+
typename... Args>
|
| 69 |
+
void jitted_gpu_kernel(
|
| 70 |
+
TensorIteratorBase& iter,
|
| 71 |
+
const std::string& f,
|
| 72 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos =
|
| 73 |
+
at::cuda::jit::BinaryFuncVariant::NoScalar,
|
| 74 |
+
at::opmath_type<f_inputs_type> scalar_val = 0,
|
| 75 |
+
std::tuple<Args...> extra_args = std::make_tuple()) {
|
| 76 |
+
// TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
|
| 77 |
+
// Maybe it could be refactored?
|
| 78 |
+
for (int arg = 0; arg < iter.ntensors(); arg++) {
|
| 79 |
+
TORCH_INTERNAL_ASSERT(
|
| 80 |
+
iter.device(arg).is_cuda(),
|
| 81 |
+
"argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
if (iter.numel() == 0) {
|
| 85 |
+
return;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
if (!iter.can_use_32bit_indexing()) {
|
| 89 |
+
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
| 90 |
+
jitted_gpu_kernel<name, return_type, f_inputs_type, arity>(
|
| 91 |
+
sub_iter, f, scalar_pos, scalar_val, extra_args);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
return;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
// Computes if dynamic casting is needed
|
| 98 |
+
// Dynamic casting is needed if an input's dtype differs from the common dtype
|
| 99 |
+
// or if the result dtype differs from the output's dtype
|
| 100 |
+
// Note: this is intentionally divergent from calling needs_dynamic_casting,
|
| 101 |
+
// which is more general and inspects a lambda to determine if dynamic
|
| 102 |
+
// casting is needed.
|
| 103 |
+
bool needs_dynamic_casting = false;
|
| 104 |
+
|
| 105 |
+
// Checks output
|
| 106 |
+
const ScalarType return_scalar_type = c10::CppTypeToScalarType<return_type>::value;
|
| 107 |
+
const auto dtype0 = iter.dtype(0);
|
| 108 |
+
if (dtype0 != return_scalar_type) {
|
| 109 |
+
needs_dynamic_casting = true;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// Checks input(s)
|
| 113 |
+
const ScalarType inputs_scalar_type = c10::CppTypeToScalarType<f_inputs_type>::value;
|
| 114 |
+
for (auto i = decltype(arity){1}; i < (arity + 1); ++i) {
|
| 115 |
+
const auto dtypei = iter.dtype(i);
|
| 116 |
+
if (dtypei != inputs_scalar_type) {
|
| 117 |
+
needs_dynamic_casting = true;
|
| 118 |
+
break;
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) {
|
| 122 |
+
// NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used
|
| 123 |
+
// for computation in the generated code and hence we pass a dummy
|
| 124 |
+
// value of `0`.
|
| 125 |
+
jitted_gpu_kernel_impl<
|
| 126 |
+
/*name*/ name,
|
| 127 |
+
/*return_type=*/return_type,
|
| 128 |
+
/*f_inputs_type=*/f_inputs_type,
|
| 129 |
+
arity,
|
| 130 |
+
at::cuda::jit::BinaryFuncVariant::NoScalar>(
|
| 131 |
+
iter, f, needs_dynamic_casting, /*scalar_val=*/scalar_val, extra_args);
|
| 132 |
+
} else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) {
|
| 133 |
+
jitted_gpu_kernel_impl<
|
| 134 |
+
/*name*/ name,
|
| 135 |
+
/*return_type=*/return_type,
|
| 136 |
+
/*f_inputs_type=*/f_inputs_type,
|
| 137 |
+
arity,
|
| 138 |
+
at::cuda::jit::BinaryFuncVariant::RhsScalar>(
|
| 139 |
+
iter,
|
| 140 |
+
f,
|
| 141 |
+
needs_dynamic_casting,
|
| 142 |
+
scalar_val,
|
| 143 |
+
extra_args);
|
| 144 |
+
|
| 145 |
+
} else {
|
| 146 |
+
jitted_gpu_kernel_impl<
|
| 147 |
+
/*name*/ name,
|
| 148 |
+
/*return_type=*/return_type,
|
| 149 |
+
/*f_inputs_type=*/f_inputs_type,
|
| 150 |
+
arity,
|
| 151 |
+
at::cuda::jit::BinaryFuncVariant::LhsScalar>(
|
| 152 |
+
iter,
|
| 153 |
+
f,
|
| 154 |
+
needs_dynamic_casting,
|
| 155 |
+
scalar_val,
|
| 156 |
+
extra_args);
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
// TODO: support runtime state capture similar to `jitted_gpu_kernel`.
|
| 161 |
+
template <char const *name, typename return_type, typename f_inputs_type>
|
| 162 |
+
void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) {
|
| 163 |
+
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
|
| 164 |
+
//currently jiterator only handles binary functions where both inputs are of the same type (f_inputs_type)
|
| 165 |
+
using opmath_t = at::opmath_type<f_inputs_type>;
|
| 166 |
+
if (iter.is_cpu_scalar(1)) {
|
| 167 |
+
auto scalar_val = iter.scalar_value<opmath_t>(1);
|
| 168 |
+
iter.remove_operand(1);
|
| 169 |
+
// TODO: When all kernels that use gpu_kernel_with_scalars are
|
| 170 |
+
// ported to structured, this device guard can be deleted. This
|
| 171 |
+
// works around incorrect device guard generation for pre-structured
|
| 172 |
+
// kernels device guards, but structured kernels do it right and
|
| 173 |
+
// we can assume the device is already set correctly
|
| 174 |
+
const OptionalDeviceGuard device_guard(iter.device(1));
|
| 175 |
+
jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val);
|
| 176 |
+
} else if (iter.is_cpu_scalar(2)) {
|
| 177 |
+
auto scalar_val = iter.scalar_value<opmath_t>(2);
|
| 178 |
+
iter.remove_operand(2);
|
| 179 |
+
jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val);
|
| 180 |
+
} else {
|
| 181 |
+
jitted_gpu_kernel<name, return_type, f_inputs_type, 2>(iter, f);
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
}} // at::native
|
| 186 |
+
|
| 187 |
+
#endif // AT_USE_JITERATOR()
|