koichi12 commited on
Commit
49fc886
·
verified ·
1 Parent(s): 22f3d85

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py +656 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/coordinate_descent_tuner.py +315 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/debug.py +655 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py +678 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_operators.py +24 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/virtualized.py +351 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h +98 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h +39 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BucketizationUtils.h +173 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvUtils.h +446 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Cross.h +14 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DistributionTemplates.h +394 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Histogram.h +16 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexKernel.h +41 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexingUtils.h +160 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitsFallback.h +157 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MaxPooling.h +97 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h +27 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Padding.h +62 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PointwiseOps.h +28 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pool.h +340 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RNN.h +53 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Repeat.h +48 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Resize.h +173 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ResizeCommon.h +75 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SharedReduceOps.h +544 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SparseTensorUtils.h +190 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/StridedRandomAccessor.h +301 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h +92 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorDimApply.h +55 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorFactories.h +142 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h +52 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorProperties.h +12 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorShape.h +105 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TopKImpl.h +98 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TransposeType.h +23 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold3d.h +49 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnfoldBackward.h +112 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UpSample.h +506 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h +48 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h +35 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h +494 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh +25 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h +672 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h +25 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh +22 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh +681 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh +22 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh +321 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh +187 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/autotune_process.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import dataclasses
5
+ import functools
6
+ import logging
7
+ import os
8
+ import queue
9
+ import time
10
+ import warnings
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from ctypes import byref, c_size_t, c_void_p
13
+ from multiprocessing.process import BaseProcess
14
+ from multiprocessing.queues import Queue
15
+ from typing import (
16
+ Any,
17
+ Callable,
18
+ Dict,
19
+ Iterable,
20
+ List,
21
+ Optional,
22
+ Sequence,
23
+ TYPE_CHECKING,
24
+ Union,
25
+ )
26
+
27
+ import torch
28
+ from torch import multiprocessing
29
+ from torch._dynamo.testing import rand_strided
30
+
31
+ from torch._inductor import ir
32
+ from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache
33
+
34
+ if TYPE_CHECKING:
35
+ from torch._inductor.select_algorithm import TritonTemplateCaller
36
+
37
+ from . import config
38
+ from .utils import do_bench
39
+ from .virtualized import V
40
+
41
+ CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
42
+ EXIT_HANDLER_REGISTERED = False
43
+
44
+ log = logging.getLogger(__name__)
45
+
46
+
47
+ # Used to synchronize between parent and child processes
48
+ class Ping:
49
+ pass
50
+
51
+
52
+ class Pong:
53
+ pass
54
+
55
+
56
+ @contextlib.contextmanager
57
+ def set_cuda_visible_device(device: Optional[int]):
58
+ """
59
+ Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
60
+ specified single device. If device is None, don't manipulate the environment.
61
+ """
62
+ if device is None:
63
+ yield
64
+ return
65
+
66
+ current = os.environ.get(CUDA_VISIBLE_DEVICES)
67
+ os.environ[CUDA_VISIBLE_DEVICES] = str(device)
68
+ try:
69
+ yield
70
+ finally:
71
+ if current is None:
72
+ del os.environ[CUDA_VISIBLE_DEVICES]
73
+ else:
74
+ os.environ[CUDA_VISIBLE_DEVICES] = current
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class TuningProcess:
79
+ """
80
+ Abstraction for launching a helper process to benchmark kernels. Spawns
81
+ the parent process and uses multiprocessing queues to send benchmark
82
+ requests and return results.
83
+ """
84
+
85
+ device: Optional[int] = None
86
+ process: Optional[BaseProcess] = None
87
+ request_queue: Optional[Queue[Any]] = None
88
+ response_queue: Optional[Queue[Any]] = None
89
+
90
+ @staticmethod
91
+ def process_main(
92
+ request_queue: Queue[Any],
93
+ response_queue: Queue[Any],
94
+ ) -> None:
95
+ """
96
+ Entry point for the child process.
97
+ """
98
+ log.debug(
99
+ "Entering TuningProcess child. Visible devices = %s",
100
+ os.environ.get(CUDA_VISIBLE_DEVICES),
101
+ )
102
+ try:
103
+ TuningProcess.workloop(request_queue, response_queue)
104
+ except Exception as ex:
105
+ log.exception("Exception in TuningProcess: %s", ex)
106
+
107
+ @staticmethod
108
+ def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
109
+ """
110
+ Work loop for the benchmarking subprocess.
111
+ """
112
+ while True:
113
+ obj = request_queue.get()
114
+
115
+ if obj is None:
116
+ break # None is a sentinel for the child to terminate
117
+ elif isinstance(obj, Ping):
118
+ response_queue.put(Pong())
119
+ elif isinstance(obj, BenchmarkRequest):
120
+ response_queue.put(obj.benchmark())
121
+ else:
122
+ raise RuntimeError(f"Invalid request type {type(obj)}")
123
+
124
+ def valid(self) -> bool:
125
+ """
126
+ True if the sub-process has been initialized.
127
+ """
128
+ return (
129
+ self.process is not None
130
+ and self.request_queue is not None
131
+ and self.response_queue is not None
132
+ )
133
+
134
+ def clear(self) -> None:
135
+ """
136
+ Reset to an uninitialized state.
137
+ """
138
+ self.process = self.request_queue = self.response_queue = None
139
+
140
+ def initialize(self) -> None:
141
+ """
142
+ Create child process, request/response queues, and do the warm up.
143
+ Set the environment to make only the provided GPU device visible
144
+ to the process.
145
+ """
146
+ if self.valid():
147
+ return
148
+
149
+ # cuda runtime does not work with "fork", use "spawn" to start processes.
150
+ ctx = multiprocessing.get_context("spawn")
151
+ self.request_queue = ctx.Queue()
152
+ self.response_queue = ctx.Queue()
153
+
154
+ self.process = ctx.Process(
155
+ target=self.process_main,
156
+ args=(
157
+ self.request_queue,
158
+ self.response_queue,
159
+ ),
160
+ )
161
+ assert self.process is not None
162
+ with set_cuda_visible_device(self.device):
163
+ self.process.start()
164
+
165
+ def put(self, obj: Any) -> None:
166
+ """
167
+ Push a work item to the child process.
168
+ """
169
+ # In case of a prior crash, ensure the subprocess is running
170
+ self.initialize()
171
+ assert self.request_queue is not None
172
+ self.request_queue.put(obj)
173
+
174
+ def get(self) -> Any:
175
+ """
176
+ Get a response from the child process.
177
+ """
178
+ assert self.process is not None
179
+ assert self.response_queue is not None
180
+ while True:
181
+ try:
182
+ return self.response_queue.get(timeout=1.0)
183
+ except queue.Empty:
184
+ status = self.process.exitcode
185
+ if status is None:
186
+ # child process is still running
187
+ continue
188
+ # child process crashed
189
+ self.clear()
190
+ raise
191
+
192
+ def terminate(self) -> None:
193
+ """
194
+ Signal the child process to terminate.
195
+ """
196
+ if self.valid():
197
+ assert self.process is not None
198
+ assert self.request_queue is not None
199
+ self.request_queue.put(None)
200
+
201
+ def wait(self) -> None:
202
+ """
203
+ Wait for the child process to exit.
204
+ """
205
+ if self.process is not None:
206
+ self.process.join()
207
+ self.clear()
208
+
209
+
210
+ @dataclasses.dataclass
211
+ class TuningProcessPool:
212
+ """
213
+ Maintains a pool of TuningProcesses to benchmark kernels in parallel
214
+ across devices. By default, we create one TuningProcess per device and
215
+ set the sub-process environment to make only that device visible.
216
+ """
217
+
218
+ processes: Optional[queue.Queue[TuningProcess]] = None
219
+ executor: Optional[ThreadPoolExecutor] = None
220
+
221
+ def initialize(self) -> None:
222
+ """
223
+ Start the child processes.
224
+ """
225
+ assert (self.processes is None) == (self.executor is None)
226
+ if self.processes is not None:
227
+ return
228
+
229
+ devices = self.get_device_list()
230
+ log.debug("Sub-process autotune device list: %s", devices)
231
+
232
+ # Launch the child processes and push a msg to "warm up"
233
+ self.processes = queue.Queue()
234
+ for device in devices:
235
+ p = TuningProcess(device=device)
236
+ p.initialize()
237
+ p.put(Ping())
238
+ self.processes.put(p)
239
+
240
+ # Wait for the initialization to finish
241
+ for p in self.processes.queue:
242
+ assert isinstance(p.get(), Pong)
243
+
244
+ # Use a thread pool to manage distributing work to the subprocesses.
245
+ # Threads block on an available process, so it makes sense to match
246
+ # the number of threads with the number of devices.
247
+ self.executor = ThreadPoolExecutor(max_workers=len(devices))
248
+
249
+ # Register the exit handler for the parent process so it will terminate
250
+ # the child processes.
251
+ global EXIT_HANDLER_REGISTERED
252
+ if not EXIT_HANDLER_REGISTERED:
253
+ EXIT_HANDLER_REGISTERED = True
254
+ import atexit
255
+
256
+ atexit.register(self.terminate)
257
+
258
+ def get_device_list(self) -> Sequence[Optional[int]]:
259
+ """
260
+ Gather the list of devices to be used in the pool.
261
+ """
262
+ if not config.autotune_multi_device:
263
+ # Don't use multiple devices
264
+ return [None]
265
+
266
+ count = torch.cuda.device_count()
267
+
268
+ # If the user specified the visible devices in the env, use those.
269
+ if CUDA_VISIBLE_DEVICES in os.environ:
270
+ devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")]
271
+ assert len(devices) <= count
272
+ return devices
273
+
274
+ return list(range(count))
275
+
276
+ def terminate(self) -> None:
277
+ """
278
+ Signal all child processes to terminate.
279
+ """
280
+ if self.executor is not None:
281
+ self.executor.shutdown()
282
+ self.executor = None
283
+
284
+ if self.processes is not None:
285
+ for p in self.processes.queue:
286
+ p.terminate()
287
+ for p in self.processes.queue:
288
+ p.wait()
289
+ self.processes = None
290
+
291
+ def target(self, choice: TritonTemplateCaller) -> float:
292
+ """
293
+ Entry point for the thread-pool helper threads: Wait for an open TuningProcess,
294
+ remove it from the queue, execute the benchmark in that subprocess, and return
295
+ the TuningProcess to the queue.
296
+ """
297
+ assert choice.bmreq is not None
298
+ assert self.processes is not None
299
+
300
+ process = self.processes.get()
301
+ process.put(choice.bmreq)
302
+ try:
303
+ return process.get()
304
+ except queue.Empty:
305
+ warnings.warn(
306
+ f"Failed to benchmark choice '{choice}'. It will be ignored. "
307
+ "Please debug the root cause in case the choice can bring perf gains."
308
+ )
309
+ # set to INF so this choice will be ignored
310
+ return float("inf")
311
+ finally:
312
+ self.processes.put(process)
313
+
314
+ def benchmark(
315
+ self,
316
+ choices: List[TritonTemplateCaller],
317
+ ) -> Dict[TritonTemplateCaller, float]:
318
+ """
319
+ Benchmark each choice in a separate process.
320
+ """
321
+ assert self.processes is not None, "Tuning process pool is not initialized"
322
+ assert self.executor is not None
323
+
324
+ results = {}
325
+
326
+ # Use a ThreadExecutorPool to spread the work across the subprocesses and
327
+ # to grab subprocesses as soon as they're free.
328
+ for choice, result in zip(choices, self.executor.map(self.target, choices)):
329
+ results[choice] = result
330
+
331
+ return results
332
+
333
+
334
+ tuning_pool = TuningProcessPool()
335
+
336
+
337
+ LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
338
+
339
+
340
+ @dataclasses.dataclass
341
+ class TensorMeta:
342
+ device: torch.device
343
+ dtype: torch.dtype
344
+ sizes: torch._prims_common.ShapeType
345
+ strides: torch._prims_common.StrideType
346
+ offset: int
347
+
348
+ @classmethod
349
+ def from_irnodes(
350
+ cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
351
+ ) -> Union[TensorMeta, List[TensorMeta]]:
352
+ if isinstance(irnodes, Sequence):
353
+ result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
354
+ assert all(isinstance(x, TensorMeta) for x in result)
355
+ return result
356
+
357
+ node = irnodes
358
+ if isinstance(node, ir.Layout):
359
+ node = ir.Buffer("fake", node)
360
+
361
+ dtype = node.get_dtype()
362
+ assert dtype is not None
363
+
364
+ return TensorMeta(
365
+ device=node.get_device(),
366
+ dtype=dtype,
367
+ sizes=V.graph.sizevars.size_hints(
368
+ node.get_size(),
369
+ fallback=config.unbacked_symint_fallback,
370
+ ),
371
+ strides=V.graph.sizevars.size_hints(
372
+ node.get_stride(),
373
+ fallback=config.unbacked_symint_fallback,
374
+ ),
375
+ offset=V.graph.sizevars.size_hint(
376
+ node.get_layout().offset,
377
+ fallback=config.unbacked_symint_fallback,
378
+ ),
379
+ )
380
+
381
+ def to_tensor(self) -> torch.Tensor:
382
+ return rand_strided(
383
+ self.sizes,
384
+ self.strides,
385
+ device=self.device,
386
+ dtype=self.dtype,
387
+ extra_size=self.offset,
388
+ )
389
+
390
+
391
+ @dataclasses.dataclass
392
+ class BenchmarkRequest:
393
+ """
394
+ Only handle triton template benchmark for now. The extern kernel benchmark
395
+ can be done inside the same process since they usually don't cause crash.
396
+
397
+ Important: Instances of this class and subclasses have to be serializable
398
+ across process boundaries. Do not put CUDA Tensors in here!
399
+ """
400
+
401
+ def __init__(
402
+ self,
403
+ kernel_name: str,
404
+ input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
405
+ output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
406
+ extra_args: Iterable[Any],
407
+ ):
408
+ # the kernel name defined in the module
409
+ self.kernel_name = kernel_name
410
+
411
+ if isinstance(input_tensor_meta, TensorMeta):
412
+ input_tensor_meta = [input_tensor_meta]
413
+ self.input_tensor_meta = input_tensor_meta
414
+
415
+ if isinstance(output_tensor_meta, (tuple, list)):
416
+ assert len(output_tensor_meta) == 1
417
+ output_tensor_meta = output_tensor_meta[0]
418
+ self.output_tensor_meta = output_tensor_meta
419
+
420
+ self.extra_args = extra_args
421
+
422
+ def make_run_fn(
423
+ self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
424
+ ) -> Callable[[], None]:
425
+ raise NotImplementedError()
426
+
427
+ def cleanup_run_fn(self) -> None:
428
+ pass
429
+
430
+ def benchmark(
431
+ self,
432
+ *input_tensors: torch.Tensor,
433
+ output_tensor: Optional[torch.Tensor] = None,
434
+ ) -> float:
435
+ debug = log.isEnabledFor(logging.DEBUG)
436
+ if debug:
437
+ start_ts = time.time()
438
+
439
+ # create args and out tensor
440
+ if output_tensor is None:
441
+ assert len(input_tensors) == 0
442
+ input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
443
+ output_tensor = self.output_tensor_meta.to_tensor()
444
+
445
+ if debug:
446
+ create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
447
+ start_ts = time.time()
448
+
449
+ fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
450
+
451
+ if debug:
452
+ load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
453
+ start_ts = time.time()
454
+
455
+ out = do_bench(fn)
456
+ torch.cuda.synchronize() # shake out any CUDA errors
457
+
458
+ if debug:
459
+ bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
460
+ log.debug(
461
+ "InChildProcess %s: load %f, create tensor %f, bench %f",
462
+ str(self),
463
+ load_elapse, # type: ignore[possibly-undefined]
464
+ create_tensor_elapse, # type: ignore[possibly-undefined]
465
+ bench_elapse,
466
+ )
467
+ self.cleanup_run_fn()
468
+ return out
469
+
470
+
471
+ class TestBenchmarkRequest(BenchmarkRequest):
472
+ """
473
+ Supports unit testing. Defined in this file so that the TuningProcess
474
+ sub-process knows how to unpickle these objects.
475
+ """
476
+
477
+ def __init__(self, value: Optional[float] = None) -> None:
478
+ self.value = value
479
+
480
+ def benchmark(
481
+ self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
482
+ ) -> float:
483
+ if self.value is None:
484
+ raise Exception("Failed to run")
485
+ return self.value
486
+
487
+
488
+ class TritonBenchmarkRequest(BenchmarkRequest):
489
+ # Important: Instances of this class have to be serializable
490
+ # across process boundaries. Do not put CUDA Tensors in here!
491
+
492
+ def __init__(
493
+ self,
494
+ kernel_name: str,
495
+ input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
496
+ output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
497
+ extra_args: Iterable[Any],
498
+ module_path: str, # the path of the module defining the triton kernel
499
+ module_cache_key: str,
500
+ grid: List[int],
501
+ num_stages: int,
502
+ num_warps: int,
503
+ matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
504
+ ):
505
+ super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
506
+ self.module_path = module_path
507
+ self.module_cache_key = module_cache_key
508
+ self.grid = grid
509
+ self.num_stages = num_stages
510
+ self.num_warps = num_warps
511
+ self.matrix_instr_nonkdim = matrix_instr_nonkdim
512
+
513
+ def make_run_fn(
514
+ self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
515
+ ) -> Callable[[], None]:
516
+ mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
517
+ log.debug(
518
+ "benchmark module key: %s, path: %s",
519
+ self.module_cache_key,
520
+ self.module_path,
521
+ )
522
+
523
+ run_method = getattr(mod, self.kernel_name).run
524
+ extra_args = list(self.extra_args)
525
+
526
+ # Newer version of triton add warmup argument to JITFunction.run.
527
+ # This code handles backward-compatibility.
528
+ warmup_arg = {}
529
+ import inspect
530
+
531
+ if "warmup" in inspect.signature(run_method).parameters:
532
+ warmup_arg["warmup"] = False
533
+
534
+ if torch.version.hip and self.matrix_instr_nonkdim != 0:
535
+ return functools.partial(
536
+ run_method,
537
+ *input_tensors,
538
+ output_tensor,
539
+ *self.extra_args,
540
+ grid=self.grid,
541
+ **warmup_arg,
542
+ num_stages=self.num_stages,
543
+ num_warps=self.num_warps,
544
+ matrix_instr_nonkdim=self.matrix_instr_nonkdim,
545
+ )
546
+ else:
547
+ return functools.partial(
548
+ run_method,
549
+ *input_tensors,
550
+ output_tensor,
551
+ *self.extra_args,
552
+ grid=self.grid,
553
+ **warmup_arg,
554
+ num_stages=self.num_stages,
555
+ num_warps=self.num_warps,
556
+ )
557
+
558
+ def __str__(self) -> str:
559
+ return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
560
+
561
+
562
+ class CUDABenchmarkRequest(BenchmarkRequest):
563
+ # Important: Instances of this class have to be serializable
564
+ # across process boundaries. Do not put CUDA Tensors in here!
565
+
566
+ def __init__(
567
+ self,
568
+ kernel_name: str,
569
+ input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
570
+ output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
571
+ extra_args: Iterable[Any],
572
+ source_code: str,
573
+ ):
574
+ super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
575
+ self.source_code = source_code
576
+ self.workspace_size: int = 0
577
+ self.workspace: Optional[torch.Tensor] = None
578
+ self.DLL: Optional[DLLWrapper] = None
579
+ self.hash_key: str = ""
580
+ self.source_file: str = ""
581
+ self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
582
+
583
+ def precompile(self):
584
+ # Prepopulate CUDACodeCache
585
+ # may happen in separate Threadpool
586
+ log.debug("Precompiling %s", self)
587
+ CUDACodeCache.load(self.source_code, "so")
588
+ log.debug("Done precompiling %s", self)
589
+
590
+ def make_run_fn(
591
+ self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
592
+ ) -> Callable[[], None]:
593
+ self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
594
+ self.source_code, "so"
595
+ )
596
+ args = [
597
+ c_void_p(tensor.data_ptr())
598
+ for tensor in list(input_tensors) + [output_tensor]
599
+ ]
600
+ log.debug(
601
+ "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
602
+ self.kernel_name,
603
+ self.source_file,
604
+ self.hash_key,
605
+ self.DLL,
606
+ args,
607
+ self.extra_args,
608
+ )
609
+ run_method = getattr(self.DLL, self.kernel_name)
610
+ stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
611
+
612
+ # Retrieve workspace_size and initialize workspace.
613
+ c_workspace_size = c_size_t()
614
+ run_method(
615
+ *args, # input ptrs and output ptrs
616
+ *self.extra_args,
617
+ byref(
618
+ c_workspace_size
619
+ ), # set workspace size ptr to retrieve workspace size
620
+ None, # null workspace ptr
621
+ stream_ptr,
622
+ )
623
+ self.workspace_size = c_workspace_size.value
624
+ # TODO: Support non-zero workspace_size.
625
+ assert self.workspace_size == 0, (
626
+ "Things need to be fixed to support non-zero workspace_size: "
627
+ "1) max autotune cache needs to store workspace size; "
628
+ "2) memory allocation needs to allocate / deallocate workspace correctly; "
629
+ )
630
+
631
+ # Generate partial function.
632
+ return functools.partial(
633
+ run_method,
634
+ *args,
635
+ *self.extra_args,
636
+ None, # null workspace size ptr
637
+ None, # set workspace ptr, TODO: update it to a real ptr if workspace_size > 0
638
+ stream_ptr,
639
+ )
640
+
641
+ def cleanup_run_fn(self) -> None:
642
+ if self.DLL is not None:
643
+ self.DLL.close()
644
+ self.workspace = None
645
+
646
+ def __str__(self) -> str:
647
+ return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
648
+
649
+
650
+ def benchmark_in_sub_process(
651
+ choices: List[TritonTemplateCaller],
652
+ ) -> Dict[TritonTemplateCaller, float]:
653
+ """
654
+ Do benchmarking in a subprocess and return the perf number (latency).
655
+ """
656
+ return tuning_pool.benchmark(choices)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/coordinate_descent_tuner.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import logging
4
+ from typing import Callable, Optional
5
+
6
+ from torch.utils._triton import has_triton
7
+ from .utils import red_text, triton_config_to_hashable
8
+
9
+ if has_triton():
10
+ import triton
11
+ else:
12
+ triton = None
13
+
14
+ from . import config as inductor_config
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ def get_field(config, name):
20
+ if name == "num_warps":
21
+ return config.num_warps
22
+ elif name == "num_stages":
23
+ return config.num_stages
24
+ else:
25
+ return config.kwargs.get(name, None)
26
+
27
+
28
+ def set_field(config, name, value):
29
+ if name == "num_warps":
30
+ config.num_warps = value
31
+ elif name == "num_stages":
32
+ config.num_stages = value
33
+ else:
34
+ config.kwargs[name] = value
35
+
36
+
37
+ class CoordescTuner:
38
+ """
39
+ The coordinate descent tuner. Tune one field/coordinate at a time.
40
+
41
+ TODO will it be necessary to tune multiple fields simultaneously.
42
+
43
+
44
+ TODO: what if both increasing and decreasing a field can improve perf.
45
+ i.e., there are multiple local optima..
46
+ """
47
+
48
+ def __init__(self, is_mm=False, name="unknown", size_hints=None):
49
+ self.is_mm = is_mm # we will tune num_stages for mm
50
+ self.cached_benchmark_results = {}
51
+ self.name = name
52
+ self.size_hints = size_hints
53
+
54
+ def get_xmax(self):
55
+ xmax = inductor_config.triton.max_block["X"]
56
+ if self.size_hints and len(self.size_hints) > 0:
57
+ xmax = min(xmax, self.size_hints[0])
58
+ return xmax
59
+
60
+ def get_ymax(self):
61
+ ymax = inductor_config.triton.max_block["Y"]
62
+ if self.size_hints and len(self.size_hints) > 1:
63
+ ymax = min(ymax, self.size_hints[1])
64
+ return ymax
65
+
66
+ def get_zmax(self):
67
+ zmax = inductor_config.triton.max_block["Z"]
68
+ if self.size_hints and len(self.size_hints) > 2:
69
+ zmax = min(zmax, self.size_hints[2])
70
+ return zmax
71
+
72
+ def get_rmax(self):
73
+ if self.size_hints and len(self.size_hints) > 0:
74
+ return self.size_hints[-1] # the last one is for reduction
75
+ else:
76
+ # large enough. We should not pick this large RBLOCK anyway
77
+ return 2**30
78
+
79
+ def get_warpsmax(self):
80
+ # Currently, CUDA has a maximum of 1024 threads, so 32 is the max
81
+ # number of warps.
82
+ return 1024 // 32
83
+
84
+ def cache_benchmark_result(self, config, timing):
85
+ self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
86
+
87
+ def lookup_in_cache(self, config):
88
+ return self.cached_benchmark_results.get(triton_config_to_hashable(config))
89
+
90
+ def call_func(self, func, config):
91
+ found = self.lookup_in_cache(config)
92
+ if found is not None:
93
+ log.debug(" CACHED")
94
+ return found
95
+ timing = func(config)
96
+ self.cache_benchmark_result(config, timing)
97
+ return timing
98
+
99
+ @property
100
+ def tunable_fields(self):
101
+ out = [
102
+ "XBLOCK",
103
+ "YBLOCK",
104
+ "ZBLOCK",
105
+ # NOTE: we should not tune RBLOCK for persistent reduction.
106
+ # We rely on the fact that persistent reduction's triton.Config
107
+ # does not have the RBLOCK field to guarantee that.
108
+ "RBLOCK",
109
+ # the following 3 are for mm
110
+ "BLOCK_M",
111
+ "BLOCK_N",
112
+ "BLOCK_K",
113
+ "num_warps",
114
+ ]
115
+ if self.is_mm:
116
+ out.append("num_stages")
117
+
118
+ return out
119
+
120
+ def value_too_large(self, name, val):
121
+ if name == "XBLOCK":
122
+ return val > self.get_xmax()
123
+ if name == "YBLOCK":
124
+ return val > self.get_ymax()
125
+ if name == "ZBLOCK":
126
+ return val > self.get_zmax()
127
+ if name == "RBLOCK":
128
+ return val > self.get_rmax()
129
+ if name == "num_warps":
130
+ return val > self.get_warpsmax()
131
+
132
+ return False
133
+
134
+ def get_neighbour_values(self, name, orig_val, radius=1, include_self=False):
135
+ """
136
+ Get neighbour values in 'radius' steps. The original value is not
137
+ returned as it's own neighbour.
138
+ """
139
+ assert radius >= 1
140
+
141
+ def update(cur_val, inc=True):
142
+ if name == "num_stages":
143
+ if inc:
144
+ return cur_val + 1
145
+ else:
146
+ return cur_val - 1
147
+ else:
148
+ if inc:
149
+ return cur_val * 2
150
+ else:
151
+ return cur_val // 2
152
+
153
+ out = []
154
+ # increment loop
155
+ cur_val = orig_val
156
+ for _ in range(radius):
157
+ cur_val = update(cur_val, True)
158
+ if self.value_too_large(name, cur_val):
159
+ break
160
+ out.append(cur_val)
161
+
162
+ # decrement loop
163
+ cur_val = orig_val
164
+ for _ in range(radius):
165
+ cur_val = update(cur_val, False)
166
+ if cur_val <= 0:
167
+ break
168
+ out.append(cur_val)
169
+
170
+ if include_self:
171
+ out.append(orig_val)
172
+ return out
173
+
174
+ @staticmethod
175
+ def has_improvement(baseline, test):
176
+ threshold = 0.001 # 0.1%
177
+ return test is not None and test < baseline * (1 - threshold)
178
+
179
+ def check_all_tuning_directions(
180
+ self,
181
+ func: Callable[["triton.Config"], float],
182
+ best_config,
183
+ best_timing,
184
+ ):
185
+ """
186
+ Check all directions. We only do this once the regular coordinate
187
+ descent tuning find no better choices any more.
188
+ We only have a few tunable fields, so this should be fine.
189
+ """
190
+ candidate_values_list = []
191
+ effective_fields = []
192
+ for field in self.tunable_fields:
193
+ old_value = get_field(best_config, field)
194
+ if old_value is None:
195
+ continue
196
+ candidate_values = self.get_neighbour_values(
197
+ field,
198
+ old_value,
199
+ radius=inductor_config.coordinate_descent_search_radius,
200
+ include_self=True,
201
+ )
202
+ candidate_values_list.append(candidate_values)
203
+ effective_fields.append(field)
204
+
205
+ choices = itertools.product(*candidate_values_list)
206
+ improved = False
207
+ for choice in choices:
208
+ assert len(choice) == len(effective_fields)
209
+ candidate_config = copy.deepcopy(best_config)
210
+ for new_val, field in zip(choice, effective_fields):
211
+ set_field(candidate_config, field, new_val)
212
+ cmp_res, candidate_timing = self.compare_config(
213
+ func, candidate_config, best_config, best_timing
214
+ )
215
+ if cmp_res:
216
+ improved = True
217
+ best_config = candidate_config
218
+ best_timing = candidate_timing
219
+
220
+ return improved, best_config, best_timing
221
+
222
+ def compare_config(self, func, candidate_config, best_config, best_timing):
223
+ """
224
+ Check if candidate_config is better than best_config.
225
+
226
+ Return a touple of (compare_result, candidate_timing).
227
+ compare_result is true iff candidate_config is better.
228
+ """
229
+ log.debug("Try config %s", candidate_config)
230
+ try:
231
+ candidate_timing = self.call_func(func, candidate_config)
232
+ except Exception as e:
233
+ log.debug("Got exception %s", e)
234
+ return False, float("inf")
235
+
236
+ if self.has_improvement(best_timing, candidate_timing):
237
+ log.debug(
238
+ "Tune from %s %f -> %s %f",
239
+ best_config,
240
+ best_timing,
241
+ candidate_config,
242
+ candidate_timing,
243
+ )
244
+
245
+ return True, candidate_timing
246
+ return False, candidate_timing
247
+
248
+ def autotune(
249
+ self,
250
+ func: Callable[["triton.Config"], float],
251
+ baseline_config: "triton.Config",
252
+ baseline_timing: Optional[float] = None,
253
+ ) -> "triton.Config":
254
+ if baseline_timing is None:
255
+ baseline_timing = self.call_func(func, baseline_config)
256
+
257
+ log.debug("= Do coordinate descent tuning for %s =", self.name)
258
+ log.debug(
259
+ "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing
260
+ )
261
+ improved = True
262
+ best_config = baseline_config
263
+ best_timing = baseline_timing
264
+ tunable_fields = self.tunable_fields
265
+
266
+ while improved:
267
+ improved = False
268
+
269
+ for name in tunable_fields:
270
+ cur_val = get_field(best_config, name)
271
+ # some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None
272
+ if cur_val is None:
273
+ continue
274
+
275
+ # It's possible that candidate_values is empty.
276
+ # E.g., if XBLOCK is 1 initially and size_hint for x is also 1.
277
+ # We would not try either larger or smaller XBLOCK in this case.
278
+ candidate_values = self.get_neighbour_values(name, cur_val)
279
+
280
+ for next_val in candidate_values:
281
+ candidate_config = copy.deepcopy(best_config)
282
+ set_field(candidate_config, name, next_val)
283
+
284
+ cmp_res, candidate_timing = self.compare_config(
285
+ func, candidate_config, best_config, best_timing
286
+ )
287
+ if cmp_res:
288
+ improved = True
289
+ best_config, best_timing = candidate_config, candidate_timing
290
+
291
+ if not improved and inductor_config.coordinate_descent_check_all_directions:
292
+ old_best_timing = best_timing
293
+ improved, best_config, best_timing = self.check_all_tuning_directions(
294
+ func, best_config, best_timing
295
+ )
296
+
297
+ if improved:
298
+ msg = red_text(
299
+ "Coordinate descend tuning found improvement of %.3fx by looking in all directions."
300
+ )
301
+ log.debug(
302
+ msg,
303
+ old_best_timing / best_timing,
304
+ )
305
+
306
+ log.debug(
307
+ "Improve from %s %f -> %s %f, %.3fx",
308
+ baseline_config,
309
+ baseline_timing,
310
+ best_config,
311
+ best_timing,
312
+ baseline_timing / best_timing,
313
+ )
314
+
315
+ return best_config
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/debug.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import contextlib
3
+ import cProfile
4
+ import dataclasses
5
+ import functools
6
+ import itertools
7
+ import logging
8
+ import os
9
+ import os.path
10
+ import pickle
11
+ import pstats
12
+ import shutil
13
+ import subprocess
14
+ from typing import Any, Dict, List, Optional
15
+ from unittest.mock import patch
16
+
17
+ from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
18
+
19
+ import torch
20
+ from torch import fx as fx
21
+
22
+ from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
23
+ from torch._dynamo.utils import get_debug_dir
24
+ from torch.fx.graph_module import GraphModule
25
+ from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
26
+ from torch.fx.passes.tools_common import legalize_graph
27
+ from torch.utils._pytree import tree_map
28
+
29
+ from . import config, ir # noqa: F811, this is needed
30
+ from .scheduler import (
31
+ BaseSchedulerNode,
32
+ FusedSchedulerNode,
33
+ NopKernelSchedulerNode,
34
+ OutputNode,
35
+ SchedulerNode,
36
+ )
37
+ from .virtualized import V
38
+
39
+ log = logging.getLogger(__name__)
40
+
41
+ SchedulerNodeList = List[Any]
42
+ BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
43
+ GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
44
+
45
+
46
+ @functools.lru_cache(None)
47
+ def has_dot() -> bool:
48
+ try:
49
+ subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
50
+ return True
51
+ except subprocess.SubprocessError:
52
+ return False
53
+
54
+
55
+ def draw_buffers(nodes: List[BaseSchedulerNode], print_graph=False, fname=None):
56
+ """
57
+ Draw a graph in fname.svg.
58
+ """
59
+ if not has_dot():
60
+ log.warning("draw_buffers() requires `graphviz` package")
61
+ return
62
+
63
+ if fname is None:
64
+ fname = get_graph_being_compiled()
65
+
66
+ graph = create_fx_from_snodes(nodes)
67
+
68
+ for node in graph.nodes:
69
+ if "fusion_meta" not in node.meta:
70
+ continue
71
+ group = node.meta["fusion_meta"].group
72
+ if isinstance(group, tuple):
73
+ if isinstance(group[1], int):
74
+ group = (group[1],)
75
+ else:
76
+ group = group[1]
77
+
78
+ # gather meta data
79
+ dtype = None
80
+ if isinstance(node, ir.ComputedBuffer):
81
+ dtype = node.data.dtype
82
+
83
+ metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type]
84
+ node.meta["tensor_meta"] = metadata
85
+
86
+ if print_graph:
87
+ print(graph)
88
+
89
+ gm = GraphModule({}, graph)
90
+ legalize_graph(gm)
91
+ gm.graph.lint()
92
+ draw_graph(
93
+ gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape
94
+ )
95
+
96
+
97
+ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
98
+ """
99
+ Creates a FX Graph from a list of SchedulerNode objects.
100
+ """
101
+
102
+ def get_fake_func(name):
103
+ def func1(*args):
104
+ return 0
105
+
106
+ func1.__name__ = name
107
+ return func1
108
+
109
+ FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
110
+
111
+ buf_to_fx_node = {}
112
+ graph = torch.fx.Graph()
113
+ first_node = None
114
+
115
+ outputs = []
116
+ group: Any = None
117
+ # create call_function node for each Buffer and Kernel
118
+ for snode in snodes:
119
+ if snode.is_extern():
120
+ node_type = "extern"
121
+ group = node_type
122
+ elif snode.is_template():
123
+ node_type = "template"
124
+ group = node_type
125
+ elif isinstance(snode, NopKernelSchedulerNode):
126
+ node_type = "nop"
127
+ group = node_type
128
+ elif isinstance(snode, SchedulerNode):
129
+ node_type = "compute"
130
+ group = snode.group
131
+ elif isinstance(snode, FusedSchedulerNode):
132
+ node_type = "fused"
133
+ group = snode.group
134
+ else:
135
+ raise RuntimeError("Unknown node type")
136
+
137
+ fused_name = torch._inductor.utils.get_fused_kernel_name(
138
+ snode.get_nodes(), "original_aten"
139
+ )
140
+ func_name = f"{node_type}: {fused_name}"
141
+ node_func = get_fake_func(func_name)
142
+ kwargs = {}
143
+ if hasattr(snode, "get_device"):
144
+ kwargs = {"device": snode.get_device()}
145
+ fx_node = graph.call_function(node_func, args=(), kwargs=kwargs)
146
+
147
+ def in_output(snode):
148
+ if isinstance(snode, FusedSchedulerNode):
149
+ return any(in_output(x) for x in snode.snodes)
150
+ return any(isinstance(user.node, OutputNode) for user in snode.users)
151
+
152
+ if in_output(snode):
153
+ outputs.append(fx_node)
154
+ name = snode.get_name()
155
+ fx_node.name = name
156
+
157
+ fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
158
+
159
+ if isinstance(snode, FusedSchedulerNode):
160
+ for x in snode.snodes:
161
+ buf_to_fx_node[x.get_name()] = fx_node
162
+ buf_to_fx_node[name] = fx_node
163
+
164
+ if first_node is None:
165
+ first_node = fx_node
166
+
167
+ # create edges between nodes
168
+ for snode in snodes:
169
+ name = snode.get_name()
170
+ deps = snode.read_writes.reads
171
+
172
+ fx_node = buf_to_fx_node[name]
173
+ new_args = []
174
+ for dep in deps:
175
+ if dep.name in buf_to_fx_node:
176
+ dep_node = buf_to_fx_node[dep.name]
177
+ else:
178
+ with graph.inserting_before(first_node):
179
+ dep_node = graph.placeholder(dep.name)
180
+ buf_to_fx_node[dep.name] = dep_node
181
+ new_args.append(dep_node)
182
+
183
+ fx_node.args = tuple(new_args)
184
+
185
+ graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
186
+ return graph
187
+
188
+
189
+ def update_orig_fx_node_name_to_buf_name(
190
+ nodes: SchedulerNodeList,
191
+ node_name_to_buf_name: Dict[str, str],
192
+ parent_buf_name: Optional[str] = None,
193
+ n_origins: int = 0,
194
+ ):
195
+ if nodes is None:
196
+ return
197
+ for node in nodes:
198
+ # for FusedSchedulerNode, traverse recursively into get_nodes()
199
+ buf_name = node.get_name()
200
+ children_nodes = node.get_nodes()
201
+ if children_nodes is not None and len(children_nodes) > 1:
202
+ update_orig_fx_node_name_to_buf_name(
203
+ children_nodes,
204
+ node_name_to_buf_name,
205
+ buf_name if parent_buf_name is None else parent_buf_name,
206
+ )
207
+ continue
208
+ else:
209
+ assert len(children_nodes) == 1 and children_nodes[0] == node
210
+
211
+ ir_node = node.node
212
+ if ir_node is None or ir_node.origins is None:
213
+ continue
214
+ for origin in ir_node.origins:
215
+ node_name = origin.name
216
+ # when buf1 and buf2 both have origin=node1
217
+ # we draw node1 according to buf1
218
+ if node_name not in node_name_to_buf_name:
219
+ node_name_to_buf_name[node_name] = (
220
+ buf_name if parent_buf_name is None else parent_buf_name
221
+ )
222
+
223
+
224
+ def get_node_name_to_buf_meta(node_name_to_buf_name: Dict[str, str]):
225
+ buf_name_to_n_node = {}
226
+ for node_name, buf_name in node_name_to_buf_name.items():
227
+ if buf_name not in buf_name_to_n_node:
228
+ buf_name_to_n_node[buf_name] = {node_name}
229
+ else:
230
+ buf_name_to_n_node[buf_name].add(node_name)
231
+
232
+ node_name_to_buf_meta = {}
233
+ for node_name, buf_name in node_name_to_buf_name.items():
234
+ n_node = len(buf_name_to_n_node[buf_name])
235
+ node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node)
236
+ return node_name_to_buf_meta
237
+
238
+
239
+ def annotate_orig_fx_with_snodes(
240
+ gm: torch.fx.GraphModule, snodes: SchedulerNodeList
241
+ ) -> None:
242
+ """
243
+ Creates a FX Graph from a list of SchedulerNode objects.
244
+ """
245
+ node_name_to_buf_name: Dict[str, str] = {}
246
+ update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
247
+ if node_name_to_buf_name is None:
248
+ return
249
+ node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name)
250
+ for node in gm.graph.nodes:
251
+ if node.name in node_name_to_buf_meta:
252
+ node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name)
253
+
254
+
255
+ @contextlib.contextmanager
256
+ def enable_aot_logging():
257
+ compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
258
+
259
+ import torch._functorch.aot_autograd
260
+
261
+ log = logging.getLogger(torch._functorch.aot_autograd.__name__)
262
+
263
+ stack = contextlib.ExitStack()
264
+ if not compile_debug:
265
+ try:
266
+ yield
267
+ finally:
268
+ stack.close()
269
+ return
270
+
271
+ # Enable all graphs to be logged to a file by setting the flags to True
272
+ # and the log level of the file logger to DEBUG
273
+ stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
274
+
275
+ path = os.path.join(get_debug_dir(), "torchinductor")
276
+ os.makedirs(path, exist_ok=True)
277
+
278
+ fh = logging.FileHandler(
279
+ os.path.join(
280
+ path,
281
+ f"aot_{get_aot_graph_name()}_debug.log",
282
+ )
283
+ )
284
+ fh.setLevel(logging.DEBUG)
285
+ fh.setFormatter(
286
+ logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
287
+ )
288
+ log.addHandler(fh)
289
+ try:
290
+ yield
291
+ finally:
292
+ log.removeHandler(fh)
293
+ stack.close()
294
+
295
+
296
+ class DebugContext:
297
+ _counter = itertools.count()
298
+
299
+ @staticmethod
300
+ def wrap(fn):
301
+ @functools.wraps(fn)
302
+ def inner(*args, **kwargs):
303
+ with DebugContext():
304
+ return fn(*args, **kwargs)
305
+
306
+ return wrap_compiler_debug(inner, compiler_name="inductor")
307
+
308
+ @staticmethod
309
+ def create_debug_dir(folder_name: str) -> Optional[str]:
310
+ debug_dir = config.trace.debug_dir or get_debug_dir()
311
+ for n in DebugContext._counter:
312
+ dirname = os.path.join(
313
+ debug_dir,
314
+ "torchinductor",
315
+ f"{folder_name}.{n}",
316
+ )
317
+ if not os.path.exists(dirname):
318
+ os.makedirs(dirname)
319
+ return dirname
320
+ return None
321
+
322
+ def __init__(self):
323
+ self._prof = None
324
+ self._path = None
325
+ self._stack = contextlib.ExitStack()
326
+
327
+ def copy(self, new_path: str):
328
+ if not self._path:
329
+ return
330
+ assert new_path.endswith(".debug"), new_path
331
+ if os.path.exists(new_path):
332
+ shutil.rmtree(new_path)
333
+ try:
334
+ shutil.copytree(self._path, new_path)
335
+ self._path = new_path
336
+ except OSError:
337
+ log.warning(
338
+ "Failed to copy debug files from %s to %s", self._path, new_path
339
+ )
340
+ pass
341
+
342
+ def fopen(self, filename: str, write_mode: str = "w", *args, **kwargs):
343
+ assert self._path
344
+ return open(os.path.join(self._path, filename), write_mode, *args, **kwargs)
345
+
346
+ @contextlib.contextmanager
347
+ def fopen_context(self, filename: str, write_mode: str = "w", *args, **kwargs):
348
+ assert self._path
349
+ with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f:
350
+ yield f
351
+
352
+ def filename(self, suffix: str):
353
+ assert self._path
354
+ return os.path.join(self._path, suffix)
355
+
356
+ def upload_tar(self):
357
+ if config.trace.upload_tar is not None:
358
+ import tarfile
359
+
360
+ assert self._path
361
+ tar_file = os.path.join(
362
+ self._path, f"{os.path.basename(self._path)}.tar.gz"
363
+ )
364
+ with tarfile.open(tar_file, "w:gz") as tar:
365
+ tar.add(self._path, arcname=os.path.basename(self._path))
366
+ config.trace.upload_tar(tar_file)
367
+
368
+ def __enter__(self):
369
+ if config.debug:
370
+ log = logging.getLogger("torch._dynamo")
371
+ prev_level = log.level
372
+ log.setLevel(logging.DEBUG)
373
+
374
+ def reset_log_level(level):
375
+ log.setLevel(level)
376
+
377
+ self._stack.callback(reset_log_level, prev_level)
378
+
379
+ self._stack.enter_context(V.set_debug_handler(self))
380
+
381
+ if not config.trace.enabled:
382
+ return
383
+
384
+ self._path = self.create_debug_dir(get_aot_graph_name())
385
+
386
+ if config.trace.debug_log:
387
+ self._setup_log_capture("debug.log", logging.DEBUG)
388
+ if config.trace.info_log:
389
+ self._setup_log_capture("info.log", logging.INFO)
390
+ if config.trace.compile_profile:
391
+ self._prof = cProfile.Profile()
392
+ self._prof.enable()
393
+
394
+ def _setup_log_capture(self, filename: str, level: int):
395
+ log = logging.getLogger("torch._inductor")
396
+ fd = self._stack.enter_context(self.fopen(filename))
397
+ ch = logging.StreamHandler(fd)
398
+ ch.setLevel(level)
399
+ ch.setFormatter(
400
+ logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
401
+ )
402
+ log.addHandler(ch)
403
+ log.setLevel(min(log.level, level))
404
+ self._stack.callback(log.removeHandler, ch)
405
+
406
+ def __exit__(self, exc_type, exc_val, exc_tb):
407
+ if self._prof:
408
+ self._prof.disable()
409
+ self._save_profile_data()
410
+
411
+ if self._path:
412
+ self.upload_tar()
413
+ log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
414
+ self._stack.close()
415
+
416
+ def _save_profile_data(self):
417
+ assert self._prof
418
+ self._prof.dump_stats(self.filename("compile.prof"))
419
+ with self.fopen("compile.stats") as fd:
420
+ stats = pstats.Stats(self._prof, stream=fd)
421
+ stats.strip_dirs()
422
+ stats.sort_stats("cumtime")
423
+ stats.print_stats(100)
424
+ stats.sort_stats("tottime")
425
+ stats.print_stats(100)
426
+
427
+ def __getattr__(self, name):
428
+ if config.trace.enabled and getattr(config.trace, name):
429
+ try:
430
+ return getattr(DebugFormatter(self), name)
431
+ except Exception:
432
+ log.warning("Ignoring exception in debug code", exc_info=True)
433
+ else:
434
+
435
+ def ignored(*args, **kwargs):
436
+ pass
437
+
438
+ return ignored
439
+
440
+
441
+ class DebugFormatter:
442
+ def __init__(self, handler):
443
+ self.fopen = handler.fopen
444
+ self.fopen_context = handler.fopen_context
445
+ self.filename = handler.filename
446
+ self.handler = handler
447
+
448
+ def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
449
+ with self.fopen("fx_graph_runnable.py") as fd:
450
+ save_graph_repro(fd, gm, inputs, "inductor")
451
+
452
+ with self.fopen("fx_graph_readable.py") as fd:
453
+ fd.write(gm.print_readable(print_output=False))
454
+
455
+ def fx_graph_transformed(
456
+ self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
457
+ ):
458
+ with self.fopen("fx_graph_transformed.py") as fd:
459
+ fd.write(gm.print_readable(print_output=False))
460
+
461
+ def ir_pre_fusion(self, nodes: SchedulerNodeList):
462
+ self._write_ir("ir_pre_fusion.txt", nodes)
463
+
464
+ def ir_post_fusion(self, nodes: SchedulerNodeList):
465
+ self._write_ir("ir_post_fusion.txt", nodes)
466
+
467
+ def _write_ir(self, filename: str, nodes: SchedulerNodeList):
468
+ with self.fopen(filename) as fd:
469
+ log.info("Writing debug ir to %s", fd.name)
470
+ for node in nodes:
471
+ fd.write(node.debug_str())
472
+ fd.write("\n\n\n")
473
+
474
+ def graph_diagram(self, nodes: SchedulerNodeList):
475
+ draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
476
+
477
+ def draw_orig_fx_graph(self, gm: torch.fx.GraphModule, nodes: SchedulerNodeList):
478
+ annotate_orig_fx_with_snodes(gm, nodes)
479
+ draw_graph(
480
+ gm,
481
+ fname=self.filename("orig_fx_graph_diagram.svg"),
482
+ clear_meta=False,
483
+ prog=GRAPHVIZ_COMMAND_SCALABLE,
484
+ parse_stack_trace=True,
485
+ dot_graph_shape=config.trace.dot_graph_shape,
486
+ )
487
+
488
+ def output_code(self, filename):
489
+ shutil.copy(filename, self.filename("output_code.py"))
490
+
491
+ def log_autotuning_results(
492
+ self,
493
+ name: str,
494
+ input_nodes: List[ir.IRNode],
495
+ timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
496
+ elapse: float,
497
+ ):
498
+ import json
499
+
500
+ from .ir import FixedLayout
501
+
502
+ def build_node_info(node: ir.IRNode):
503
+ if hasattr(node, "name"):
504
+ node_name = node.name
505
+ else:
506
+ node_name = ""
507
+ node_info = {
508
+ "name": node_name,
509
+ "type": type(node).__name__,
510
+ }
511
+ try:
512
+ layout = node.get_layout()
513
+ if isinstance(layout, FixedLayout):
514
+ offset = 0
515
+ try:
516
+ offset = int(layout.offset)
517
+ except Exception:
518
+ try:
519
+ offset = V.graph.sizevars.size_hint(
520
+ layout.offset, fallback=0
521
+ )
522
+ except Exception:
523
+ pass
524
+ static_layout = FixedLayout(
525
+ layout.device,
526
+ dtype=layout.dtype,
527
+ size=list(V.graph.sizevars.size_hints(layout.size)),
528
+ stride=list(V.graph.sizevars.size_hints(layout.stride)),
529
+ offset=offset,
530
+ )
531
+ node_info["layout"] = str(static_layout)
532
+ else:
533
+ node_info["layout"] = str(node.get_layout())
534
+ except Exception as e:
535
+ pass
536
+ try:
537
+ node_info["dtype"] = str(node.get_dtype())
538
+ except Exception as e:
539
+ pass
540
+ try:
541
+ node_info["device"] = str(node.get_device())
542
+ except Exception as e:
543
+ pass
544
+ try:
545
+ node_info["stride"] = str(
546
+ V.graph.sizevars.size_hints(node.get_stride())
547
+ )
548
+ except Exception as e:
549
+ pass
550
+ try:
551
+ node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size()))
552
+ except Exception as e:
553
+ pass
554
+ try:
555
+ node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel()))
556
+ except Exception as e:
557
+ pass
558
+ if hasattr(node, "data") and isinstance(node.data, ir.IRNode):
559
+ node_info["data"] = build_node_info(node.data)
560
+ return node_info
561
+
562
+ general_properties = {
563
+ "op_name": name,
564
+ "cuda_device_name": torch.cuda.get_device_name(),
565
+ "cuda_device_count": torch.cuda.device_count(),
566
+ "input_nodes": [build_node_info(node) for node in input_nodes],
567
+ "autotuning_time": elapse,
568
+ }
569
+ with self.fopen_context(
570
+ "autotuning_result_json_list.txt", "at", encoding="utf-8"
571
+ ) as fd:
572
+ for caller, time in timings.items():
573
+ info_dict = dict(caller.info_dict())
574
+ info_dict.update(general_properties)
575
+ info_dict["benchmark_result"] = time
576
+ json.dump(info_dict, fd)
577
+ fd.write("\n")
578
+
579
+
580
+ @dataclasses.dataclass
581
+ class TensorMetadataHolder:
582
+ tensor_metadata: TensorMetadata
583
+ device: torch.device
584
+
585
+
586
+ save_args_cnt = itertools.count()
587
+
588
+
589
+ def save_args_for_compile_fx_inner(*args, **kwargs):
590
+ """
591
+ This function is used to save arguments for a compile_fx_inner function call
592
+ to the file system. Later on one can replay the compile_fx_inner call
593
+ with the saved arguments using load_args_and_run_compile_fx_inner.
594
+ """
595
+
596
+ folder = "/tmp/inductor_saved_args"
597
+ if not os.path.exists(folder):
598
+ os.mkdir(folder)
599
+
600
+ def handle_tensor(x):
601
+ """
602
+ Pickle FakeTensor will result in error:
603
+ AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
604
+
605
+ Convert all Tensor to metadata. This may also makes pickle faster.
606
+ """
607
+ if isinstance(x, torch.Tensor):
608
+ return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
609
+ else:
610
+ return x
611
+
612
+ args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))
613
+
614
+ fn_name = "compile_fx_inner"
615
+ path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl"
616
+ with open(path, "wb") as f:
617
+ pickle.dump((args_to_save, kwargs_to_save), f)
618
+
619
+ if log.isEnabledFor(logging.DEBUG):
620
+ message = f"""
621
+ Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
622
+ run the following:
623
+
624
+ from torch._inductor.debug import load_args_and_run_compile_fx_inner
625
+ load_args_and_run_compile_fx_inner({path!r})
626
+ """
627
+ # call print rather than log.debug. log.debug will print message
628
+ # prefix for each line which makes the code snippet harder to be
629
+ # copied.
630
+ # Not a big deal since the code is already been guarded by checking
631
+ # the log level.
632
+ print(message)
633
+
634
+
635
+ def load_args_and_run_compile_fx_inner(path: str):
636
+ from torch._inductor.compile_fx import compile_fx_inner
637
+
638
+ with open(path, "rb") as f:
639
+ args, kwargs = pickle.load(f)
640
+
641
+ def handle_tensor(x):
642
+ if isinstance(x, TensorMetadataHolder):
643
+ return torch._dynamo.testing.rand_strided(
644
+ x.tensor_metadata.shape,
645
+ x.tensor_metadata.stride,
646
+ x.tensor_metadata.dtype,
647
+ x.device,
648
+ )
649
+ else:
650
+ return x
651
+
652
+ fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
653
+ with fake_mode, config.patch("save_args", False):
654
+ args, kwargs = tree_map(handle_tensor, (args, kwargs))
655
+ return compile_fx_inner(*args, **kwargs)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/decomposition.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ import math
4
+ import sys
5
+ import typing
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch._decomp as decomp
10
+ import torch._prims_common as utils
11
+ import torch.ao.quantization.fx._decomposed
12
+ from torch._decomp import (
13
+ core_aten_decompositions,
14
+ get_decompositions,
15
+ remove_decompositions,
16
+ )
17
+ from torch._decomp.decompositions import (
18
+ _grid_sampler_2d as decomp_grid_sampler_2d,
19
+ pw_cast_for_opmath,
20
+ )
21
+ from torch._decomp.decompositions_for_rng import extra_random_decomps
22
+ from torch._higher_order_ops.out_dtype import out_dtype
23
+ from torch._prims_common import (
24
+ elementwise_dtypes,
25
+ ELEMENTWISE_TYPE_PROMOTION_KIND,
26
+ type_to_dtype,
27
+ )
28
+
29
+ from . import config, inductor_prims
30
+
31
+ log = logging.getLogger(__name__)
32
+ aten = torch.ops.aten
33
+ prims = torch.ops.prims
34
+ quantized_decomposed = torch.ops.quantized_decomposed
35
+
36
+ inductor_decompositions = get_decompositions(
37
+ [
38
+ aten._adaptive_avg_pool2d_backward,
39
+ aten.arange,
40
+ aten.bitwise_and_,
41
+ aten.bitwise_or_,
42
+ aten.clamp_min_,
43
+ aten.dist,
44
+ aten.empty_like,
45
+ aten.flip,
46
+ aten.gelu,
47
+ aten.hardtanh,
48
+ aten.index_select,
49
+ aten.lcm,
50
+ aten.leaky_relu,
51
+ aten.linalg_vector_norm,
52
+ aten._log_softmax,
53
+ aten.max_pool2d_with_indices_backward,
54
+ aten._native_batch_norm_legit,
55
+ aten._native_batch_norm_legit_functional,
56
+ aten._native_batch_norm_legit_no_training,
57
+ aten.native_batch_norm,
58
+ aten.native_group_norm,
59
+ aten.native_layer_norm,
60
+ aten.nll_loss2d_backward,
61
+ aten._softmax,
62
+ aten.sin_,
63
+ aten.sqrt_,
64
+ out_dtype,
65
+ aten._to_copy,
66
+ aten.tril_indices,
67
+ aten.triu_indices,
68
+ aten.upsample_bilinear2d.vec,
69
+ ]
70
+ )
71
+ decompositions = {**core_aten_decompositions(), **inductor_decompositions}
72
+
73
+ # Remove unwanted decompositions included via the core ATen decompositions from
74
+ # the Inductor decomp table.
75
+ decomps_to_exclude = [
76
+ aten._unsafe_index,
77
+ aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py
78
+ aten.clamp_max,
79
+ aten.clamp_min,
80
+ aten.glu, # inductor lowers this directly
81
+ aten.split.Tensor, # inductor lowers this directly
82
+ aten.squeeze, # inductor lowers this directly
83
+ aten.sum, # inductor lowers this directly
84
+ aten.unbind, # inductor lowers this directly
85
+ ]
86
+
87
+ remove_decompositions(decompositions, decomps_to_exclude)
88
+
89
+
90
+ def register_decomposition(ops):
91
+ for op in [ops] if callable(ops) else ops:
92
+ if op in decompositions:
93
+ log.warning("duplicate decomp: %s", ops)
94
+ return decomp.register_decomposition(ops, decompositions)
95
+
96
+
97
+ # TODO: for now, inductor doesn't handle asserts
98
+ # because the condition is symbool -> tensor in the graph.
99
+ @register_decomposition([aten._assert_async.msg])
100
+ def assert_async_msg_decomp(tensor, msg):
101
+ return
102
+
103
+
104
+ # Following `assert_async_msg_decomp` and implement as non-op.
105
+ @register_decomposition([aten._functional_assert_async.msg])
106
+ def functional_assert_async_msg_decomp(tensor, msg):
107
+ return
108
+
109
+
110
+ @register_decomposition([aten.sym_constrain_range_for_size.default])
111
+ def sym_constrain_range_for_size(symbol, *, min=None, max=None):
112
+ return
113
+
114
+
115
+ @register_decomposition([aten.clamp])
116
+ @pw_cast_for_opmath
117
+ def clamp(x, min=None, max=None):
118
+ if min is not None:
119
+ x = x.clamp_min(min)
120
+ if max is not None:
121
+ x = x.clamp_max(max)
122
+ return x
123
+
124
+
125
+ @register_decomposition([aten.full])
126
+ def full(size, fill_value, **kwargs):
127
+ dtype = kwargs.get("dtype")
128
+ if dtype is None:
129
+ kwargs["dtype"] = type_to_dtype(type(fill_value))
130
+ return aten.full(size, fill_value, **kwargs)
131
+ return NotImplemented
132
+
133
+
134
+ # Not really sure how to put this into the main library. PrimTorch wants
135
+ # empty_permuted to go to the prim, and typically users don't really want
136
+ # to decompose to empty_strided (but inductor is OK with it, because we are
137
+ # cool with strides and everything goes to empty_strided)
138
+ @register_decomposition([aten.empty_permuted.default])
139
+ def empty_permuted(size, physical_layout, **kwargs):
140
+ perm = [0] * len(size)
141
+ for p, l in enumerate(physical_layout):
142
+ perm[l] = p
143
+ return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
144
+
145
+
146
+ @register_decomposition([aten.convolution_backward])
147
+ def convolution_backward(
148
+ grad_output,
149
+ input,
150
+ weight,
151
+ bias_sizes,
152
+ stride,
153
+ padding,
154
+ dilation,
155
+ transposed,
156
+ output_padding,
157
+ groups,
158
+ output_mask,
159
+ ):
160
+ if not output_mask[2] or grad_output.device.type != "cuda":
161
+ return NotImplemented
162
+ grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
163
+ grad_inp, grad_weight, _ = aten.convolution_backward(
164
+ grad_output,
165
+ input,
166
+ weight,
167
+ bias_sizes,
168
+ stride,
169
+ padding,
170
+ dilation,
171
+ transposed,
172
+ output_padding,
173
+ groups,
174
+ [output_mask[0], output_mask[1], False],
175
+ )
176
+ return (grad_inp, grad_weight, grad_bias)
177
+
178
+
179
+ @register_decomposition([aten.log2])
180
+ def log2(x):
181
+ return torch.log(x) * (1.0 / math.log(2.0))
182
+
183
+
184
+ @register_decomposition([aten.round.decimals])
185
+ def round_dec(x, decimals=0):
186
+ ten_pow_decimals = 10.0**decimals
187
+ return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
188
+
189
+
190
+ @register_decomposition([aten.bmm])
191
+ @pw_cast_for_opmath
192
+ def bmm(self, batch2):
193
+ if config.coordinate_descent_tuning:
194
+ if self.shape[1] == 1 or batch2.shape[2] == 1:
195
+ out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2)
196
+ return out
197
+ if self.device.type == "cpu":
198
+ if self.size(1) == 1 and batch2.size(-1) == 1:
199
+ return torch.sum(
200
+ self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
201
+ ).unsqueeze(1)
202
+ return NotImplemented
203
+
204
+
205
+ @register_decomposition([aten.addmm])
206
+ @pw_cast_for_opmath
207
+ def addmm(self, mat1, mat2, beta=1, alpha=1):
208
+ if self.device.type == "cpu":
209
+ if mat1.size(0) == 1 and mat2.size(-1) == 1:
210
+ out = torch.sum(
211
+ mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True
212
+ ).unsqueeze(0)
213
+ return alpha * out + beta * self
214
+ if mat1.size(0) == 1 and mat2.size(0) <= 16 and mat2.size(1) <= 16:
215
+ out = (mat1.T * mat2).sum(dim=0, keepdim=True)
216
+ return alpha * out + beta * self
217
+ return NotImplemented
218
+
219
+
220
+ @register_decomposition([aten.mm])
221
+ @pw_cast_for_opmath
222
+ def mm(self, input2):
223
+ from torch.fx.experimental.symbolic_shapes import (
224
+ definitely_true,
225
+ guard_size_oblivious,
226
+ )
227
+
228
+ # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
229
+ # todo: Look into why and fix it (hopefully)
230
+ if config.coordinate_descent_tuning:
231
+ if self.shape[0] == 1 or input2.shape[1] == 1:
232
+ return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
233
+ if self.device.type == "cpu":
234
+ if (
235
+ guard_size_oblivious(self.size(-1) == 1)
236
+ and guard_size_oblivious(self.size(0) > 0)
237
+ and guard_size_oblivious(input2.size(0) == 1)
238
+ and (self.dtype == input2.dtype)
239
+ and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32)
240
+ ):
241
+ return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
242
+ if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious(
243
+ input2.size(-1) == 1
244
+ ):
245
+ return torch.sum(
246
+ self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True
247
+ ).unsqueeze(0)
248
+ return NotImplemented
249
+
250
+
251
+ # This pass does two things:
252
+ # - Eliminate cat when there is only one tensor input
253
+ # - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we
254
+ # don't remove ALL empty tensors, only the naughty ones)
255
+ @register_decomposition([aten.cat.default])
256
+ def cat(tensors, dim=0):
257
+ from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
258
+
259
+ def non_empty_tensor(x):
260
+ # For better or worse, this is a valid cat:
261
+ #
262
+ # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)])
263
+ #
264
+ # We'd like to eliminate naughtiness like this for downstream passes
265
+ # like split_cat. The easiest way is to just drop such inputs
266
+ # (guarding that they are non-zero).
267
+ #
268
+ # Is it permissible for this filtering to be size-oblivious? A case
269
+ # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0
270
+ # happened to be zero, we would have liked to have filtered it out.
271
+ # But actually, the ONLY way this could have passed is if u0 == 0,
272
+ # so by the time we get here we have already installed a deferred
273
+ # runtime assert forcing u0 to be zero. So if this hasn't happened,
274
+ # we know that the unbacked SymInt has appropriate size and there are
275
+ # no problems.
276
+ return len(x.shape) != 1 or guard_size_oblivious(x.shape[0] > 0)
277
+
278
+ filtered_tensors = list(filter(non_empty_tensor, tensors))
279
+
280
+ if len(filtered_tensors) == 1:
281
+ return filtered_tensors[0].clone()
282
+ elif 1 < len(filtered_tensors) < len(tensors):
283
+ # on the first call, when we remove empty tensors, we redispatch recursively
284
+ return aten.cat.default(filtered_tensors, dim)
285
+ # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed)
286
+ return NotImplemented
287
+
288
+
289
+ @register_decomposition([aten.angle])
290
+ def angle(x):
291
+ if x.is_complex():
292
+ return torch.where(
293
+ torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
294
+ )
295
+
296
+ # when x is real number
297
+ # if x >= 0, return 0
298
+ # if x < 0, return pi
299
+ # if x is nan, return nan
300
+ _, dtype = elementwise_dtypes(
301
+ x,
302
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
303
+ )
304
+ pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
305
+ ret = torch.where(x < 0, pi, 0.0)
306
+ return torch.where(torch.isnan(x), float("nan"), ret)
307
+
308
+
309
+ @register_decomposition([aten.add])
310
+ def add(x, y, *, alpha=None):
311
+ x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
312
+ y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
313
+ if not x_is_complex_tensor or not y_is_complex_tensor:
314
+ return NotImplemented
315
+ z = y
316
+ if alpha is not None:
317
+ z = alpha * y
318
+ complex_type = torch.promote_types(x.dtype, y.dtype)
319
+ return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type)
320
+
321
+
322
+ @register_decomposition([aten.conj_physical])
323
+ def conj_physical(self):
324
+ assert not self.is_complex(), "TODO: implement this"
325
+ return self
326
+
327
+
328
+ @register_decomposition([aten.lift, aten.detach_])
329
+ def lift(self):
330
+ return self
331
+
332
+
333
+ @register_decomposition([aten.bernoulli.default])
334
+ def bernoulli(self, *, generator=None):
335
+ assert generator is None
336
+ return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
337
+
338
+
339
+ @register_decomposition([aten.fmin, prims.fmin])
340
+ def fmin(self, other):
341
+ return torch.where(torch.isnan(other) | (other > self), self, other)
342
+
343
+
344
+ @register_decomposition([aten.fmax, prims.fmax])
345
+ def fmax(self, other):
346
+ return torch.where(torch.isnan(other) | (other < self), self, other)
347
+
348
+
349
+ @register_decomposition(aten.amax)
350
+ def amax(self, dim=None, keepdim=False):
351
+ if self.dtype == torch.bool:
352
+ return torch.any(self, dim=dim, keepdim=keepdim)
353
+ return NotImplemented
354
+
355
+
356
+ @register_decomposition(aten.amin)
357
+ def amin(self, dim=None, keepdim=False):
358
+ if self.dtype == torch.bool:
359
+ return torch.all(self, dim=dim, keepdim=keepdim)
360
+ return NotImplemented
361
+
362
+
363
+ @register_decomposition([aten.narrow_copy])
364
+ def narrow_copy(self, dim, start, length):
365
+ return torch.narrow(self, dim, start, length).clone()
366
+
367
+
368
+ @register_decomposition([aten.expand_copy])
369
+ def expand_copy(self, size, *, implicit=False):
370
+ return aten.expand(self, size, implicit=implicit).clone()
371
+
372
+
373
+ @register_decomposition([aten.view_copy.default])
374
+ def view_copy_default(self, size):
375
+ return aten.view(self, size).clone()
376
+
377
+
378
+ @register_decomposition([aten.view_copy.dtype])
379
+ def view_copy_dtype(self, dtype):
380
+ return self.to(dtype).clone()
381
+
382
+
383
+ def get_like_layout(
384
+ tensor: torch.Tensor, memory_format: Optional[torch.memory_format]
385
+ ) -> torch.memory_format:
386
+ # TODO: _to_copy tensor to stride permutation
387
+ if memory_format is torch.preserve_format or memory_format is None:
388
+ return utils.suggest_memory_format(tensor)
389
+ else:
390
+ return memory_format
391
+
392
+
393
+ @register_decomposition(aten.rand_like)
394
+ def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
395
+ return torch.rand(
396
+ [*self.size()],
397
+ dtype=dtype or self.dtype,
398
+ device=device or self.device,
399
+ **kwargs,
400
+ ).to(memory_format=get_like_layout(self, memory_format))
401
+
402
+
403
+ @register_decomposition(aten.randn_like)
404
+ def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
405
+ return torch.randn(
406
+ [*self.size()],
407
+ dtype=dtype or self.dtype,
408
+ device=device or self.device,
409
+ **kwargs,
410
+ ).to(memory_format=get_like_layout(self, memory_format))
411
+
412
+
413
+ @register_decomposition(aten.full_like)
414
+ def full_like(
415
+ self,
416
+ fill_value,
417
+ *,
418
+ dtype=None,
419
+ layout=None,
420
+ device=None,
421
+ pin_memory=False,
422
+ requires_grad=False,
423
+ memory_format=torch.preserve_format,
424
+ ):
425
+ return torch.full(
426
+ [*self.size()],
427
+ fill_value,
428
+ dtype=dtype or self.dtype,
429
+ layout=layout or self.layout,
430
+ device=device or self.device,
431
+ requires_grad=requires_grad,
432
+ ).to(memory_format=get_like_layout(self, memory_format))
433
+
434
+
435
+ @register_decomposition(aten.randint_like.default)
436
+ def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs):
437
+ return aten.randint.low(
438
+ 0,
439
+ high,
440
+ [*self.size()],
441
+ dtype=dtype or self.dtype,
442
+ device=device or self.device,
443
+ **kwargs,
444
+ ).to(memory_format=get_like_layout(self, memory_format))
445
+
446
+
447
+ @register_decomposition(aten.randint_like.low_dtype)
448
+ def randint_like_low(
449
+ self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs
450
+ ):
451
+ return aten.randint.low(
452
+ low,
453
+ high,
454
+ [*self.size()],
455
+ dtype=dtype or self.dtype,
456
+ device=device or self.device,
457
+ **kwargs,
458
+ ).to(memory_format=get_like_layout(self, memory_format))
459
+
460
+
461
+ @register_decomposition(aten.randint.default)
462
+ def randint(high, size, **kwargs):
463
+ return aten.randint.low(0, high, size, **kwargs)
464
+
465
+
466
+ # The difference between quantize_per_tensor.default and quantize_per_tensor.tensor is
467
+ # scale and zero_point is scalar or scalar tensor
468
+ @register_decomposition(quantized_decomposed.quantize_per_tensor.default)
469
+ def quantize_per_tensor_default_decomp_impl(
470
+ input: torch.Tensor,
471
+ scale: float,
472
+ zero_point: int,
473
+ quant_min: int,
474
+ quant_max: int,
475
+ dtype: torch.dtype,
476
+ ) -> torch.Tensor:
477
+ if input.dtype == torch.bfloat16:
478
+ input = input.to(torch.float32)
479
+ inv_scale = 1.0 / scale
480
+ return torch.clamp(
481
+ torch.round(input * inv_scale) + zero_point, quant_min, quant_max
482
+ ).to(dtype)
483
+
484
+
485
+ # The difference between dequantize_per_tensor.default and dequantize_per_tensor.tensor is
486
+ # scale and zero_point is scalar or scalar tensor
487
+ @register_decomposition(quantized_decomposed.dequantize_per_tensor.default)
488
+ def dequantize_per_tensor_default_decomp_impl(
489
+ input: torch.Tensor,
490
+ scale: float,
491
+ zero_point: int,
492
+ quant_min: int,
493
+ quant_max: int,
494
+ dtype: torch.dtype,
495
+ ) -> torch.Tensor:
496
+ return (input.to(torch.float32) - zero_point) * scale
497
+
498
+
499
+ @register_decomposition(quantized_decomposed.quantize_per_tensor.tensor)
500
+ def quantize_per_tensor_tensor_decomp_impl(
501
+ input: torch.Tensor,
502
+ scale: torch.Tensor,
503
+ zero_point: torch.Tensor,
504
+ quant_min: int,
505
+ quant_max: int,
506
+ dtype: torch.dtype,
507
+ ) -> torch.Tensor:
508
+ if input.dtype == torch.bfloat16:
509
+ input = input.to(torch.float32)
510
+ inv_scale = 1.0 / scale
511
+ return torch.clamp(
512
+ torch.round(input * inv_scale) + zero_point, quant_min, quant_max
513
+ ).to(dtype)
514
+
515
+
516
+ @register_decomposition(quantized_decomposed.dequantize_per_tensor.tensor)
517
+ def dequantize_per_tensor_tensor_decomp_impl(
518
+ input: torch.Tensor,
519
+ scale: torch.Tensor,
520
+ zero_point: torch.Tensor,
521
+ quant_min: int,
522
+ quant_max: int,
523
+ dtype: torch.dtype,
524
+ ) -> torch.Tensor:
525
+ return (input.to(torch.float32) - zero_point.to(torch.int32)) * scale.to(
526
+ torch.float32
527
+ )
528
+
529
+
530
+ @register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
531
+ def q_embedding_bag_byte_unpack_decomp(packed):
532
+ def bitcast_u8_to_f32(u8):
533
+ x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
534
+ if sys.byteorder == "little":
535
+ return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None]
536
+ else:
537
+ return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None]
538
+
539
+ scales = bitcast_u8_to_f32(packed[..., -8:-4])
540
+ offsets = bitcast_u8_to_f32(packed[..., -4:])
541
+ return packed[..., :-8].to(torch.float32) * scales + offsets
542
+
543
+
544
+ @register_decomposition([aten.grid_sampler_2d])
545
+ @pw_cast_for_opmath
546
+ def grid_sampler_2d(
547
+ a: torch.Tensor,
548
+ grid: torch.Tensor,
549
+ interpolation_mode: int = 0,
550
+ padding_mode: int = 0,
551
+ align_corners: bool = False,
552
+ ) -> torch.Tensor:
553
+ # We do not expand the grid (_expand_grid=False) on cpu for performance reasons
554
+ # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
555
+ # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
556
+ # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
557
+ # Thus we apply this hack to not expand the grid for this case.
558
+ _expand_grid = not (
559
+ a.device == torch.device("cpu")
560
+ and interpolation_mode == 0
561
+ and a.is_contiguous(memory_format=torch.contiguous_format)
562
+ )
563
+
564
+ output = decomp_grid_sampler_2d(
565
+ a,
566
+ grid=grid,
567
+ interpolation_mode=interpolation_mode,
568
+ padding_mode=padding_mode,
569
+ align_corners=align_corners,
570
+ _expand_grid=_expand_grid,
571
+ )
572
+ return output
573
+
574
+
575
+ @register_decomposition(aten._foreach_addcmul.Scalar)
576
+ def _foreach_addcmul_scalar(self, left_tensors, right_tensors, scalar=1):
577
+ return aten._foreach_add.List(
578
+ self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
579
+ )
580
+
581
+
582
+ @register_decomposition(aten._foreach_addcdiv.Scalar)
583
+ def _foreach_addcdiv_scalar(self, left_tensors, right_tensors, scalar=1):
584
+ return aten._foreach_add.List(
585
+ self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
586
+ )
587
+
588
+
589
+ @register_decomposition(aten._foreach_lerp.Scalar)
590
+ def _foreach_lerp_scalar(start_tensors, end_tensors, weight):
591
+ return aten._foreach_add.List(
592
+ start_tensors,
593
+ aten._foreach_mul.Scalar(
594
+ aten._foreach_sub.List(end_tensors, start_tensors), weight
595
+ ),
596
+ )
597
+
598
+
599
+ @aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
600
+ @register_decomposition(aten.miopen_batch_norm)
601
+ def miopen_batch_norm(
602
+ input: torch.Tensor,
603
+ weight: torch.Tensor,
604
+ bias: typing.Optional[torch.Tensor],
605
+ running_mean: typing.Optional[torch.Tensor],
606
+ running_var: typing.Optional[torch.Tensor],
607
+ training: bool,
608
+ exponential_average_factor: float,
609
+ epsilon: float,
610
+ ):
611
+ a, b, c = aten.native_batch_norm(
612
+ input,
613
+ weight,
614
+ bias,
615
+ running_mean,
616
+ running_var,
617
+ training,
618
+ exponential_average_factor,
619
+ epsilon,
620
+ )
621
+
622
+ if training:
623
+ return (a, b, c)
624
+ return (
625
+ a,
626
+ weight.new_zeros((0,)),
627
+ weight.new_zeros((0,)),
628
+ )
629
+
630
+
631
+ @functools.lru_cache(None)
632
+ def fast_random_decomps():
633
+ return {**decompositions, **extra_random_decomps}
634
+
635
+
636
+ def select_decomp_table():
637
+ """decomps can change based on config"""
638
+ if config.fallback_random:
639
+ return decompositions
640
+ return fast_random_decomps()
641
+
642
+
643
+ @register_decomposition(aten.masked_scatter)
644
+ def masked_scatter(self, mask, source):
645
+ if self.device.type == "cuda":
646
+ # This two-step algorithm is the same as eager CUDA, for eager CPU we
647
+ # use a 1-shot serial iteration.
648
+ self, mask = aten.broadcast_tensors([self, mask])
649
+ source_idx = mask.reshape(-1).cumsum(0) - 1
650
+ return inductor_prims.masked_scatter_with_index(self, mask, source_idx, source)
651
+ return NotImplemented
652
+
653
+
654
+ @register_decomposition(quantized_decomposed.choose_qparams.tensor)
655
+ def choose_qparams_tensor(
656
+ input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
657
+ ):
658
+ min_val, max_val = torch.aminmax(input)
659
+ scale = (max_val - min_val) / float(quant_max - quant_min)
660
+ scale = torch.max(scale, torch.Tensor([eps]))
661
+ zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
662
+ zero_point = torch.clamp(zero_point, quant_min, quant_max)
663
+ return scale.to(torch.float64), zero_point.to(torch.int64)
664
+
665
+
666
+ @register_decomposition(aten.put)
667
+ def put(self, index, source, accumulate=False):
668
+ flattened = self.flatten()
669
+ flattened = torch.index_put(
670
+ flattened, [index], source.reshape(index.shape), accumulate
671
+ )
672
+ return flattened.reshape(self.shape)
673
+
674
+
675
+ @register_decomposition(aten.put_)
676
+ def put_(self, index, source, accumulate=False):
677
+ out = aten.put(self, index, source, accumulate=accumulate)
678
+ return self.copy_(out)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/test_operators.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.library
2
+ from torch import Tensor
3
+ from torch.autograd import Function
4
+
5
+ _test_lib_def = torch.library.Library("_inductor_test", "DEF")
6
+ _test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag)
7
+
8
+ _test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
9
+ for dispatch_key in ("CPU", "CUDA", "Meta"):
10
+ _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
11
+
12
+
13
+ class Realize(Function):
14
+ @staticmethod
15
+ def forward(ctx, x):
16
+ return torch.ops._inductor_test.realize(x)
17
+
18
+ @staticmethod
19
+ def backward(ctx, grad_output):
20
+ return grad_output
21
+
22
+
23
+ def realize(x: Tensor) -> Tensor:
24
+ return Realize.apply(x)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/virtualized.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file provides a number of "global" variables/handlers that are actually
3
+ thread local and dynamically scoped, with Inductor patching them to various
4
+ implementations depending on the situation.
5
+
6
+ These handlers are interacted with in a fairly stylized way. Typically,
7
+ we will import V from this module::
8
+
9
+ from .virtualized import V
10
+
11
+ Various handlers are accessible as attributes on this module; for example,
12
+ you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with
13
+ a number.
14
+
15
+ There are a few distinct usage patterns for virtualized global variables:
16
+
17
+ 1. Implicit argument passing. Examples: ``V.current_node``, ``V.aot_compilation``.
18
+ Use ``V.set_current_node`` to change what the current node is while we're
19
+ executing some region of code, so code inside that region can query ``V.current_node``
20
+ to find out what it is. This is often more convenient than manually threading
21
+ the current node as an argument through all call stacks.
22
+
23
+ 2. Per-compilation global state. Examples: ``V.fake_mode``, ``V.graph``. For a
24
+ given ``compile_fx`` invocation, these typically don't change, but they are
25
+ associated with some internal state so they cannot just be global functions.
26
+ We install these objects at the beginning of compilation and then you can
27
+ conveniently access them without having to pass them around.
28
+
29
+ 3. Alternate define-by-run interpretations. Examples: ``V.ops``, ``V.kernel``.
30
+ A commonly used IR in Inductor is define-by-run: instead of maintaining
31
+ explicit syntax data structures, we instead represent loop bodies as
32
+ callable functions, which internally invoke operations defined on
33
+ ``V.ops``. To perform semantic analysis, print or code generate these
34
+ operations, we dynamically patch ``V.ops`` with an alternate handler with
35
+ the intended semantics and then run the callable function. For example, to
36
+ extract out a traditional (FX) graph representation of the define-by-run
37
+ IR, simply install a handler that records each ``ops`` call to a graph.
38
+
39
+ TODO: Define a parent class / protocol that defines all of the operations
40
+ V.ops is expected to support.
41
+
42
+ It is typically an error to access a virtualized global without having installed
43
+ an appropriate handler (you will get a NullHandler), although in some cases we
44
+ provide a default implementation.
45
+
46
+ One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is
47
+ ubiquitous enough to have its own top level variable, so you will typically see
48
+ ``ops.constant(...)`` rather than ``V.ops.constant(...)``. In fact, these are not
49
+ equivalent; the former interface supports arithmetic overloads like ``x + y``
50
+ instead of forcing ``ops.add(x, y)``, so it should be preferred.
51
+
52
+ Some operators are seemingly unused, but they are implicitly used by ops_wrapper.
53
+ In particular, we typically have an operator for every basic pointwise PyTorch operation
54
+ supported.
55
+ """
56
+
57
+ from __future__ import annotations
58
+
59
+ from contextlib import AbstractContextManager, contextmanager
60
+ from threading import local
61
+ from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union
62
+
63
+ from .ops_handler import ( # noqa: F401
64
+ KernelFormatterHandler,
65
+ MockHandler,
66
+ OpsHandler,
67
+ ReductionType,
68
+ StoreMode,
69
+ WrapperHandler,
70
+ )
71
+
72
+ if TYPE_CHECKING:
73
+ import torch
74
+ from torch._inductor.debug import DebugContext
75
+ from torch._inductor.graph import GraphLowering
76
+ from torch._inductor.ir import InterpreterShim
77
+ from torch._subclasses import FakeTensorMode
78
+
79
+ threadlocal = local()
80
+
81
+ T = TypeVar("T")
82
+
83
+
84
+ class NullHandler:
85
+ """
86
+ Sentinel indicating that a global variable is unset ala None. Typically,
87
+ attempting to access the global variable before it's set is an error, but with
88
+ NullHandler it won't fail until you try to access an attribute on it.
89
+ """
90
+
91
+ pass
92
+
93
+
94
+ class Virtualized(Generic[T]):
95
+ """
96
+ Implements a global variable that redirects via thread local variable
97
+ (NB: construct this class to create the global variable; this is not
98
+ a singleton class!)
99
+
100
+ This allows us to swap in different op implementations in codegen.
101
+
102
+ NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is
103
+ the default value of the variable), we sometimes use these variables to
104
+ store other things, like booleans.
105
+ """
106
+
107
+ def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]):
108
+ self._key: str = f"__torchinductor_{vname}"
109
+ self._default = default
110
+
111
+ def _set_handler(self, value: T) -> AbstractContextManager[None]:
112
+ prior = self._get_handler()
113
+ setattr(threadlocal, self._key, value)
114
+
115
+ @contextmanager
116
+ def ctx():
117
+ try:
118
+ yield
119
+ finally:
120
+ self._set_handler(prior)
121
+
122
+ return ctx()
123
+
124
+ def _get_handler(self) -> T:
125
+ try:
126
+ return getattr(threadlocal, self._key)
127
+ except AttributeError:
128
+ # TODO: To be honest, I feel we probably should just error in this
129
+ # case, instead of making a null handler that will probably error
130
+ # when you getattr on it
131
+ return self._default() # type: ignore[return-value]
132
+
133
+ def __getattr__(self, name: str) -> Any:
134
+ return getattr(self._get_handler(), name)
135
+
136
+
137
+ class NullKernelHandler(NullHandler):
138
+ """
139
+ We need access `V.kernel.removed_buffers` in DeferredLine class when there
140
+ is no kernel in the context. This happens when codegening the wrapper.
141
+ Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't
142
+ need call 'getattr' with default value which is error prone to typo in
143
+ attribute name.
144
+ """
145
+
146
+ def __init__(self):
147
+ super().__init__()
148
+ self.removed_buffers = set()
149
+ self.inplaced_to_remove = set()
150
+ self.index_dtype = "tl.int64"
151
+
152
+
153
+ _ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
154
+ _graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
155
+ _real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
156
+ _fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
157
+ _kernel: Virtualized[NullKernelHandler] = Virtualized(
158
+ "kernel", NullKernelHandler
159
+ ) # TODO: improve type
160
+ _debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler)
161
+ _interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler)
162
+ _aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler)
163
+ _current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler)
164
+
165
+
166
+ class OpsValue:
167
+ """The return type of most ops calls.
168
+
169
+ This exists so we can overload magic methods, and write mathematical
170
+ expressions much more fluently. So instead of
171
+
172
+ ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1)
173
+
174
+ we can write
175
+
176
+ (_Ap2 * x - _Ap3) * x * x + _1
177
+
178
+ """
179
+
180
+ value: Any
181
+
182
+ def __init__(self, value):
183
+ self.value = value
184
+
185
+ def __str__(self):
186
+ return str(self.value)
187
+
188
+ def __repr__(self):
189
+ return f"OpsValue({self.value!r})"
190
+
191
+ def __add__(self, other):
192
+ return ops.add(self, other)
193
+
194
+ def __mul__(self, other):
195
+ return ops.mul(self, other)
196
+
197
+ def __sub__(self, other):
198
+ return ops.sub(self, other)
199
+
200
+ def __neg__(self):
201
+ return ops.neg(self)
202
+
203
+ def __truediv__(self, other):
204
+ return ops.truediv(self, other)
205
+
206
+ def __floordiv__(self, other):
207
+ return ops.floordiv(self, other)
208
+
209
+ def __mod__(self, other):
210
+ return ops.mod(self, other)
211
+
212
+ def __pow__(self, other):
213
+ return ops.pow(self, other)
214
+
215
+ def __lt__(self, other):
216
+ return ops.lt(self, other)
217
+
218
+ def __le__(self, other):
219
+ return ops.le(self, other)
220
+
221
+ def __eq__(self, other):
222
+ return ops.eq(self, other)
223
+
224
+ def __ne__(self, other):
225
+ return ops.ne(self, other)
226
+
227
+ def __gt__(self, other):
228
+ return ops.gt(self, other)
229
+
230
+ def __ge__(self, other):
231
+ return ops.ge(self, other)
232
+
233
+ def __and__(self, other):
234
+ return ops.bitwise_and(self, other)
235
+
236
+ def __or__(self, other):
237
+ return ops.bitwise_or(self, other)
238
+
239
+ def __xor__(self, other):
240
+ return ops.bitwise_xor(self, other)
241
+
242
+ def __invert__(self):
243
+ return ops.bitwise_not(self)
244
+
245
+ def __rshfit__(self, n):
246
+ return ops.bitwise_right_shift(self, n)
247
+
248
+ def __lshift__(self, n):
249
+ return ops.bitwise_left_shift(self, n)
250
+
251
+
252
+ class OpsWrapper:
253
+ """This wraps any returned IR values into an `OpsValue` instance, so that we
254
+ can overload the magic methods for writing mathematical expressions fluently.
255
+ """
256
+
257
+ def __getattr__(self, name):
258
+ def inner(*args, **kwargs):
259
+ new_args = [OpsWrapper._unwrap(a) for a in args]
260
+ new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
261
+ return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
262
+
263
+ return inner
264
+
265
+ @staticmethod
266
+ def _unwrap(x):
267
+ if isinstance(x, (list, tuple)):
268
+ return tuple(OpsWrapper._unwrap(v) for v in x)
269
+ if isinstance(x, OpsValue):
270
+ return x.value
271
+ return x
272
+
273
+ @staticmethod
274
+ def _wrap(x):
275
+ if isinstance(x, (list, tuple)):
276
+ return tuple(OpsValue(v) for v in x)
277
+ return OpsValue(x)
278
+
279
+ @staticmethod
280
+ def indirect_indexing(index, size, check=True):
281
+ # Returns a sympy value, not IR value
282
+ index = OpsWrapper._unwrap(index)
283
+ return _ops.indirect_indexing(index, size, check)
284
+
285
+
286
+ ops = OpsWrapper()
287
+
288
+
289
+ class _V:
290
+ MockHandler = MockHandler
291
+ KernelFormatterHandler = KernelFormatterHandler
292
+ WrapperHandler = WrapperHandler
293
+
294
+ set_ops_handler: Callable[[Any], Any] = _ops._set_handler
295
+ get_ops_handler: Callable[[], Any] = _ops._get_handler
296
+ set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
297
+ set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
298
+ get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
299
+ set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler
300
+ get_fake_mode: Callable[[], Any] = _fake_mode._get_handler
301
+ set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler
302
+ set_debug_handler: Callable[[Any], Any] = _debug._set_handler
303
+ set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler
304
+ set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler
305
+ get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
306
+ set_current_node: Callable[[Any], Any] = _current_node._set_handler
307
+ get_current_node: Callable[[], Any] = _current_node._get_handler
308
+
309
+ @property
310
+ def ops(self) -> OpsHandler[Any]:
311
+ """The operator handler specific to the current codegen task"""
312
+ return _ops._get_handler()
313
+
314
+ @property
315
+ def graph(self) -> GraphLowering:
316
+ """The graph currently being generated"""
317
+ return _graph._get_handler()
318
+
319
+ @property
320
+ def real_inputs(self):
321
+ """non-fake example inputs"""
322
+ return _real_inputs._get_handler()
323
+
324
+ @property
325
+ def fake_mode(self):
326
+ """The graph currently being generated"""
327
+ return _fake_mode._get_handler()
328
+
329
+ @property
330
+ def kernel(self):
331
+ """The kernel currently being generated"""
332
+ return _kernel._get_handler()
333
+
334
+ @property
335
+ def debug(self):
336
+ return _debug._get_handler()
337
+
338
+ @property
339
+ def interpreter(self):
340
+ return _interpreter._get_handler()
341
+
342
+ @property
343
+ def aot_compilation(self):
344
+ return _aot_compilation._get_handler()
345
+
346
+ @property
347
+ def current_node(self):
348
+ return _current_node._get_handler()
349
+
350
+
351
+ V = _V()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <c10/util/Exception.h>
5
+ #include <c10/util/string_view.h>
6
+
7
+ namespace c10 {
8
+ class Scalar;
9
+ }
10
+
11
+ namespace at {
12
+ struct TensorIterator;
13
+ struct TensorIteratorBase;
14
+ class TensorBase;
15
+ }
16
+
17
+ namespace at::native {
18
+
19
+ // These constants control the approximation behavior of gelu function.
20
+ enum class GeluType {
21
+ None, // Baseline Gelu
22
+ Tanh, // Tahn Gelu Approximation
23
+ END
24
+ };
25
+
26
+ static GeluType get_gelutype_enum(const c10::string_view approximate) {
27
+ if (approximate == "none") {
28
+ return GeluType::None;
29
+ } else if (approximate == "tanh") {
30
+ return GeluType::Tanh;
31
+ } else {
32
+ TORCH_CHECK(false, "approximate argument must be either none or tanh.");
33
+ }
34
+ }
35
+
36
+ static std::string gelutype_to_string(const GeluType type) {
37
+ switch(type) {
38
+ case GeluType::None: return "none";
39
+ case GeluType::Tanh: return "tanh";
40
+ default: TORCH_CHECK(false, "unknown GELU type: ", static_cast<int>(type));
41
+ }
42
+ }
43
+
44
+ using structured_activation_fn = void (*)(TensorIteratorBase&);
45
+ using structured_activation_backward_fn = void (*)(TensorIteratorBase&);
46
+
47
+ using activation_fn = void (*)(TensorIterator&);
48
+ using activation_backward_fn = void (*)(TensorIterator&);
49
+ using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
50
+ using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
51
+ using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
52
+ using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&);
53
+ using hardsigmoid_fn = void(*)(TensorIteratorBase&);
54
+ using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&);
55
+ using hardswish_fn = void(*)(TensorIterator&);
56
+ using hardswish_backward_fn = void(*)(TensorIterator&);
57
+ using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
58
+ using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
59
+ using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
60
+ using elu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&);
61
+ using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&, bool);
62
+ using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
63
+ using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
64
+ using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
65
+ using gelu_fn = void (*)(TensorIteratorBase&, GeluType);
66
+ using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType);
67
+ using glu_jvp_fn = void (*)(TensorIteratorBase&);
68
+
69
+ DECLARE_DISPATCH(elu_fn, elu_stub);
70
+ DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
71
+ DECLARE_DISPATCH(softplus_fn, softplus_stub);
72
+ DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
73
+ DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
74
+ DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
75
+ DECLARE_DISPATCH(threshold_fn, threshold_stub);
76
+ DECLARE_DISPATCH(gelu_fn, GeluKernel);
77
+ DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
78
+ DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
79
+ DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
80
+ DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
81
+ DECLARE_DISPATCH(hardswish_fn, hardswish_stub);
82
+ DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub);
83
+ DECLARE_DISPATCH(shrink_fn, hardshrink_stub);
84
+ DECLARE_DISPATCH(softshrink_fn, softshrink_stub);
85
+ DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);
86
+ DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub);
87
+ DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);
88
+ DECLARE_DISPATCH(structured_activation_fn, glu_stub);
89
+ DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub);
90
+ DECLARE_DISPATCH(glu_jvp_fn, glu_jvp_stub);
91
+ DECLARE_DISPATCH(structured_activation_fn, silu_stub);
92
+ DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub);
93
+ DECLARE_DISPATCH(structured_activation_fn, mish_stub);
94
+ DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub);
95
+ DECLARE_DISPATCH(activation_fn, prelu_stub);
96
+ DECLARE_DISPATCH(activation_backward_fn, prelu_backward_stub);
97
+
98
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <c10/util/ArrayRef.h>
6
+ #include <c10/util/irange.h>
7
+ #include <cmath>
8
+
9
+ namespace at::native {
10
+
11
+ using adaptive_avg_pooling_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
12
+ using adaptive_avg_pooling_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
13
+ DECLARE_DISPATCH(adaptive_avg_pooling_fn, adaptive_avg_pool2d_kernel);
14
+ DECLARE_DISPATCH(adaptive_avg_pooling_backward_fn, adaptive_avg_pool2d_backward_kernel);
15
+
16
+ using adaptive_max_pooling_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
17
+ using adaptive_max_pooling_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
18
+ DECLARE_DISPATCH(adaptive_max_pooling_fn, adaptive_max_pool2d_kernel);
19
+ DECLARE_DISPATCH(adaptive_max_pooling_backward_fn, adaptive_max_pool2d_backward_kernel);
20
+
21
+ static inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
22
+ return (a / b) * c + ((a % b) * c) / b;
23
+ }
24
+
25
+ static inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
26
+ return 1 + ((a + 1) * c - 1) / b;
27
+ }
28
+
29
+ static inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) {
30
+ int64_t ndim = gradOutput_.ndimension();
31
+ for (const auto i : c10::irange(1, ndim)) {
32
+ TORCH_CHECK(gradOutput_.size(i) > 0,
33
+ arg_name, "(): Expected grad_output to have non-zero size for non-batch dimensions, "
34
+ "but grad_output has sizes ", gradOutput_.sizes(), " with dimension ", i,
35
+ " being empty");
36
+ }
37
+ }
38
+
39
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BucketizationUtils.h ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/TypeProperties.h>
5
+ #include <ATen/ScalarOps.h>
6
+
7
+ #ifndef AT_PER_OPERATOR_HEADERS
8
+ #include <ATen/NativeFunctions.h>
9
+ #else
10
+ #include <ATen/ops/result_type.h>
11
+ #endif
12
+
13
+ namespace at::native {
14
+
15
+ // original values given by raw_*. If an original value is not contiguous, will make a contiguous copy to
16
+ // the corresponding trimmed_* value. Additionally, if the dtypes of the boundary and input tensor do not
17
+ // match, will change them to be a common super type so comparisons are done between the same types.
18
+ // For any trimmed_* tensor, if its outgoing value matches what it was incoming (typically null), then the
19
+ // corresponding raw_* version should be used since it was already contiguous of the right type.
20
+ inline void searchsorted_maybe_trim_input_tensors(
21
+ Tensor& trimmed_input,
22
+ Tensor& trimmed_boundaries,
23
+ Tensor& trimmed_sorter,
24
+ const Tensor& raw_input,
25
+ const Tensor& raw_boundaries,
26
+ const Tensor& raw_sorter) {
27
+ bool in_is_contiguous = raw_input.is_contiguous();
28
+ bool bd_is_contiguous = raw_boundaries.is_contiguous();
29
+ bool sort_is_contiguous = raw_sorter.is_contiguous();
30
+
31
+ if (!in_is_contiguous) {
32
+ TORCH_WARN_ONCE("torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due "
33
+ "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value "
34
+ "tensor if possible. This message will only appear once per program.");
35
+ trimmed_input = raw_input.contiguous();
36
+ }
37
+ if (!bd_is_contiguous) {
38
+ TORCH_WARN_ONCE("torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
39
+ "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
40
+ "tensor if possible. This message will only appear once per program.");
41
+ trimmed_boundaries = raw_boundaries.contiguous();
42
+ }
43
+ if (!sort_is_contiguous) {
44
+ TORCH_WARN_ONCE("torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
45
+ "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
46
+ "tensor if possible. This message will only appear once per program.");
47
+ trimmed_sorter = raw_sorter.contiguous();
48
+ }
49
+ if (raw_input.dtype() != raw_boundaries.dtype()) {
50
+ at::native::ResultTypeState state = {};
51
+ state = at::native::update_result_type_state(raw_boundaries, state);
52
+ state = at::native::update_result_type_state(raw_input, state);
53
+ ScalarType common_stype = at::native::result_type(state);
54
+
55
+ TORCH_INTERNAL_ASSERT(common_stype != ScalarType::Undefined);
56
+ if (common_stype != raw_input.scalar_type()) {
57
+ trimmed_input = in_is_contiguous ? raw_input.to(common_stype) : trimmed_input.to(common_stype);
58
+ }
59
+ if (common_stype != raw_boundaries.scalar_type()) {
60
+ trimmed_boundaries = bd_is_contiguous ? raw_boundaries.to(common_stype) : trimmed_boundaries.to(common_stype);
61
+ }
62
+ }
63
+ }
64
+
65
+ /* unused but needed for internal jagged tensor class */
66
+ inline void searchsorted_maybe_trim_input_tensors(
67
+ Tensor& trimmed_input,
68
+ Tensor& trimmed_boundaries,
69
+ const Tensor& raw_input,
70
+ const Tensor& raw_boundaries) {
71
+ Tensor trimmed_sorter;
72
+ Tensor raw_sorter;
73
+ return searchsorted_maybe_trim_input_tensors(
74
+ trimmed_input,
75
+ trimmed_boundaries,
76
+ trimmed_sorter,
77
+ raw_input,
78
+ raw_boundaries,
79
+ raw_sorter);
80
+ }
81
+
82
+ inline bool searchsorted_dims_matched_before_last_dim(const Tensor& boundaries, const Tensor& input) {
83
+ if (boundaries.dim() != input.dim()) {
84
+ return false;
85
+ }
86
+ const auto& dims_bd = boundaries.sizes();
87
+ const auto& dims_in = input.sizes();
88
+ for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) {
89
+ if (dims_bd[dim] != dims_in[dim]) {
90
+ return false;
91
+ }
92
+ }
93
+ return true;
94
+ }
95
+
96
+ inline Tensor searchsorted_scalar_tensor(const Scalar& scalar, const c10::Device& device) {
97
+ auto tensor = c10::scalar_to_tensor(scalar, device);
98
+ // This is to adopt the scalar promotion rules defined in native/TypeProperties.h
99
+ // So we have the same type promotion rules as binary operations.
100
+ tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
101
+ return tensor;
102
+ }
103
+
104
+ inline void searchsorted_pre_check(
105
+ const Tensor& boundaries,
106
+ const Tensor& input,
107
+ const Tensor& output,
108
+ const bool out_int32,
109
+ const bool right,
110
+ const c10::optional<c10::string_view> side_opt,
111
+ const Tensor& sorter) {
112
+ if (side_opt) {
113
+ const c10::string_view side = *side_opt;
114
+ TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ",
115
+ "got ", side);
116
+
117
+ // assume the user has not explicitly set (right=False, side="right")
118
+ TORCH_CHECK(!right || side == "right", "torch.searchsorted(): side and right can't be set to opposites, got side "
119
+ "of ", side, " while right was True");
120
+ }
121
+
122
+ TORCH_CHECK(boundaries.device() == input.device(), "torch.searchsorted(): boundaries and input value tensors ",
123
+ "should have same device type, but got boundaries tensor device type ", boundaries.device(), " and input value ",
124
+ "tensor device type ", input.device());
125
+
126
+ if (sorter.defined()) {
127
+ TORCH_CHECK(sorter.device() == boundaries.device(), "torch.searchsorted(): sorter and boundary tensors should ",
128
+ "have same device type, but got sorter tensor device type ", sorter.device(), " and input value tensor ",
129
+ "device type ", boundaries.device());
130
+
131
+ TORCH_CHECK(sorter.sizes() == boundaries.sizes(), "torch.searchsorted(): boundary and sorter must have the same "
132
+ "size, but got boundary tensor ", boundaries.sizes(), "and got sorter tensor ", sorter.sizes());
133
+
134
+ TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
135
+ "dtype but got dtype ", sorter.scalar_type());
136
+
137
+ if (sorter.numel() > 0) {
138
+ auto minmax = sorter.aminmax();
139
+ int64_t vmin = std::get<0>(minmax).item().toLong();
140
+ int64_t vmax = std::get<1>(minmax).item().toLong();
141
+ TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
142
+ }
143
+ }
144
+
145
+ TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
146
+ "torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we got ",
147
+ "boundaries tensor dim(", boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(",
148
+ input.numel(), ")");
149
+
150
+ TORCH_CHECK(boundaries.dim() != 0, "torch.searchsorted(): boundaries tensor should have positive dimension, but ",
151
+ "got 0 dimension");
152
+
153
+ TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input),
154
+ "torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor ",
155
+ "and input value tensor must match, but we got boundaries tensor ", boundaries.sizes(), " and input value tensor ",
156
+ input.sizes());
157
+
158
+ ScalarType output_dtype = output.scalar_type();
159
+ TORCH_CHECK(
160
+ (output_dtype == ScalarType::Long && !out_int32) ||
161
+ (output_dtype == ScalarType::Int && out_int32),
162
+ "torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) depending on ",
163
+ "whether out_int32 flag is True, but we got output tensor's dtype ", output_dtype,
164
+ " and out_int32 flag is ", (out_int32 ? "True" : "False"));
165
+
166
+ if (out_int32) {
167
+ TORCH_CHECK(boundaries.sizes().back() < INT_MAX,
168
+ "torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX, ", but we got ",
169
+ boundaries.sizes().back());
170
+ }
171
+ }
172
+
173
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvUtils.h ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/TensorUtils.h>
4
+ #include <ATen/detail/CUDAHooksInterface.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+ #include <c10/util/env.h>
7
+ #include <c10/util/irange.h>
8
+
9
+ namespace at::native {
10
+
11
+ using conv_depthwise2d_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
12
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
13
+ at::IntArrayRef, at::IntArrayRef, std::array<bool, 2>);
14
+ DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub);
15
+ using conv_depthwise3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
16
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
17
+ at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
18
+ DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub);
19
+ using cudnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
20
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
21
+ at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
22
+ DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub);
23
+ using mps_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
24
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
25
+ at::IntArrayRef, int64_t, std::array<bool,3>);
26
+ DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub);
27
+ using cudnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
28
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
29
+ at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
30
+ DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub);
31
+ using miopen_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
32
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
33
+ at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
34
+ DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub);
35
+ using miopen_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
36
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
37
+ at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
38
+ DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub);
39
+ using miopen_depthwise_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
40
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
41
+ at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
42
+ DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub);
43
+ using mkldnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
44
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
45
+ at::IntArrayRef, int64_t, std::array<bool,3>);
46
+ DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub);
47
+ using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const c10::optional<Tensor>&,
48
+ IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
49
+ DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub);
50
+ using mkldnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
51
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
52
+ at::IntArrayRef, at::IntArrayRef, int64_t, std::array<bool,3>);
53
+ DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub);
54
+ using slow_conv_dilated2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
55
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
56
+ at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
57
+ DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub);
58
+ using slow_conv_dilated3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
59
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
60
+ at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
61
+ DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub);
62
+ using slow_conv_transpose2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
63
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
64
+ at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
65
+ DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub);
66
+ using slow_conv_transpose3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
67
+ const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
68
+ at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
69
+ DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub);
70
+
71
+ namespace {
72
+ static bool cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
73
+ }
74
+
75
+ static inline bool cudnnv8_enabled_check_debug() {
76
+ static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
77
+ static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
78
+ static uint8_t cudnnv8_debugcount = 0;
79
+ if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
80
+ TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", cudnnv8_heuristic_mode_b);
81
+ cudnnv8_debugcount++;
82
+ }
83
+ return cudnnv8_flag == 1;
84
+ }
85
+
86
+ static inline bool cudnnv8_use_heur_mode_b() {
87
+ return cudnnv8_heuristic_mode_b;
88
+ }
89
+
90
+ // Keep in sync with py::enum_ in Module.cpp
91
+ enum class ConvBackend {
92
+ CudaDepthwise2d,
93
+ CudaDepthwise3d,
94
+ Cudnn,
95
+ CudnnTranspose,
96
+ Empty,
97
+ Miopen,
98
+ MiopenDepthwise,
99
+ MiopenTranspose,
100
+ Mkldnn,
101
+ MkldnnTranspose,
102
+ MkldnnEmpty,
103
+ NnpackSpatial,
104
+ Overrideable,
105
+ Slow2d,
106
+ Slow3d,
107
+ SlowDilated2d,
108
+ SlowDilated3d,
109
+ SlowTranspose2d,
110
+ SlowTranspose3d,
111
+ Winograd3x3Depthwise,
112
+ Xnnpack2d,
113
+ Mps,
114
+ MpsTranspose,
115
+ };
116
+
117
+ // Overload for selecting the convolution backend from the full set of convolution inputs.
118
+ // This overload is exposed to python for testing, etc.
119
+ TORCH_API ConvBackend select_conv_backend(
120
+ const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
121
+ SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation,
122
+ bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
123
+
124
+ TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
125
+ const Tensor& weight,
126
+ const ConvBackend backend);
127
+
128
+ // ---------------------------------------------------------------------
129
+ //
130
+ // Math
131
+ //
132
+ // ---------------------------------------------------------------------
133
+
134
+ constexpr int input_batch_size_dim = 0; // also grad_input
135
+ constexpr int input_channels_dim = 1;
136
+ constexpr int output_batch_size_dim = 0; // also grad_output
137
+ constexpr int output_channels_dim = 1;
138
+ constexpr int weight_output_channels_dim = 0;
139
+ constexpr int weight_input_channels_dim = 1;
140
+
141
+ // Often written as 2 + max_dim (extra dims for batch size and channels)
142
+ constexpr int max_dim = 3;
143
+
144
+ // ---------------------------------------------------------------------
145
+ //
146
+ // Checking
147
+ //
148
+ // ---------------------------------------------------------------------
149
+
150
+ // Used on pad, stride and dilation
151
+ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
152
+ {
153
+ TORCH_CHECK(args.size() <= expected_size,
154
+ "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
155
+ expected_size, " (while checking arguments for ", c, ")");
156
+ TORCH_CHECK(args.size() >= expected_size,
157
+ "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
158
+ expected_size, " (while checking arguments for ", c, ")");
159
+
160
+ auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
161
+ if (num_negative_values > 0){
162
+ std::stringstream ss;
163
+ ss << arg_name << " should be greater than zero but got (";
164
+ std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
165
+ ss << args.back() << ")" << " (while checking arguments for " << c << ")";
166
+ AT_ERROR(ss.str());
167
+ }
168
+ }
169
+
170
+
171
+ // NOTE [ Convolution checks ]
172
+ //
173
+ // NB: For many call sites, it is not strictly necessary to check all of
174
+ // these relationships (for example, for forward convolution, we compute
175
+ // the size of output ourselves, so we don't actually need to check
176
+ // output. However, writing a single function that does everything
177
+ // means we get to reuse it for both forwards and all backwards
178
+ // variants, even when the set of "real" inputs varies. The magic of
179
+ // relational computing!
180
+ //
181
+ // (There is one downside, which is that it is slightly harder to write
182
+ // error messages which are able to distinguish between real inputs
183
+ // (which the user can change) and computed inputs (which the user can
184
+ // only indirectly affect). It would be an interesting exercise to
185
+ // come up with a general framework to handle such situations.)
186
+ static void convolution_shape_check(
187
+ CheckedFrom c,
188
+ const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
189
+ IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
190
+ {
191
+ check_args(c, padding, input->dim() - 2, "padding");
192
+ check_args(c, stride, padding.size(), "stride");
193
+ check_args(c, dilation, padding.size(), "dilation");
194
+
195
+ // Input
196
+ checkDimRange(c, input, 3, 6 /* exclusive */);
197
+ checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
198
+
199
+ // Weight
200
+ checkSameDim(c, input, weight);
201
+
202
+ // TODO: check that output->size() matches output_sizes
203
+ // TODO: check that weight matches output->sizes()
204
+ checkSameDim(c, input, output);
205
+ }
206
+
207
+ // NB: conv_output_size and conv_input_size are not bijections,
208
+ // as conv_output_size loses information; this is why conv_input_size
209
+ // takes an extra output_padding argument to resolve the ambiguity.
210
+
211
+ template <typename T>
212
+ static inline std::vector<T> _conv_output_size(
213
+ ArrayRef<T> input_size, ArrayRef<T> weight_size,
214
+ ArrayRef<T> padding, ArrayRef<T> stride, ArrayRef<T> dilation = ArrayRef<T>()
215
+ ) {
216
+ // ASSERT(input_size.size() > 2)
217
+ // ASSERT(input_size.size() == weight_size.size())
218
+ bool has_dilation = !dilation.empty();
219
+ auto dim = input_size.size();
220
+ std::vector<T> output_size(dim);
221
+ output_size[0] = input_size[input_batch_size_dim];
222
+ output_size[1] = weight_size[weight_output_channels_dim];
223
+ for (const auto d : c10::irange(2, dim)) {
224
+ auto dilation_ = has_dilation ? dilation[d - 2] : 1;
225
+ auto kernel = dilation_ * (weight_size[d] - 1) + 1;
226
+ output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
227
+ }
228
+ return output_size;
229
+ }
230
+
231
+ static inline std::vector<int64_t> conv_output_size(
232
+ IntArrayRef input_size, IntArrayRef weight_size,
233
+ IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
234
+ ) {
235
+ return _conv_output_size(input_size, weight_size, padding, stride, dilation);
236
+ }
237
+
238
+ static inline std::vector<c10::SymInt> conv_output_size(
239
+ SymIntArrayRef input_size, SymIntArrayRef weight_size,
240
+ SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
241
+ ) {
242
+ return _conv_output_size(input_size, weight_size, padding, stride, dilation);
243
+ }
244
+
245
+ template <typename T>
246
+ std::vector<T> _conv_input_size(
247
+ ArrayRef<T> output_size, ArrayRef<T> weight_size,
248
+ ArrayRef<T> padding, ArrayRef<T> output_padding, ArrayRef<T> stride, ArrayRef<T> dilation, T groups
249
+ ) {
250
+ // ASSERT(output_size.size() > 2)
251
+ // ASSERT(output_size.size() == weight_size.size())
252
+ auto dim = output_size.size();
253
+ std::vector<T> input_size(dim);
254
+ input_size[0] = output_size[output_batch_size_dim];
255
+ input_size[1] = weight_size[weight_input_channels_dim] * groups;
256
+ for (const auto d : c10::irange(2, dim)) {
257
+ auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
258
+ input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
259
+ kernel + output_padding[d - 2];
260
+ }
261
+ return input_size;
262
+ }
263
+
264
+ static inline std::vector<c10::SymInt> conv_input_size(
265
+ SymIntArrayRef output_size, SymIntArrayRef weight_size,
266
+ SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
267
+ ) {
268
+ return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
269
+ }
270
+
271
+ static inline std::vector<int64_t> conv_input_size(
272
+ IntArrayRef output_size, IntArrayRef weight_size,
273
+ IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
274
+ ) {
275
+ return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
276
+ }
277
+
278
+ template <typename T>
279
+ std::vector<T> _conv_weight_size(
280
+ ArrayRef<T> input_size, ArrayRef<T> output_size,
281
+ ArrayRef<T> padding, ArrayRef<T> output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
282
+ ) {
283
+ auto dim = input_size.size();
284
+ std::vector<T> weight_size(dim);
285
+ weight_size[0] = output_size[1];
286
+ weight_size[1] = input_size[1] / groups;
287
+ for (const auto d : c10::irange(2, dim)) {
288
+ auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
289
+ + padding[d - 2] * 2 - output_padding[d - 2];
290
+ weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
291
+ }
292
+ return weight_size;
293
+ }
294
+
295
+ static inline std::vector<c10::SymInt> conv_weight_size(
296
+ SymIntArrayRef input_size, SymIntArrayRef output_size,
297
+ SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
298
+ ) {
299
+ return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
300
+ }
301
+
302
+ static inline std::vector<int64_t> conv_weight_size(
303
+ IntArrayRef input_size, IntArrayRef output_size,
304
+ IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
305
+ ) {
306
+ return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
307
+ }
308
+
309
+ static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
310
+ std::vector<int64_t> shape(dim, 1);
311
+ shape[1] = -1;
312
+ return bias.reshape(shape);
313
+ }
314
+
315
+ static inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
316
+ // disable NHWC for float64 input.
317
+ if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
318
+ input.scalar_type() == at::kDouble ||
319
+ weight.scalar_type() == at::kDouble) {
320
+ return at::MemoryFormat::Contiguous;
321
+ }
322
+ long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
323
+ auto input_memory_format = input.suggest_memory_format();
324
+ auto weight_memory_format = weight.suggest_memory_format();
325
+ auto weight_ndim = weight.ndimension();
326
+
327
+ bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
328
+ (input_memory_format == at::MemoryFormat::ChannelsLast) ||
329
+ (weight_memory_format == at::MemoryFormat::ChannelsLast)
330
+ );
331
+ if (can_use_cudnn_channels_last_2d) {
332
+ return at::MemoryFormat::ChannelsLast;
333
+ }
334
+
335
+ bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
336
+ (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
337
+ (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
338
+ );
339
+ if (can_use_cudnn_channels_last_3d) {
340
+ return at::MemoryFormat::ChannelsLast3d;
341
+ }
342
+
343
+ return at::MemoryFormat::Contiguous;
344
+ }
345
+
346
+ // controls whether emptyCache will be called following cudnn conv benchmarking
347
+ TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
348
+ TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
349
+
350
+
351
+ static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
352
+
353
+ // disable NHWC for float64 input.
354
+ if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
355
+ input.scalar_type() == at::kDouble ||
356
+ weight.scalar_type() == at::kDouble) {
357
+ return false;
358
+ }
359
+
360
+ bool can_use_miopen_channels_last_2d = false;
361
+ #if defined(USE_ROCM) && (ROCM_VERSION >= 40300)
362
+ // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
363
+ // See #64427
364
+ static c10::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
365
+
366
+ auto input_memory_format = input.suggest_memory_format();
367
+ auto weight_memory_format = weight.suggest_memory_format();
368
+
369
+ can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
370
+ ( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
371
+ (weight_memory_format == at::MemoryFormat::ChannelsLast) )
372
+ );
373
+ #endif
374
+
375
+ bool can_use_miopen_channels_last_3d = false;
376
+
377
+ return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
378
+ }
379
+
380
+ static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
381
+
382
+ // disable NHWC for float64 input.
383
+ if (input.scalar_type() == at::kDouble ||
384
+ weight.scalar_type() == at::kDouble) {
385
+ return false;
386
+ }
387
+
388
+ // disable NHWC for MkldnnCPU tensor.
389
+ if (input.is_mkldnn() || weight.is_mkldnn()) {
390
+ return false;
391
+ }
392
+
393
+ auto input_memory_format = input.suggest_memory_format();
394
+ auto weight_memory_format = weight.suggest_memory_format();
395
+
396
+ bool can_use_mkldnn_channels_last_2d =
397
+ (input_memory_format == at::MemoryFormat::ChannelsLast) ||
398
+ (weight_memory_format == at::MemoryFormat::ChannelsLast);
399
+
400
+ bool can_use_mkldnn_channels_last_3d =
401
+ (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
402
+ (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
403
+
404
+ return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
405
+ }
406
+
407
+ static inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
408
+
409
+ auto input_memory_format = input.suggest_memory_format();
410
+ auto weight_memory_format = weight.suggest_memory_format();
411
+
412
+ bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
413
+ (input_memory_format == at::MemoryFormat::ChannelsLast) || (
414
+ weight_memory_format == at::MemoryFormat::ChannelsLast));
415
+
416
+ return can_use_thnn_channels_last_2d;
417
+ }
418
+
419
+ static inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
420
+
421
+ // check layout only for xpu tensor.
422
+ if (!input.is_xpu() || !weight.is_xpu()) {
423
+ return false;
424
+ }
425
+
426
+ // disable NHWC for float64 input.
427
+ if (input.scalar_type() == at::kDouble ||
428
+ weight.scalar_type() == at::kDouble) {
429
+ return false;
430
+ }
431
+
432
+ auto input_memory_format = input.suggest_memory_format();
433
+ auto weight_memory_format = weight.suggest_memory_format();
434
+
435
+ bool can_use_xpu_channels_last_2d =
436
+ (input_memory_format == at::MemoryFormat::ChannelsLast) ||
437
+ (weight_memory_format == at::MemoryFormat::ChannelsLast);
438
+
439
+ bool can_use_xpu_channels_last_3d =
440
+ (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
441
+ (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
442
+
443
+ return can_use_xpu_channels_last_2d || can_use_xpu_channels_last_3d;
444
+ }
445
+
446
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Cross.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace at {
6
+ class Tensor;
7
+
8
+ namespace native {
9
+
10
+ using cross_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const int64_t d);
11
+
12
+ DECLARE_DISPATCH(cross_fn, cross_stub);
13
+
14
+ }} // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DistributionTemplates.h ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/Dispatch_v2.h>
6
+ #include <ATen/Generator.h>
7
+ #include <ATen/ExpandUtils.h>
8
+ #include <ATen/Tensor.h>
9
+ #include <ATen/MemoryOverlap.h>
10
+ #include <ATen/NamedTensorUtils.h>
11
+ #include <ATen/native/Resize.h>
12
+ #include <ATen/native/TensorIterator.h>
13
+ #include <c10/util/Optional.h>
14
+ #include <limits>
15
+ #include <cmath>
16
+
17
+ #ifndef AT_PER_OPERATOR_HEADERS
18
+ #include <ATen/Functions.h>
19
+ #else
20
+ #include <ATen/ops/empty_like.h>
21
+ #include <ATen/ops/empty.h>
22
+ #include <ATen/ops/full.h>
23
+ #include <ATen/ops/view_as_real.h>
24
+ #endif
25
+
26
+ namespace at::native::templates {
27
+
28
+ // ==================================================== Random ========================================================
29
+
30
+ // The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`.
31
+ // The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t).
32
+ // This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance:
33
+ //
34
+ // auto actual = torch::empty({3, 3}, torch::half);
35
+ // actual.random_(0, 65504);
36
+ //
37
+ // If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504
38
+ // and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to`
39
+ // moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to
40
+ // the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous
41
+ // available number for torch::half dtype.
42
+ template<typename scalar_t>
43
+ int64_t update_from(int64_t from) {
44
+ static_assert(
45
+ std::is_floating_point<scalar_t>::value ||
46
+ std::is_same<scalar_t, at::Half>::value ||
47
+ std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
48
+ const auto from_plus_1 = static_cast<int64_t>(static_cast<scalar_t>(from + 1));
49
+ if (from_plus_1 < from) {
50
+ int64_t from_ = std::abs(from + 1);
51
+ int n = 0;
52
+ while (from_ >>= 1) ++n;
53
+ // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
54
+ from = from_plus_1 + (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
55
+ }
56
+ return from;
57
+ }
58
+
59
+ template<typename scalar_t>
60
+ int64_t update_to(int64_t to) {
61
+ static_assert(
62
+ std::is_floating_point<scalar_t>::value ||
63
+ std::is_same<scalar_t, at::Half>::value ||
64
+ std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
65
+ const auto to_minus_1 = static_cast<int64_t>(static_cast<scalar_t>(to - 1));
66
+ if (to_minus_1 >= to) {
67
+ int64_t to_ = std::abs(to - 1);
68
+ int n = 0;
69
+ while (to_ >>= 1) ++n;
70
+ // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
71
+ to = to_minus_1 - (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
72
+ }
73
+ return to;
74
+ }
75
+
76
+ // Return earlier for not invoking kernel.
77
+ // See https://github.com/pytorch/pytorch/issues/103418 for more details
78
+ #define CHECK_EMPTY_AND_RETURN(tensor) \
79
+ if (tensor.numel() == 0) { \
80
+ return tensor; \
81
+ }
82
+
83
+ template<template<typename> class random_kernel, typename RNG>
84
+ at::Tensor& random_impl(at::Tensor& self, c10::optional<Generator> generator) {
85
+ CHECK_EMPTY_AND_RETURN(self);
86
+ auto iter = at::TensorIterator::borrowing_nullary_op(self);
87
+ random_kernel<RNG>()(iter, generator);
88
+ return self;
89
+ }
90
+
91
+ #define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \
92
+ TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \
93
+
94
+ #define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \
95
+ if (var < -(1LL << digits) || var > (1LL << digits)) { \
96
+ TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \
97
+ "Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \
98
+ "This warning will become an error in version 1.7 release, please fix the code in advance"); \
99
+ }
100
+
101
+ static void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) {
102
+ const auto scalar_type = typeMetaToScalarType(dtype);
103
+ if (isFloatingType(scalar_type)) {
104
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] {
105
+ const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
106
+ const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
107
+ CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
108
+ CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
109
+
110
+ constexpr auto digits = std::numeric_limits<scalar_t>::digits;
111
+ WARN_OUT_OF_BOUNDS(from, "from", digits, dtype);
112
+ WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype);
113
+ });
114
+ } else if (scalar_type == kUInt64) {
115
+ // When you do a comparison between int64_t and uint64_t, the usual
116
+ // arithmetic conversions say that the int64_t value is promoted to
117
+ // unsigned. But this conversion wraps around: if I had -1 as my int64_t,
118
+ // then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never
119
+ // the right thing to do.
120
+ CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype);
121
+ CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype);
122
+ } else if (isIntegralType(scalar_type, /*includeBool=*/true)) {
123
+ AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() {
124
+ const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());
125
+ const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
126
+ CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
127
+ CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
128
+ }), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool);
129
+ } else {
130
+ TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types");
131
+ }
132
+ }
133
+
134
+ template<template<typename> class random_from_to_kernel, typename RNG>
135
+ at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, c10::optional<int64_t> to_opt, c10::optional<Generator> generator) {
136
+ uint64_t range = 0;
137
+ auto iter = at::TensorIterator::borrowing_nullary_op(self);
138
+ if (to_opt.has_value()) {
139
+ // [from, to)
140
+ int64_t to = *to_opt;
141
+ TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
142
+ if (isFloatingType(iter.dtype())) {
143
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] {
144
+ from = update_from<scalar_t>(from);
145
+ to = update_to<scalar_t>(to);
146
+ TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
147
+ });
148
+ }
149
+ check_from_to_in_range(from, to - 1, self.dtype());
150
+ CHECK_EMPTY_AND_RETURN(self);
151
+ range = static_cast<uint64_t>(to) - static_cast<uint64_t>(from);
152
+ random_from_to_kernel<RNG>()(iter, range, from, generator);
153
+ } else if (from != std::numeric_limits<int64_t>::lowest()) {
154
+ // [from, std::numeric_limits<int64_t>::max()]
155
+ int64_t to_inc = 0;
156
+ if (isFloatingType(iter.dtype())) {
157
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] {
158
+ constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
159
+ to_inc = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max() : static_cast<int64_t>(scalar_t_max);
160
+ from = update_from<scalar_t>(from);
161
+ TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc);
162
+ });
163
+ } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
164
+ AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] {
165
+ if constexpr (std::is_same_v<scalar_t, bool>) {
166
+ to_inc = static_cast<int64_t>(true);
167
+ } else {
168
+ to_inc = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
169
+ }
170
+ }), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool);
171
+ } else {
172
+ TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types");
173
+ }
174
+ check_from_to_in_range(from, to_inc, self.dtype());
175
+ CHECK_EMPTY_AND_RETURN(self);
176
+ range = static_cast<uint64_t>(to_inc) - static_cast<uint64_t>(from) + 1;
177
+ random_from_to_kernel<RNG>()(iter, range, from, generator);
178
+ } else {
179
+ // [std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max()]
180
+ // range = 2^64
181
+ CHECK_EMPTY_AND_RETURN(self);
182
+ random_from_to_kernel<RNG>()(iter, generator);
183
+ }
184
+ return self;
185
+ }
186
+
187
+ // ==================================================== Normal ========================================================
188
+
189
+ #define CHECK_NORMAL_TENSOR_STD(std) \
190
+ do { \
191
+ TORCH_CHECK( \
192
+ !std.is_complex(), \
193
+ "normal expects standard deviation to be non-complex"); \
194
+ TORCH_CHECK( \
195
+ std.numel() == 0 || std.is_meta() || std.min().ge(0).item<bool>(), \
196
+ "normal expects all elements of std >= 0.0"); \
197
+ } while (0)
198
+
199
+ #define CHECK_NORMAL_STD(std) \
200
+ TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std);
201
+
202
+ template<template<typename> class normal_kernel, typename RNG>
203
+ Tensor& normal_impl_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
204
+ CHECK_NORMAL_STD(std);
205
+ CHECK_EMPTY_AND_RETURN(self);
206
+
207
+ if (self.is_complex()) {
208
+ auto float_tensor = at::view_as_real(self);
209
+ // variance for normal distribution of the real and imaginary values
210
+ // is half of the input variance
211
+ normal_kernel<RNG>()(float_tensor, mean, std/(std::sqrt(2)), gen);
212
+ } else {
213
+ normal_kernel<RNG>()(self, mean, std, gen);
214
+ }
215
+ return self;
216
+ }
217
+
218
+ template<template<typename> class normal_kernel, typename RNG>
219
+ Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, c10::optional<Generator> gen) {
220
+ CHECK_NORMAL_STD(std);
221
+ auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous);
222
+ auto shape = at::infer_size(mean.sizes(), std_tensor.sizes());
223
+ at::native::resize_output(output, shape);
224
+ normal_impl_<normal_kernel, RNG>(output, 0, std, gen);
225
+ output.add_(mean);
226
+ return output;
227
+ }
228
+
229
+ template<template<typename> class normal_kernel, typename RNG>
230
+ Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, c10::optional<Generator> gen) {
231
+ CHECK_NORMAL_TENSOR_STD(std);
232
+ auto mean_tensor = at::full({}, mean, output.options());
233
+ auto shape = at::infer_size(mean_tensor.sizes(), std.sizes());
234
+ at::native::resize_output(output, shape);
235
+ normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
236
+ // CUDA NB: addcmul_out copies the tensor to be added into the output.
237
+ // The previous function here was addcmul_out(output, mean_tensor, output, std, 1);
238
+ // The third argument is not a constant reference and hence the samples in output are overwritten.
239
+ // Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std
240
+ output.mul_(std).add_(mean_tensor);
241
+ return output;
242
+ }
243
+
244
+ template<template<typename> class normal_kernel, typename RNG>
245
+ Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
246
+ CHECK_NORMAL_TENSOR_STD(std);
247
+ auto shape = at::infer_size(mean.sizes(), std.sizes());
248
+ at::native::resize_output(output, shape);
249
+ normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
250
+ // CUDA NB: addcmul_out copies the tensor to be added into the output.
251
+ // The previous function here was addcmul_out(output, mean, output, std, 1);
252
+ // The third argument is not a constant reference and hence the samples in output are overwritten.
253
+ // Consequently, the computation performed is mean + mean * std instead of mean + output * std
254
+ output.mul_(std).add_(mean);
255
+ return output;
256
+ }
257
+
258
+ template<template<typename> class normal_kernel, typename RNG>
259
+ Tensor normal_impl(const Tensor& mean, double std, c10::optional<Generator> gen) {
260
+ CHECK_NORMAL_STD(std);
261
+ Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous);
262
+ normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
263
+ return ret;
264
+ }
265
+
266
+ template<template<typename> class normal_kernel, typename RNG>
267
+ Tensor normal_impl(double mean, const Tensor& std, c10::optional<Generator> gen) {
268
+ CHECK_NORMAL_TENSOR_STD(std);
269
+ Tensor ret = at::empty_like(std, MemoryFormat::Contiguous);
270
+ normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
271
+ return ret;
272
+ }
273
+
274
+ template<template<typename> class normal_kernel, typename RNG>
275
+ Tensor normal_impl(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
276
+ CHECK_NORMAL_TENSOR_STD(std);
277
+ auto shape = at::infer_size(mean.sizes(), std.sizes());
278
+ Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous);
279
+ normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
280
+ return ret;
281
+ }
282
+
283
+ // ==================================================== Uniform =======================================================
284
+
285
+ template<template<typename> class uniform_kernel, typename RNG>
286
+ at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, c10::optional<Generator> generator) {
287
+ if (self.is_complex()) {
288
+ CHECK_EMPTY_AND_RETURN(self);
289
+ auto float_tensor = at::view_as_real(self);
290
+ uniform_impl_<uniform_kernel, RNG>(float_tensor, from, to, generator);
291
+ } else {
292
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] {
293
+ const auto dtype = self.dtype();
294
+ const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
295
+ const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
296
+ CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
297
+ CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype);
298
+ TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
299
+ TORCH_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
300
+ "uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
301
+ ">::max(), but found to=", to, " and from=", from,
302
+ " which result in to-from to exceed the limit");
303
+ from = std::min(std::max(from, min), max);
304
+ to = std::max(std::min(to, max), min);
305
+ });
306
+ CHECK_EMPTY_AND_RETURN(self);
307
+ auto iter = at::TensorIterator::borrowing_nullary_op(self);
308
+ uniform_kernel<RNG>()(iter, from, to, generator);
309
+ }
310
+ return self;
311
+ }
312
+
313
+ // ================================================== LogNormal =======================================================
314
+
315
+ template<template<typename> class log_normal_kernel, typename RNG>
316
+ at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, c10::optional<Generator> gen) {
317
+ TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std);
318
+ CHECK_EMPTY_AND_RETURN(self);
319
+ auto iter = TensorIterator::borrowing_nullary_op(self);
320
+ log_normal_kernel<RNG>()(iter, mean, std, gen);
321
+ return self;
322
+ }
323
+
324
+ // =================================================== Geometric ======================================================
325
+
326
+ template<template<typename> class geometric_kernel, typename RNG>
327
+ Tensor& geometric_impl_(Tensor& self, double p, c10::optional<Generator> gen) {
328
+ TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
329
+ CHECK_EMPTY_AND_RETURN(self);
330
+ auto iter = TensorIterator::borrowing_nullary_op(self);
331
+ geometric_kernel<RNG>()(iter, p, gen);
332
+ return self;
333
+ }
334
+
335
+ // ================================================== Exponential =====================================================
336
+
337
+ template<template<typename> class exponential_kernel, typename RNG>
338
+ Tensor& exponential_impl_(Tensor& self, double lambda, c10::optional<Generator> gen) {
339
+ TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda);
340
+ CHECK_EMPTY_AND_RETURN(self);
341
+ auto iter = TensorIterator::borrowing_nullary_op(self);
342
+ exponential_kernel<RNG>()(iter, lambda, gen);
343
+ return self;
344
+ }
345
+
346
+ // ==================================================== Cauchy ========================================================
347
+
348
+ template<template<typename> class cauchy_kernel, typename RNG>
349
+ Tensor& cauchy_impl_(Tensor& self, double median, double sigma, c10::optional<Generator> gen) {
350
+ // TODO: instead of variable name 'sigma', use 'gamma' or 'scale'
351
+ // the variance, squared sigma, is undefined for cauchy distribution
352
+ TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma);
353
+ TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Cauchy distribution is a continuous probability distribution. dtype must be a floating point but you specified ", self.dtype());
354
+ CHECK_EMPTY_AND_RETURN(self);
355
+ auto iter = TensorIterator::borrowing_nullary_op(self);
356
+ cauchy_kernel<RNG>()(iter, median, sigma, gen);
357
+ return self;
358
+ }
359
+
360
+ // ==================================================== Bernoulli =====================================================
361
+
362
+ template<template<typename> class bernoulli_tensor_kernel, typename RNG>
363
+ Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, c10::optional<Generator> gen) {
364
+ CHECK_EMPTY_AND_RETURN(self);
365
+ NoNamesGuard guard;
366
+ at::assert_no_internal_overlap(self);
367
+ bernoulli_tensor_kernel<RNG>()(self, p_, gen);
368
+ return self;
369
+ }
370
+
371
+ template<template<typename> class bernoulli_scalar_kernel, typename RNG>
372
+ Tensor& bernoulli_impl_(Tensor& self, double p, c10::optional<Generator> gen) {
373
+ TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
374
+ CHECK_EMPTY_AND_RETURN(self);
375
+ at::assert_no_internal_overlap(self);
376
+ bernoulli_scalar_kernel<RNG>()(self, p, gen);
377
+ return self;
378
+ }
379
+
380
+ template<template<typename> class bernoulli_tensor_kernel, typename RNG>
381
+ Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, c10::optional<Generator> gen) {
382
+ // result.resize_as_(self) requires self to have same dtype as result, so we
383
+ // use resize_ instead.
384
+ // TODO: Fix resize_as_. See pytorch/pytorch#11665.
385
+ result.resize_(self.sizes());
386
+ bernoulli_impl_<bernoulli_tensor_kernel, RNG>(result, self, gen);
387
+ namedinference::propagate_names(result, self);
388
+ return result;
389
+ }
390
+
391
+ #undef CHECK_OUT_OF_BOUNDS
392
+ #undef WARN_OUT_OF_BOUNDS
393
+
394
+ } // namespace at::native::templates
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Histogram.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ namespace at::native {
7
+
8
+ using histogramdd_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&);
9
+ using histogramdd_linear_fn = void(*)(const Tensor&, const c10::optional<Tensor>&, bool, Tensor&, const TensorList&, bool);
10
+ using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges);
11
+
12
+ DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub);
13
+ DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub);
14
+ DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub);
15
+
16
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexKernel.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/DispatchStub.h>
3
+ #include <c10/util/ArrayRef.h>
4
+
5
+ namespace at {
6
+ class Tensor;
7
+ class TensorBase;
8
+ struct TensorIterator;
9
+ struct TensorIteratorBase;
10
+ }
11
+
12
+ namespace c10 {
13
+ class Scalar;
14
+ }
15
+
16
+ namespace at::native {
17
+
18
+ using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
19
+ using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
20
+ using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
21
+ using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
22
+ using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate);
23
+ using take_fn = void(*)(TensorIterator & iter, const TensorBase& input);
24
+ using flip_fn = void(*)(TensorIterator &, const bool);
25
+ using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
26
+ using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
27
+ using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &);
28
+
29
+ DECLARE_DISPATCH(index_fn, index_stub);
30
+ DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
31
+ DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
32
+ DECLARE_DISPATCH(index_put_fn, index_put_stub);
33
+ DECLARE_DISPATCH(put_fn, put_stub);
34
+ DECLARE_DISPATCH(take_fn, take_stub);
35
+ DECLARE_DISPATCH(flip_fn, flip_stub);
36
+ DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
37
+ DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub);
38
+ DECLARE_DISPATCH(masked_select_fn, masked_select_stub);
39
+ DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub);
40
+
41
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexingUtils.h ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/ExpandUtils.h>
3
+ #include <ATen/native/CanUse32BitIndexMath.h>
4
+ #include <ATen/native/TensorIterator.h>
5
+ #include <ATen/core/IListRef.h>
6
+ #include <c10/util/irange.h>
7
+
8
+ namespace at::native {
9
+
10
+ [[noreturn]]
11
+ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
12
+ TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
13
+ " does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
14
+ }
15
+
16
+
17
+ static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTensorListRef indices) {
18
+ // If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
19
+ std::vector<Tensor> result;
20
+ for (const auto& index_opt : indices) {
21
+ if (!index_opt.has_value()) {
22
+ result.emplace_back();
23
+ } else {
24
+ const auto& index = *index_opt;
25
+ if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
26
+ if (index.scalar_type() == kByte) {
27
+ TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
28
+ " please use a dtype torch.bool instead.");
29
+ }
30
+ // The sizes of the ByteTensor mask or bool tensor must match the sizes of the
31
+ // corresponding dimensions in self
32
+ for (const auto j : c10::irange(index.dim())) {
33
+ int64_t srcIdx = static_cast<int64_t>(result.size() + j);
34
+ if (index.size(j) != self.size(srcIdx)) {
35
+ invalid_mask(self, srcIdx, index, j);
36
+ }
37
+ }
38
+ // Replace with nonzeros
39
+ auto nonzero = index.nonzero();
40
+ for (const auto j : c10::irange(index.dim())) {
41
+ result.emplace_back(nonzero.select(1, j));
42
+ }
43
+ } else {
44
+ result.emplace_back(index);
45
+ }
46
+ }
47
+ }
48
+ return result;
49
+ }
50
+
51
+ static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
52
+ for (const auto& tensor : indices) {
53
+ if (tensor.has_value() && tensor->defined()) {
54
+ auto scalarType = tensor->scalar_type();
55
+ if (allow_int) {
56
+ if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
57
+ TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
58
+ }
59
+ } else {
60
+ if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
61
+ TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
62
+ }
63
+ }
64
+ }
65
+ }
66
+ }
67
+
68
+ inline torch::List<c10::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
69
+ torch::List<c10::optional<Tensor>> result;
70
+ result.reserve(list.size());
71
+ for (const Tensor& a : list) {
72
+ result.push_back(a);
73
+ }
74
+ return result;
75
+ }
76
+
77
+ inline torch::List<c10::optional<Tensor>> toListOfOptionalTensors(ArrayRef<IValue> list) {
78
+ torch::List<c10::optional<Tensor>> result;
79
+ result.reserve(list.size());
80
+ for (const IValue& a : list) {
81
+ result.push_back(a.isTensor() ? c10::optional<Tensor>(a.toTensor()) : c10::optional<Tensor>());
82
+ }
83
+ return result;
84
+ }
85
+
86
+ static C10_UNUSED bool hasContiguousSubspace(TensorList tl) {
87
+ // true if all the non-null tensors are adjacent
88
+ auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
89
+ auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
90
+ auto start = std::find_if(tl.begin(), tl.end(), isDefined);
91
+ auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
92
+ auto it = std::find_if(start, stop.base(), isNull);
93
+ return it == stop.base();
94
+ }
95
+
96
+
97
+ // Transposes the tensor and indices together so that all the non-null indices
98
+ // index the first k dimensions of the tensor. Returns the transposed tensor
99
+ // and the reordered indices. For example:
100
+ // transposeToFront(tensor, {nullptr, a, nullptr, b})
101
+ // returns
102
+ // tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
103
+ static C10_UNUSED std::tuple<Tensor, std::vector<Tensor>>
104
+ transposeToFront(const Tensor& self, TensorList indices) {
105
+ std::vector<int64_t> dims;
106
+ std::vector<Tensor> transposedIndices;
107
+ dims.reserve(self.dim());
108
+ for (const auto i : c10::irange(self.dim())) {
109
+ if (indices[i].defined()) {
110
+ dims.push_back(i);
111
+ transposedIndices.emplace_back(indices[i]);
112
+ }
113
+ }
114
+ for (const auto i : c10::irange(self.dim())) {
115
+ if (!indices[i].defined()) {
116
+ dims.push_back(i);
117
+ transposedIndices.emplace_back();
118
+ }
119
+ }
120
+ return std::make_tuple(self.permute(dims), std::move(transposedIndices));
121
+ }
122
+
123
+ inline std::tuple<Tensor, std::vector<Tensor>, std::vector<int64_t>>
124
+ transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) {
125
+ std::vector<int64_t> dims;
126
+ std::vector<int64_t> invPerm;
127
+ std::vector<Tensor> transposedIndices;
128
+ dims.reserve(self.dim());
129
+ invPerm.resize(self.dim());
130
+ for (const auto i : c10::irange(self.dim())) {
131
+ if (indices[i].defined()) {
132
+ dims.push_back(i);
133
+ transposedIndices.emplace_back(indices[i]);
134
+ }
135
+ }
136
+ for (const auto i : c10::irange(self.dim())) {
137
+ if (!indices[i].defined()) {
138
+ dims.push_back(i);
139
+ transposedIndices.emplace_back();
140
+ }
141
+ }
142
+ for (const auto i : c10::irange(self.dim())) {
143
+ invPerm[dims[i]] = i;
144
+ }
145
+ return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
146
+ }
147
+
148
+ struct AdvancedIndex {
149
+ AdvancedIndex(const Tensor& src, TensorList indices);
150
+
151
+ Tensor src;
152
+ std::vector<Tensor> indices;
153
+ DimVector indexed_sizes;
154
+ DimVector indexed_strides;
155
+ int64_t dims_before;
156
+ int64_t dims_after;
157
+ };
158
+
159
+
160
+ } //namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitsFallback.h ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <ATen/core/dispatch/Dispatcher.h>
3
+ #include <ATen/core/op_registration/op_registration.h>
4
+ #include <ATen/native/UnaryOps.h>
5
+ #include <ATen/native/Resize.h>
6
+ #include <c10/util/irange.h>
7
+ #include <torch/library.h>
8
+
9
+ #ifndef AT_PER_OPERATOR_HEADERS
10
+ #include <ATen/Functions.h>
11
+ #else
12
+ #include <ATen/ops/clone.h>
13
+
14
+ #include <utility>
15
+ #endif
16
+
17
+ namespace at::native {
18
+ // This fallback should only be used for operations that are self inverse and have a corresponding tensor
19
+ // bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
20
+ // Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
21
+ // Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.
22
+
23
+ // NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
24
+ struct MathOpFallback {
25
+ MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(std::move(op_name_)) {}
26
+ virtual bool is_bit_set(const Tensor&) = 0;
27
+ void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
28
+ /*
29
+ Situations to handle:
30
+ 1. Out-of-place operation. Easy: materialize all inputs and
31
+ call it a day.
32
+ 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
33
+ Materialize other inputs as in (1).
34
+ 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
35
+ Materialize other inputs as in (1).
36
+
37
+ It is important to be able to tell if we READ from an argument and if we
38
+ WRITE to an argument. Conservative approach is to assume that we always
39
+ READ from an argument, but in out= operations you can skip
40
+ conjugating inputs on entry that never get used. In the current schema we
41
+ can't easily tell if the operation is in in-place or out= operation.
42
+
43
+ Note:
44
+ 1. Mutable tensorlists containing tensors whose math bit set to true are disallowed.
45
+ 2. Mutable tensors with math bit set to true are unconditionally cloned to ensure
46
+ correct behavior in the case when the mutable tensor shares memory with non mutable arguments.
47
+
48
+ If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory
49
+ with these mutable inputs would read into wrong values in the following cases:
50
+ 1. Non mutable inputs have their math bit set to false.
51
+ 2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory
52
+ with one or more mutable arg(s)) are cloned.
53
+ At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs.
54
+ */
55
+ const auto& arguments = op.schema().arguments();
56
+ const auto num_arguments = arguments.size();
57
+ const auto stack_start = stack->size() - num_arguments;
58
+
59
+ c10::optional<bool> is_write;
60
+ for (const auto i : c10::irange(num_arguments)) {
61
+ // Three possible states:
62
+ // 1. alias_info has no value --> out-of-place operation
63
+ // 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
64
+ // 3. alias_info does have a value, alias_info->is_write=False --> view operation
65
+ const AliasInfo* alias_info = arguments[i].alias_info();
66
+ if (alias_info != nullptr) {
67
+ if (is_write.has_value()) {
68
+ TORCH_CHECK(*is_write == alias_info->isWrite(),
69
+ "Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
70
+ op_name, " fallback doesn't work for operators with a mix "
71
+ "mutable and non-mutable inputs that alias with outputs, "
72
+ "this must be implemented manually. "
73
+ "If you got this error on a core op, please report a bug to PyTorch.");
74
+ } else {
75
+ is_write = alias_info->isWrite();
76
+ }
77
+ }
78
+ }
79
+
80
+ if (is_write.has_value() && !*is_write) {
81
+ // We assume that view operators automatically handle the math bit
82
+ // correctly by propagating the dispatch key in key_set.
83
+ // This is not necessarily always right, so you should test these cases.
84
+ op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
85
+ return;
86
+ }
87
+
88
+ // Mutable inputs with math bit set to True and their clones
89
+ std::vector<std::pair<Tensor, Tensor>> mutable_inputs_with_their_clones;
90
+ for (const auto i : c10::irange(num_arguments)) {
91
+ auto& ivalue = (*stack)[stack_start + i];
92
+ if (!(ivalue.isTensor() || ivalue.isTensorList())) {
93
+ continue;
94
+ }
95
+ const auto& argument = arguments[i];
96
+ bool mut_arg = false;
97
+ if (argument.alias_info()) {
98
+ // Was already tested by is_write loop above
99
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
100
+ mut_arg = true;
101
+ }
102
+ if (ivalue.isTensor()) {
103
+ if (!is_bit_set(ivalue.toTensor())) {
104
+ continue;
105
+ }
106
+ auto tensor = std::move(ivalue).toTensor();
107
+ auto resolved_tensor = at::clone(tensor);
108
+ if (mut_arg) {
109
+ TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ",
110
+ op_name, "bit set to true.");
111
+ mutable_inputs_with_their_clones.emplace_back(std::move(tensor), resolved_tensor);
112
+ }
113
+ (*stack)[stack_start + i] = std::move(resolved_tensor);
114
+ } else if (ivalue.isTensorList()) {
115
+ auto tensors = std::move(ivalue).toTensorList();
116
+ for(const auto j : c10::irange(tensors.size())) {
117
+ const auto& tensor = tensors[j];
118
+ if (!is_bit_set(tensor)) {
119
+ continue;
120
+ }
121
+ TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ",
122
+ op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ",
123
+ op.schema().name());
124
+ tensors[j] = at::clone(tensor);
125
+ }
126
+ (*stack)[stack_start + i] = std::move(tensors);
127
+ }
128
+ }
129
+
130
+ op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
131
+
132
+ TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1);
133
+
134
+ for (std::pair<Tensor, Tensor> mut_tensors: mutable_inputs_with_their_clones) {
135
+ auto& mutable_input = mut_tensors.first;
136
+ auto& cloned_mutable_input = mut_tensors.second;
137
+ auto& ivalue = (*stack)[stack_start];
138
+ auto returned_output = std::move(ivalue).toTensor();
139
+
140
+ // sanity check to ensure that the tensor in stack aliases the cloned_mutable_input
141
+ TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output));
142
+
143
+ // necessary for out= arg
144
+ at::native::resize_output(mutable_input, returned_output.sizes());
145
+
146
+ mutable_input.copy_(returned_output);
147
+ (*stack)[stack_start] = std::move(mutable_input);
148
+ }
149
+ }
150
+
151
+ virtual ~MathOpFallback() = default;
152
+
153
+ DispatchKey key;
154
+ string op_name;
155
+ };
156
+
157
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MaxPooling.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/Parallel.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+ #include <ATen/native/Pool.h>
7
+
8
+ namespace at::native {
9
+
10
+ static void check_max_pool1d(
11
+ const Tensor& self,
12
+ IntArrayRef kernel_size,
13
+ IntArrayRef stride,
14
+ IntArrayRef padding,
15
+ IntArrayRef dilation,
16
+ bool ceil_mode) {
17
+
18
+ TORCH_CHECK(
19
+ self.dim() == 2 || self.dim() == 3,
20
+ "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
21
+ TORCH_CHECK(
22
+ kernel_size.size() == 1,
23
+ "max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
24
+ kernel_size.size());
25
+ TORCH_CHECK(
26
+ stride.empty() || stride.size() == 1,
27
+ "max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
28
+ stride.size());
29
+ TORCH_CHECK(
30
+ padding.size() == 1,
31
+ "max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
32
+ padding.size());
33
+ TORCH_CHECK(
34
+ dilation.size() == 1,
35
+ "max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
36
+ dilation.size());
37
+
38
+ // If stride=None then set it to kernel_size
39
+ if (stride.empty()) {
40
+ stride = kernel_size;
41
+ }
42
+
43
+ TORCH_CHECK(
44
+ kernel_size[0] > 0,
45
+ "max_pool1d() kernel_size must be greater than zero, but got ",
46
+ kernel_size[0]);
47
+ TORCH_CHECK(
48
+ stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
49
+ TORCH_CHECK(
50
+ padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
51
+ TORCH_CHECK(
52
+ padding[0] <= kernel_size[0] / 2,
53
+ "max_pool1d() padding should be at most half of kernel size, but got padding=",
54
+ padding[0],
55
+ " and kernel_size=",
56
+ kernel_size[0]);
57
+ TORCH_CHECK(
58
+ dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
59
+
60
+ const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
61
+ TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
62
+ }
63
+
64
+ // TODO(Heitor) Template by dimension
65
+ struct PoolingParams1D {
66
+ int64_t NB; // Number of batches
67
+ int64_t NC; // Number of channels
68
+ int64_t IW; // Input width
69
+ int64_t OW; // Output width
70
+ int64_t KW; // Kernel width
71
+ int64_t SJ; // Column stride
72
+ int64_t PJ; // Column padding
73
+ int64_t DJ; // Column dilation
74
+
75
+ // Return index of input element for the given kernel and output index
76
+ inline int64_t index(int64_t kj, int64_t oj) const {
77
+ return oj * SJ + kj * DJ - PJ;
78
+ }
79
+
80
+ // Return index of first output within bounds for this kernel index
81
+ inline int64_t valid_output_start(int64_t kj) const {
82
+ int64_t ij = index(kj, 0);;
83
+ return ij < 0 ? at::divup(-ij, SJ) : 0;
84
+ }
85
+
86
+ // Return index one past last output within bounds for this kernel index
87
+ inline int64_t valid_output_end(int64_t kj) const {
88
+ int64_t ij = index(kj, OW - 1);
89
+ return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
90
+ }
91
+ };
92
+
93
+ using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
94
+
95
+ DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
96
+
97
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBase.h>
2
+ #include <algorithm>
3
+ #include <vector>
4
+
5
+ namespace at::native {
6
+
7
+ inline int64_t ensure_nonempty_dim(int64_t dim) {
8
+ return std::max<int64_t>(dim, 1);
9
+ }
10
+
11
+ inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) {
12
+ return t.dim() == 0 ? 1 : t.size(dim);
13
+ }
14
+
15
+ inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) {
16
+ return t.dim() == 0 ? 1 : t.stride(dim);
17
+ }
18
+
19
+ using IdxVec = std::vector<int64_t>;
20
+ inline IdxVec ensure_nonempty_vec(IdxVec vec) {
21
+ if (vec.empty()) {
22
+ vec.push_back(1);
23
+ }
24
+ return vec;
25
+ }
26
+
27
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Padding.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ namespace at::native {
7
+
8
+ using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef);
9
+
10
+ // reflection padding
11
+ DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel);
12
+ DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel);
13
+ DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel);
14
+ DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel);
15
+ DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel);
16
+ DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel);
17
+
18
+ // replication padding
19
+ DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel);
20
+ DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel);
21
+ DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel);
22
+ DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel);
23
+ DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel);
24
+ DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel);
25
+
26
+ namespace padding {
27
+
28
+ template <int dim>
29
+ static inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
30
+
31
+ TORCH_CHECK(padding.size() == 2 * dim,
32
+ "padding size is expected to be ", 2 * dim,
33
+ ", but got: ", padding.size());
34
+
35
+ int input_dim = input.dim();
36
+
37
+ bool is_batch_mode = input_dim == (dim + 2);
38
+
39
+ bool valid_batch_mode = is_batch_mode;
40
+ bool valid_non_batch_mode = !is_batch_mode;
41
+
42
+ if (is_batch_mode) {
43
+ // allow batch size of 0-dim.
44
+ for (const auto d : c10::irange(1, input_dim)) {
45
+ valid_batch_mode = valid_batch_mode && input.size(d) != 0;
46
+ }
47
+ } else {
48
+ for (const auto d : c10::irange(0, input_dim)) {
49
+ valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0;
50
+ }
51
+ }
52
+
53
+ // allow empty batch size but not other dimensions.
54
+ TORCH_CHECK(valid_batch_mode || valid_non_batch_mode,
55
+ "Expected ", dim + 1, "D or ", dim + 2,
56
+ "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
57
+ input.sizes());
58
+ }
59
+
60
+ } // namespace padding
61
+
62
+ } // at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PointwiseOps.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Ternary and higher-order pointwise operations
2
+ #pragma once
3
+
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ namespace c10 {
7
+ class Scalar;
8
+ }
9
+
10
+ namespace at {
11
+
12
+ struct TensorIterator;
13
+ struct TensorIteratorBase;
14
+
15
+ namespace native {
16
+
17
+ using pointwise_fn = void (*)(TensorIterator&, const Scalar& scalar);
18
+ using structured_pointwise_fn = void (*)(TensorIteratorBase&, const Scalar& scalar);
19
+ using pointwise_fn_double = void (*)(TensorIterator&, const Scalar&, double);
20
+
21
+ DECLARE_DISPATCH(structured_pointwise_fn, addcmul_stub);
22
+ DECLARE_DISPATCH(structured_pointwise_fn, addcdiv_stub);
23
+ DECLARE_DISPATCH(pointwise_fn_double, smooth_l1_backward_stub);
24
+ DECLARE_DISPATCH(pointwise_fn_double, huber_backward_stub);
25
+ DECLARE_DISPATCH(pointwise_fn, mse_backward_stub);
26
+
27
+ } // namespace native
28
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pool.h ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <ATen/div_rtn.h>
3
+ #include <ATen/TensorUtils.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <c10/util/irange.h>
6
+
7
+ #include <utility>
8
+
9
+ #pragma once
10
+
11
+ namespace at::native {
12
+
13
+ using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input,
14
+ int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
15
+ using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
16
+
17
+ DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
18
+ DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel);
19
+
20
+ // averge pooling has same signature for forward and backward
21
+ using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH,
22
+ int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, c10::optional<int64_t> divisor_override);
23
+ using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
24
+ int dW, int dH, int padW, int padH, bool count_include_pad, c10::optional<int64_t> divisor_override);
25
+
26
+ DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel);
27
+ DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel);
28
+
29
+ using max_pool3d_fn = void(*)(Tensor& output, Tensor& indices, const Tensor& input,
30
+ int kW, int kH, int kD, int dW, int dH, int dD, int pW, int pH, int pD, int dilationW, int dilationH, int dilationD);
31
+ using max_pool3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
32
+
33
+ DECLARE_DISPATCH(max_pool3d_fn, max_pool3d_kernel);
34
+ DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel);
35
+ namespace {
36
+
37
+ template <typename dest_t, typename src_t>
38
+ static inline dest_t
39
+ safe_downcast(src_t v)
40
+ {
41
+ TORCH_CHECK(std::numeric_limits<dest_t>::min() <= v && v <= std::numeric_limits<dest_t>::max(),
42
+ "integer out of range");
43
+
44
+ return static_cast<dest_t>(v);
45
+ }
46
+
47
+ template<typename T>
48
+ static inline T pooling_output_shape_pad_lr(
49
+ T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
50
+ bool ceil_mode) {
51
+ T outputSize = div_rtn<T>(
52
+ inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 +
53
+ (ceil_mode ? stride - 1 : 0), stride) + 1;
54
+ if (ceil_mode) {
55
+ // ensure that the last pooling starts inside the image
56
+ // needed to avoid problems in ceil mode
57
+ if ((outputSize - 1) * stride >= inputSize + pad_l) {
58
+ --outputSize;
59
+ }
60
+ }
61
+ return outputSize;
62
+ }
63
+
64
+ template<typename T>
65
+ static inline T pooling_output_shape(
66
+ T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
67
+ TORCH_CHECK(stride != 0, "stride should not be zero");
68
+ TORCH_CHECK(pad >= 0,
69
+ "pad must be non-negative, but got pad: ", pad);
70
+ TORCH_CHECK(pad <= ((kernelSize - 1) * dilation + 1) / 2,
71
+ "pad should be at most half of effective kernel size, but got pad=",
72
+ pad, ", kernel_size=", kernelSize, " and dilation=", dilation)
73
+ return pooling_output_shape_pad_lr(
74
+ inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode);
75
+ }
76
+
77
+ template <typename T>
78
+ std::pair<T, T> _pooling_same_mode_padding_lr(
79
+ T inputSize, T kernelSize, T stride, T dilation) {
80
+ // NOTE: with strides, the output shape is ceil(inputSize/stride)
81
+ auto total_padding = T(dilation) * (kernelSize - 1);
82
+
83
+ // Prefer symmetric padding if possible
84
+ if (stride > 2 && (total_padding % 2 == 1)) {
85
+ // The floor in the output size calculation gives us a little wiggle room
86
+ auto wiggle_room = inputSize % stride - 1;
87
+ if (wiggle_room > 0) {
88
+ total_padding = total_padding - 1;
89
+ }
90
+ }
91
+
92
+ auto left = total_padding / 2;
93
+ return {left, total_padding - left};
94
+ }
95
+
96
+ inline std::pair<int64_t, int64_t> pooling_same_mode_padding_lr(
97
+ int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) {
98
+ return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
99
+ }
100
+
101
+ inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
102
+ c10::SymInt inputSize, c10::SymInt kernelSize, c10::SymInt stride, c10::SymInt dilation) {
103
+ return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), std::move(stride), std::move(dilation));
104
+ }
105
+
106
+ // AveragePool2d/DilatedMaxPool2d (forward)
107
+ static inline void
108
+ pool2d_shape_check(
109
+ const Tensor& input,
110
+ int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
111
+ int64_t nInputPlane,
112
+ int64_t inputHeight, int64_t inputWidth,
113
+ int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
114
+ {
115
+ const int64_t ndim = input.ndimension();
116
+ const int64_t nOutputPlane = nInputPlane;
117
+
118
+ TORCH_CHECK(kW > 0 && kH > 0,
119
+ "kernel size should be greater than zero, but got ",
120
+ "kH: ", kH, " kW: ", kW);
121
+ TORCH_CHECK(dW > 0 && dH > 0,
122
+ "stride should be greater than zero, but got "
123
+ "dH: ", dH, " dW: ", dW);
124
+ TORCH_CHECK(dilationH > 0 && dilationW > 0,
125
+ "dilation should be greater than zero, but got ",
126
+ "dilationH: ", dilationH, " dilationW: ", dilationW);
127
+
128
+ bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
129
+ if (memory_format == at::MemoryFormat::ChannelsLast){
130
+ // Expect tensor in NHWC format and allow 0-dim only for N.
131
+ TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0),
132
+ "Expected 4D (batch mode) tensor expected for input with channels_last layout"
133
+ " with optional 0 dim batch size for input, but got: ", input.sizes());
134
+ } else {
135
+ TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) ||
136
+ (ndim == 4 && valid_dims && input.size(3) != 0),
137
+ "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
138
+ input.sizes());
139
+ }
140
+
141
+ TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
142
+ "pad should be smaller than or equal to half of kernel size, but got ",
143
+ "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
144
+
145
+ TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
146
+ "Given input size: (",
147
+ nInputPlane, "x", inputHeight, "x", inputWidth, "). ",
148
+ "Calculated output size: (",
149
+ nOutputPlane, "x", outputHeight, "x", outputWidth, "). ",
150
+ "Output size is too small");
151
+ }
152
+
153
+ // DilatedMaxPool2d (backward)
154
+ static inline void
155
+ max_pool2d_backward_shape_check(
156
+ const Tensor& input,
157
+ const Tensor& gradOutput,
158
+ const Tensor& indices,
159
+ int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
160
+ int64_t nInputPlane,
161
+ int64_t inputHeight, int64_t inputWidth,
162
+ int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
163
+ {
164
+ pool2d_shape_check(
165
+ input,
166
+ kH, kW, dH, dW, padH, padW, dilationH, dilationW,
167
+ nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
168
+
169
+ const int64_t ndim = input.ndimension();
170
+ const int64_t nOutputPlane = nInputPlane;
171
+
172
+ check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
173
+ check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
174
+ check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
175
+
176
+ check_dim_size(indices, ndim, ndim-3, nOutputPlane);
177
+ check_dim_size(indices, ndim, ndim-2, outputHeight);
178
+ check_dim_size(indices, ndim, ndim-1, outputWidth);
179
+ }
180
+
181
+ // AveragePool2d (backward)
182
+ static inline void
183
+ avg_pool2d_backward_shape_check(
184
+ const Tensor& input,
185
+ const Tensor& gradOutput,
186
+ int64_t /*nbatch*/,
187
+ int kH, int kW, int dH, int dW, int padH, int padW,
188
+ int64_t nInputPlane,
189
+ int64_t inputHeight, int64_t inputWidth,
190
+ int64_t outputHeight, int64_t outputWidth,
191
+ MemoryFormat memory_format)
192
+ {
193
+ pool2d_shape_check(
194
+ input,
195
+ kH, kW, dH, dW, padH, padW, 1, 1,
196
+ nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
197
+ memory_format);
198
+
199
+ const int64_t ndim = input.ndimension();
200
+ const int64_t nOutputPlane = nInputPlane;
201
+
202
+ check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
203
+ check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
204
+ check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
205
+ }
206
+
207
+ // AveragePool3d/DilatedMaxPool3d (forward)
208
+ static inline void
209
+ pool3d_shape_check(
210
+ const Tensor& input,
211
+ int64_t nslices,
212
+ int kT, int kH, int kW,
213
+ int dT, int dH, int dW,
214
+ int pT, int pH, int pW,
215
+ int dilationT, int dilationH, int dilationW,
216
+ int64_t itime, int64_t iheight, int64_t iwidth,
217
+ int64_t otime, int64_t oheight, int64_t owidth,
218
+ const char *fn_name,
219
+ bool check_input_size=false)
220
+ {
221
+ const int64_t ndim = input.ndimension();
222
+
223
+ TORCH_CHECK(kT > 0 && kW > 0 && kH > 0,
224
+ "kernel size should be greater than zero, but got ",
225
+ "kT: ", kT, " kH: ", kH, " kW: ", kW);
226
+ TORCH_CHECK(dT > 0 && dW > 0 && dH > 0,
227
+ "stride should be greater than zero, but got ",
228
+ "dT: ", dT, " dH: ", dH, " dW: ", dW);
229
+ TORCH_CHECK(dilationT > 0 && dilationW > 0 && dilationH > 0,
230
+ "dilation should be greater than zero, but got ",
231
+ "dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW);
232
+
233
+ TORCH_CHECK(ndim == 4 || ndim == 5,
234
+ fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes());
235
+
236
+ for (const auto i : c10::irange(ndim)) {
237
+ if (ndim == 5 && i == 0) {
238
+ // size of batch-dim can be 0.
239
+ continue;
240
+ }
241
+ TORCH_CHECK(
242
+ input.size(i) > 0,
243
+ fn_name,
244
+ ": Expected input's non-batch dimensions to have positive length,"
245
+ " but input has a shape of ",
246
+ input.sizes(),
247
+ " and non-batch dimension ",
248
+ input.size(i),
249
+ " has length zero!")
250
+ }
251
+
252
+ if (check_input_size) { // AveragePool3d
253
+ TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
254
+ "input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ",
255
+ "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")");
256
+ }
257
+
258
+ TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
259
+ "pad should be smaller than or equal to half of kernel size, but got "
260
+ "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
261
+
262
+ TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1,
263
+ "Given input size: (",
264
+ nslices,"x", itime, "x", iheight, "x", iwidth, "). ",
265
+ "Calculated output size: (",
266
+ nslices, "x", otime, "x", oheight, "x", owidth, "). ",
267
+ "Output size is too small");
268
+ }
269
+
270
+ static inline void
271
+ max_pool3d_backward_shape_check(
272
+ const Tensor& input,
273
+ const Tensor& gradOutput,
274
+ const Tensor& indices,
275
+ int64_t nslices,
276
+ int kT, int kH, int kW,
277
+ int dT, int dH, int dW,
278
+ int pT, int pH, int pW,
279
+ int dilationT, int dilationH, int dilationW,
280
+ int64_t itime, int64_t iheight, int64_t iwidth,
281
+ int64_t otime, int64_t oheight, int64_t owidth,
282
+ const char* fn_name)
283
+ {
284
+ const int64_t ndim = input.ndimension();
285
+
286
+ pool3d_shape_check(
287
+ input,
288
+ nslices,
289
+ kT, kH, kW,
290
+ dT, dH, dW,
291
+ pT, pH, pW,
292
+ dilationT, dilationH, dilationW,
293
+ itime, iheight, iwidth,
294
+ otime, oheight, owidth, fn_name);
295
+
296
+ check_dim_size(gradOutput, ndim, ndim-4, nslices);
297
+ check_dim_size(gradOutput, ndim, ndim-3, otime);
298
+ check_dim_size(gradOutput, ndim, ndim-2, oheight);
299
+ check_dim_size(gradOutput, ndim, ndim-1, owidth);
300
+
301
+ check_dim_size(indices, ndim, ndim-4, nslices);
302
+ check_dim_size(indices, ndim, ndim-3, otime);
303
+ check_dim_size(indices, ndim, ndim-2, oheight);
304
+ check_dim_size(indices, ndim, ndim-1, owidth);
305
+ }
306
+
307
+ static inline void
308
+ avg_pool3d_backward_shape_check(
309
+ const Tensor& input,
310
+ const Tensor& gradOutput,
311
+ int64_t nslices,
312
+ int kT, int kH, int kW,
313
+ int dT, int dH, int dW,
314
+ int pT, int pH, int pW,
315
+ int64_t itime, int64_t iheight, int64_t iwidth,
316
+ int64_t otime, int64_t oheight, int64_t owidth,
317
+ const char *fn_name)
318
+ {
319
+ const int64_t ndim = input.ndimension();
320
+
321
+ pool3d_shape_check(
322
+ input,
323
+ nslices,
324
+ kT, kH, kW,
325
+ dT, dH, dW,
326
+ pT, pH, pW,
327
+ 1, 1, 1,
328
+ itime, iheight, iwidth,
329
+ otime, oheight, owidth,
330
+ fn_name, true);
331
+
332
+ check_dim_size(gradOutput, ndim, ndim-4, nslices);
333
+ check_dim_size(gradOutput, ndim, ndim-3, otime);
334
+ check_dim_size(gradOutput, ndim, ndim-2, oheight);
335
+ check_dim_size(gradOutput, ndim, ndim-1, owidth);
336
+ }
337
+
338
+ } // anonymous namespace
339
+
340
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RNN.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ namespace at::native {
7
+
8
+ using lstm_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool, bool);
9
+ using rnn_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool, bool);
10
+ using lstm_packed_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool);
11
+ using rnn_packed_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool);
12
+
13
+ DECLARE_DISPATCH(lstm_fn, lstm_cudnn_stub);
14
+ DECLARE_DISPATCH(lstm_fn, lstm_miopen_stub);
15
+ DECLARE_DISPATCH(lstm_fn, lstm_mkldnn_stub);
16
+ DECLARE_DISPATCH(rnn_fn, gru_cudnn_stub);
17
+ DECLARE_DISPATCH(rnn_fn, gru_miopen_stub);
18
+ DECLARE_DISPATCH(rnn_fn, rnn_tanh_cudnn_stub);
19
+ DECLARE_DISPATCH(rnn_fn, rnn_tanh_miopen_stub);
20
+ DECLARE_DISPATCH(rnn_fn, rnn_relu_cudnn_stub);
21
+ DECLARE_DISPATCH(rnn_fn, rnn_relu_miopen_stub);
22
+ DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_cudnn_stub);
23
+ DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_miopen_stub);
24
+ DECLARE_DISPATCH(rnn_packed_fn, gru_packed_cudnn_stub);
25
+ DECLARE_DISPATCH(rnn_packed_fn, gru_packed_miopen_stub);
26
+ DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_cudnn_stub);
27
+ DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_miopen_stub);
28
+ DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_cudnn_stub);
29
+ DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_miopen_stub);
30
+
31
+ inline void check_attributes(const Tensor& input, const TensorList& params, const TensorList& hiddens, bool check_dtype=false) {
32
+ auto input_device = input.device();
33
+ auto input_dtype = input.scalar_type();
34
+
35
+ auto check_tensors = [&](const std::string& name, const Tensor& t) {
36
+ if (!t.defined()) return;
37
+ auto t_device = t.device();
38
+ TORCH_CHECK(input_device == t_device,
39
+ "Input and ", name, " tensors are not at the same device, found input tensor at ",
40
+ input_device, " and ", name, " tensor at ", t_device);
41
+ if (check_dtype) {
42
+ auto t_dtype = t.scalar_type();
43
+ TORCH_CHECK(input_dtype == t_dtype,
44
+ "Input and ", name, " tensors are not the same dtype, found input tensor with ",
45
+ input_dtype, " and ", name, " tensor with ", t_dtype);
46
+ }
47
+ };
48
+
49
+ for (const auto& h : hiddens) check_tensors("hidden", h);
50
+ for (const auto& p : params) check_tensors("parameter", p);
51
+ }
52
+
53
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Repeat.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/TensorOperators.h>
5
+
6
+ #ifndef AT_PER_OPERATOR_HEADERS
7
+ #include <ATen/Functions.h>
8
+ #else
9
+ #include <ATen/ops/empty.h>
10
+ #include <ATen/ops/empty_like.h>
11
+ #endif
12
+
13
+ namespace at::native {
14
+
15
+ template <
16
+ typename index_t,
17
+ void compute(index_t*, int64_t*, index_t*, int64_t, int64_t)>
18
+ static inline Tensor repeat_interleave_common(
19
+ const Tensor& repeats,
20
+ c10::optional<int64_t> output_size) {
21
+ TORCH_CHECK(
22
+ repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
23
+ TORCH_CHECK(
24
+ repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
25
+ "repeats has to be Long or Int tensor");
26
+ if (repeats.size(0) == 0) {
27
+ return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
28
+ }
29
+ Tensor repeats_ = repeats.contiguous();
30
+ Tensor cumsum = repeats.cumsum(0);
31
+ int64_t total;
32
+ if (output_size.has_value()) {
33
+ total = output_size.value();
34
+ } else {
35
+ total = cumsum[-1].item<int64_t>();
36
+ TORCH_CHECK(
37
+ (repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
38
+ }
39
+
40
+ Tensor result = at::empty({total}, repeats.options());
41
+ index_t* repeat_ptr = repeats_.data_ptr<index_t>();
42
+ int64_t* cumsum_ptr = cumsum.data_ptr<int64_t>();
43
+ index_t* result_ptr = result.data_ptr<index_t>();
44
+ compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
45
+ return result;
46
+ }
47
+
48
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Resize.h ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/ResizeCommon.h>
5
+ #include <ATen/EmptyTensor.h>
6
+ #include <ATen/TensorUtils.h>
7
+
8
+ #include <c10/core/CPUAllocator.h>
9
+
10
+ #include <utility>
11
+
12
+
13
+ namespace at::native {
14
+
15
+ // TODO: make all operations that resize given outputs use this function
16
+ // for consistency and maintainability.
17
+ // Some operations like `cat` might not be able to make the use of
18
+ // resize_output directly. For more details to understand how it works in `cat`,
19
+ // see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
20
+ // Resizes outputs
21
+ // Functions accepting output tensors, like with the "out" kwarg, should
22
+ // call this function to handle resizing their output tensor.
23
+ // Issues a warning if the output tensor has one or more elements and
24
+ // needs resizing
25
+ // NOTE: In the future the warning will become an error
26
+ // Returns a bool saying whether or not the resize actually happened or not
27
+ TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
28
+ // WARNING: Do NOT call this directly. If you are resizing an output and want
29
+ // to support dynamic shapes call at::resize__symint and resize_output_check_symint.
30
+ // For more details, see: https://github.com/pytorch/pytorch/pull/111530/files#r1365845272
31
+ TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
32
+
33
+ // Utility for resize_output
34
+ // Returns a bool saying resize should happen or not and
35
+ // raises a warning if resizing for one or more elements
36
+ TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
37
+ TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
38
+
39
+ TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
40
+ TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);
41
+ TORCH_API void resize_bytes_nocuda(const Storage& storage, c10::SymInt size_bytes);
42
+
43
+ static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
44
+ // It does not make sense to try to resize a storage
45
+ // to hold 0 elements, and this can break
46
+ // if storage_offset is positive but
47
+ // new_size is 0, so just bail in that case
48
+ // (same comment is in cuda/Resize.h)
49
+ if (self->numel() == 0) {
50
+ return;
51
+ }
52
+
53
+ const Storage& storage = self->unsafe_storage();
54
+ if (!storage) {
55
+ auto new_storage = c10::make_intrusive<StorageImpl>(
56
+ StorageImpl::use_byte_size_t(),
57
+ new_size_bytes,
58
+ c10::GetCPUAllocator(),
59
+ true);
60
+ self->set_storage_keep_dtype(std::move(new_storage));
61
+ } else if (new_size_bytes > storage.nbytes()) {
62
+ resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
63
+ }
64
+ }
65
+
66
+ TORCH_API TensorImpl* resize_impl_cpu_(
67
+ TensorImpl* self,
68
+ IntArrayRef size,
69
+ at::OptionalIntArrayRef stride,
70
+ bool resize_storage = true);
71
+
72
+ template <typename T>
73
+ T maybe_convert_symint(c10::SymInt) = delete;
74
+
75
+ template <>
76
+ inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
77
+
78
+ template <>
79
+ inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
80
+
81
+ template <typename T>
82
+ static inline void checkInBoundsForStorage(
83
+ ArrayRef<T> size,
84
+ ArrayRef<T> stride,
85
+ T storage_offset,
86
+ const caffe2::TypeMeta& data_type,
87
+ const Storage& new_storage) {
88
+ T storage_size_bytes =
89
+ at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
90
+ T storage_offset_bytes = storage_offset * data_type.itemsize();
91
+ if (storage_size_bytes == 0) {
92
+ // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
93
+ return;
94
+ }
95
+ T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
96
+ TORCH_CHECK(
97
+ storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
98
+ "setStorage: sizes ",
99
+ size,
100
+ ", strides ",
101
+ stride,
102
+ ","
103
+ " storage offset ",
104
+ storage_offset,
105
+ ", and itemsize ",
106
+ data_type.itemsize(),
107
+ " requiring a storage size of ",
108
+ storage_size_bytes + storage_offset_bytes,
109
+ " are out of bounds for storage of size ",
110
+ new_storage_size_bytes);
111
+ }
112
+
113
+ template <typename T>
114
+ static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
115
+ ArrayRef<T> size, ArrayRef<T> stride) {
116
+ // FIXME: stride should be optional
117
+ if (stride.data()) {
118
+ TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
119
+ ") and stride length (", stride.size(), ")");
120
+ }
121
+
122
+ #ifdef DEBUG
123
+ TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
124
+ #endif
125
+
126
+ // storage: note this can't be replaced with result.set_(storage) as the semantics of that
127
+ // function is to set the tensor size to be equal to the size of the storage.
128
+ if (!result.storage().is_alias_of(storage)) {
129
+ // Caffe2 might have tensors whose storages are null, but we
130
+ // don't allow it in PyTorch.
131
+ TORCH_INTERNAL_ASSERT(storage);
132
+ TORCH_INTERNAL_ASSERT(result.storage());
133
+
134
+ // We used to allow this, but this breaks device caching.
135
+ // Let's put an actual error message for this one.
136
+ TORCH_CHECK(result.storage().device() == storage.device(),
137
+ "Attempted to set the storage of a tensor on device \"", result.storage().device(),
138
+ "\" to a storage on different device \"", storage.device(),
139
+ "\". This is no longer allowed; the devices must match.");
140
+ result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
141
+ }
142
+
143
+ // storageOffset
144
+ TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
145
+ }
146
+
147
+ /**
148
+ * Set self's sizes, strides, and storage_offset.
149
+ * (size, stride, storage_offset) must be in bounds for self's storage.
150
+ */
151
+ template <typename T>
152
+ inline void setStrided(
153
+ const Tensor& self,
154
+ ArrayRef<T> size,
155
+ ArrayRef<T> stride,
156
+ T storage_offset) {
157
+ TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
158
+ for (const auto& val : stride) {
159
+ TORCH_CHECK(val >= 0,
160
+ "as_strided: Negative strides are not supported at the moment, "
161
+ "got strides: ", stride);
162
+ }
163
+
164
+ auto* self_ = self.unsafeGetTensorImpl();
165
+ checkInBoundsForStorage(
166
+ size, stride, storage_offset, self_->dtype(), self_->storage());
167
+
168
+ /* storage offset */
169
+ TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
170
+ self_->set_sizes_and_strides(size, stride, c10::make_optional(storage_offset));
171
+ }
172
+
173
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ResizeCommon.h ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/TensorFactories.h>
5
+ #include <ATen/NamedTensorUtils.h>
6
+ #include <c10/util/irange.h>
7
+
8
+ #ifndef AT_PER_OPERATOR_HEADERS
9
+ #include <ATen/NativeFunctions.h>
10
+ #else
11
+ #include <ATen/ops/empty.h>
12
+ #endif
13
+
14
+ namespace at::native {
15
+
16
+ template <typename T>
17
+ inline T storage_size_for(ArrayRef<T> size, ArrayRef<T> stride) {
18
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(),
19
+ "storage_size_for(size, stride) requires that size and stride ",
20
+ "have the same size as a precondition.");
21
+ T storage_size = 1;
22
+ for (const auto dim : c10::irange(size.size())) {
23
+ if (size[dim] == 0) {
24
+ storage_size = 0;
25
+ break;
26
+ }
27
+ storage_size += (size[dim] - 1) * stride[dim];
28
+ }
29
+ return storage_size;
30
+ }
31
+
32
+ inline const Tensor& resize_named_tensor_(
33
+ const Tensor& self,
34
+ IntArrayRef size,
35
+ c10::optional<MemoryFormat> optional_memory_format) {
36
+ TORCH_INTERNAL_ASSERT(self.has_names());
37
+ TORCH_CHECK(
38
+ self.sizes() == size,
39
+ "Cannot resize named tensor with resize_ or resize_as_ (tried to resize "
40
+ "Tensor",
41
+ self.names(),
42
+ " with size ",
43
+ self.sizes(),
44
+ " to ",
45
+ size,
46
+ "). This may be caused by passing a named tensor ",
47
+ "as an `out=` argument; please ensure that the sizes are the same. ");
48
+ TORCH_CHECK(
49
+ !optional_memory_format.has_value(),
50
+ "Unsupported memory format for named tensor resize ",
51
+ optional_memory_format.value());
52
+ return self;
53
+ }
54
+
55
+ // For deterministic output, fill new elements that were added after a storage
56
+ // resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage
57
+ // before the resize happened.
58
+ inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) {
59
+ const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage();
60
+ int64_t new_storage_nbytes = storage.nbytes();
61
+ int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize();
62
+ int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize();
63
+ if (new_storage_numel > old_storage_numel) {
64
+ at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device()));
65
+ tensor_view.set_(
66
+ storage,
67
+ /*storage_offset=*/old_storage_numel,
68
+ /*size=*/{new_storage_numel - old_storage_numel},
69
+ /*stride=*/{1});
70
+ at::native::fill_empty_deterministic_(tensor_view);
71
+ }
72
+ return tensor;
73
+ }
74
+
75
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SharedReduceOps.h ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // Please note that this file is
3
+ // used across both CPU and GPU.
4
+
5
+ #include <type_traits>
6
+ #include <complex>
7
+ #include <c10/macros/Macros.h>
8
+ #include <ATen/detail/FunctionTraits.h>
9
+ #include <ATen/NumericUtils.h>
10
+ #if defined(__CUDACC__)
11
+ #include <ATen/cuda/DeviceUtils.cuh>
12
+ #include <ATen/native/cuda/DeviceSqrt.cuh>
13
+ #elif defined(__HIPCC__)
14
+ #include <ATen/hip/DeviceUtils.cuh>
15
+ #include <ATen/native/hip/DeviceSqrt.cuh>
16
+ #endif
17
+ #if defined(__CUDACC__) || defined(__HIPCC__)
18
+ #include <thrust/pair.h>
19
+ #else
20
+ #include <cmath>
21
+ #define device_sqrt std::sqrt
22
+ #endif
23
+ #if defined(__CUDACC__) || defined(__HIPCC__)
24
+ template <typename scalar_t>
25
+ inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
26
+ #if defined(__HIPCC__)
27
+ // TODO: remove this special case for HIP when issue is fixed:
28
+ // https://github.com/ROCm-Developer-Tools/HIP/issues/2209
29
+ scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
30
+ #else
31
+ scalar_t max = at::_isnan(b) ? b : std::max(a, b);
32
+ #endif
33
+ return max;
34
+ }
35
+ template <typename scalar_t>
36
+ inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
37
+ #if defined(__HIPCC__)
38
+ // TODO: remove this special case for HIP when issue is fixed:
39
+ // https://github.com/ROCm-Developer-Tools/HIP/issues/2209
40
+ scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
41
+ #else
42
+ scalar_t min = at::_isnan(b) ? b : std::min(a, b);
43
+ #endif
44
+ return min;
45
+ }
46
+ #define MAX(X, Y) max_propagate_nan(X,Y)
47
+ #define MIN(X, Y) min_propagate_nan(X,Y)
48
+ #else
49
+ #include <ATen/native/cpu/zmath.h>
50
+ #define MAX(X, Y) max_impl(X,Y)
51
+ #define MIN(X, Y) min_impl(X,Y)
52
+ #endif
53
+
54
+ // ROCM hcc doesn't work well with using std:: in kernel functions
55
+ #if defined(__CUDA_ARCH__)
56
+ #include <c10/cuda/CUDAMathCompat.h>
57
+ #define compat_pow c10::cuda::compat::pow
58
+ #elif defined(__HIPCC__)
59
+ #include <c10/hip/HIPMathCompat.h>
60
+ #define compat_pow c10::hip::compat::pow
61
+ #else
62
+ #define compat_pow std::pow
63
+ #endif
64
+
65
+ namespace at { namespace native {
66
+
67
+ namespace detail {
68
+
69
+ #if defined(__CUDACC__) || defined(__HIPCC__)
70
+ template <typename T1, typename T2> using pair = thrust::pair<T1, T2>;
71
+ #else
72
+ template <typename T1, typename T2> using pair = std::pair<T1, T2>;
73
+ #endif
74
+
75
+ } // namespace detail
76
+
77
+ template <typename scalar_t, typename index_t>
78
+ struct WelfordData {
79
+ scalar_t mean;
80
+ scalar_t m2;
81
+ index_t n;
82
+ scalar_t nf;
83
+
84
+ C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
85
+
86
+ C10_HOST_DEVICE WelfordData(
87
+ scalar_t mean,
88
+ scalar_t m2,
89
+ index_t n,
90
+ scalar_t nf)
91
+ : mean(mean), m2(m2), n(n), nf(nf) {}
92
+ };
93
+
94
+
95
+ template <typename scalar_t, typename acc_scalar_t, typename index_t, typename res_t>
96
+ struct WelfordOps {
97
+ acc_scalar_t correction;
98
+ bool take_sqrt;
99
+ public:
100
+ using acc_t = WelfordData<acc_scalar_t, index_t>;
101
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
102
+ // We accumulate n in index_t to avoid cumulative rounding error, but still
103
+ // need nf for use in combine where int32 may overflow.
104
+ index_t new_n = acc.n + 1;
105
+ acc_scalar_t new_nf = static_cast<acc_scalar_t>(new_n);
106
+ acc_scalar_t delta = data - acc.mean;
107
+ acc_scalar_t new_mean = acc.mean + delta / new_nf;
108
+ acc_scalar_t new_delta = data - new_mean;
109
+ return {
110
+ new_mean,
111
+ acc.m2 + delta * new_delta,
112
+ new_n,
113
+ new_nf,
114
+ };
115
+ }
116
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
117
+ if (a.nf == 0) {
118
+ return b;
119
+ }
120
+ if (b.nf == 0) {
121
+ return a;
122
+ }
123
+ acc_scalar_t delta = b.mean - a.mean;
124
+ acc_scalar_t new_count = a.nf + b.nf;
125
+ acc_scalar_t nb_over_n = b.nf / new_count;
126
+ return {
127
+ a.mean + delta * nb_over_n,
128
+ a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
129
+ // setting acc.n as -1 since acc.n might not be able to represent the count
130
+ // correctly within its range, setting it to -1 to avoid confusion
131
+ -1,
132
+ new_count
133
+ };
134
+ }
135
+ inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ {
136
+ const auto mean = static_cast<scalar_t>(acc.mean);
137
+ const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
138
+ const auto var = acc.m2 / divisor;
139
+ res_t results(take_sqrt ? device_sqrt(var) : var, mean);
140
+ return results;
141
+ }
142
+
143
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
144
+ return acc;
145
+ }
146
+
147
+ #if defined(__CUDACC__) || defined(__HIPCC__)
148
+ inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
149
+ return {
150
+ WARP_SHFL_DOWN(acc.mean, offset)
151
+ , WARP_SHFL_DOWN(acc.m2, offset)
152
+ , WARP_SHFL_DOWN(acc.n, offset)
153
+ , WARP_SHFL_DOWN(acc.nf, offset)
154
+ };
155
+ }
156
+ #endif
157
+ C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
158
+ : correction(correction), take_sqrt(take_sqrt) {}
159
+ };
160
+
161
+ template <typename scalar_t, typename acc_t=scalar_t, typename factor_t=acc_t, typename out_t = acc_t>
162
+ struct MeanOps {
163
+ factor_t factor;
164
+
165
+ inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const {
166
+ return combine(a, static_cast<acc_t>(b));
167
+ }
168
+
169
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
170
+ return a + b;
171
+ }
172
+
173
+ inline C10_DEVICE out_t project(acc_t a) const {
174
+ return a * factor;
175
+ }
176
+
177
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
178
+ return acc;
179
+ }
180
+
181
+ #if defined(__CUDACC__) || defined(__HIPCC__)
182
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
183
+ return WARP_SHFL_DOWN(data, offset);
184
+ }
185
+ #endif
186
+
187
+ MeanOps(factor_t factor): factor(factor) {
188
+ }
189
+ };
190
+
191
+ // This accumulator template is used to calculate the minimum absolute value of
192
+ // a set of numbers.
193
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
194
+ // value. These types differ for complex number input support.
195
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
196
+ struct AbsMinOps {
197
+
198
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
199
+ return MIN(acc, static_cast<acc_t>(std::abs(data)));
200
+ }
201
+
202
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
203
+ return MIN(a, b);
204
+ }
205
+
206
+ inline C10_DEVICE out_t project(acc_t a) const {
207
+ return a;
208
+ }
209
+
210
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
211
+ return acc;
212
+ }
213
+
214
+ #if defined(__CUDACC__) || defined(__HIPCC__)
215
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
216
+ return WARP_SHFL_DOWN(acc, offset);
217
+ }
218
+ #endif
219
+ };
220
+
221
+ // This accumulator template is used to calculate the maximum absolute value of
222
+ // a set of numbers.
223
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
224
+ // value. These types differ for complex number input support.
225
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
226
+ struct AbsMaxOps {
227
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
228
+ return MAX(acc, static_cast<acc_t>(std::abs(data)));
229
+ }
230
+
231
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
232
+ return MAX(a, b);
233
+ }
234
+
235
+ inline C10_DEVICE out_t project(acc_t a) const {
236
+ return a;
237
+ }
238
+
239
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
240
+ return acc;
241
+ }
242
+
243
+ #if defined(__CUDACC__) || defined(__HIPCC__)
244
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
245
+ return WARP_SHFL_DOWN(acc, offset);
246
+ }
247
+ #endif
248
+ };
249
+
250
+ // This accumulator template is used to calculate the norm of the absolute value
251
+ // of a set of numbers.
252
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
253
+ // value. These types differ for complex number input support.
254
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
255
+ struct NormOps {
256
+ acc_t norm_;
257
+
258
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
259
+ return acc + compat_pow(static_cast<acc_t>(std::abs(data)), norm_);
260
+ }
261
+
262
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
263
+ return a + b;
264
+ }
265
+
266
+ inline C10_DEVICE out_t project(acc_t a) const {
267
+ return compat_pow(a, static_cast<acc_t>(1.0) / norm_);
268
+ }
269
+
270
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
271
+ return acc;
272
+ }
273
+
274
+ #if defined(__CUDACC__) || defined(__HIPCC__)
275
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
276
+ return WARP_SHFL_DOWN(acc, offset);
277
+ }
278
+ #endif
279
+
280
+ NormOps(acc_t norm_): norm_(norm_) {
281
+ }
282
+ };
283
+
284
+ // This accumulator template is used to calculate the order zero norm of the
285
+ // absolute value of a set of numbers.
286
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
287
+ // value. These types differ for complex number input support.
288
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
289
+ struct NormZeroOps {
290
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
291
+ return acc + (data == static_cast<scalar_t>(0) ? static_cast<acc_t>(0) : static_cast<acc_t>(1));
292
+ }
293
+
294
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
295
+ return a + b;
296
+ }
297
+
298
+ inline C10_DEVICE out_t project(acc_t a) const {
299
+ return a;
300
+ }
301
+
302
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
303
+ return acc;
304
+ }
305
+
306
+
307
+ #if defined(__CUDACC__) || defined(__HIPCC__)
308
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
309
+ return WARP_SHFL_DOWN(acc, offset);
310
+ }
311
+ #endif
312
+ };
313
+
314
+ // This accumulator template is used to calculate the order one norm of the
315
+ // absolute value of a set of numbers.
316
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
317
+ // value. These types differ for complex number input support.
318
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
319
+ struct NormOneOps {
320
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
321
+ return acc + static_cast<acc_t>(std::abs(data));
322
+ }
323
+
324
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
325
+ return a + b;
326
+ }
327
+
328
+ inline C10_DEVICE out_t project(acc_t a) const {
329
+ return a;
330
+ }
331
+
332
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
333
+ return acc;
334
+ }
335
+
336
+ #if defined(__CUDACC__) || defined(__HIPCC__)
337
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
338
+ return WARP_SHFL_DOWN(acc, offset);
339
+ }
340
+ #endif
341
+ };
342
+
343
+
344
+ template<typename acc_t>
345
+ struct AbsSwitch {};
346
+
347
+ template<typename scalar_t, typename acc_t>
348
+ inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t>) {
349
+ return static_cast<acc_t>(data);
350
+ }
351
+
352
+ template<typename scalar_t, typename acc_t>
353
+ inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t>) {
354
+ return static_cast<acc_t>(std::abs(data));
355
+ }
356
+
357
+ template<typename scalar_t, typename acc_t>
358
+ inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t>) {
359
+ return static_cast<acc_t>(std::abs(data));
360
+ }
361
+
362
+ // This accumulator template is used to calculate the order two norm of the
363
+ // absolute value of a set of numbers.
364
+ // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
365
+ // value. These types differ for complex number input support.
366
+ template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
367
+ struct NormTwoOps {
368
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
369
+ acc_t data_ = abs_if_complex(data, AbsSwitch<acc_t>());
370
+ return acc + data_ * data_;
371
+ }
372
+
373
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
374
+ return a + b;
375
+ }
376
+
377
+ inline C10_DEVICE out_t project(acc_t a) const {
378
+ return device_sqrt(a);
379
+ }
380
+
381
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
382
+ return acc;
383
+ }
384
+
385
+ #if defined(__CUDACC__) || defined(__HIPCC__)
386
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
387
+ return WARP_SHFL_DOWN(acc, offset);
388
+ }
389
+ #endif
390
+ };
391
+
392
+ template <typename acc_t, typename data_t>
393
+ struct NanSumOps {
394
+ inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const {
395
+ return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b});
396
+ }
397
+
398
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
399
+ return a + b;
400
+ }
401
+
402
+ inline C10_DEVICE data_t project(acc_t a) const {
403
+ return data_t{a};
404
+ }
405
+
406
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
407
+ return acc;
408
+ }
409
+
410
+ #if defined(__CUDACC__) || defined(__HIPCC__)
411
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
412
+ return WARP_SHFL_DOWN(data, offset);
413
+ }
414
+ #endif
415
+ };
416
+
417
+ namespace detail {
418
+
419
+ template <typename scalar_t>
420
+ struct LessOrNan {
421
+ C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
422
+ // If (a == b), then choose the one with lower idx, else min(a, b)
423
+ if (at::_isnan(a)) {
424
+ if (at::_isnan(b)) {
425
+ return idx_a < idx_b;
426
+ }
427
+ return true;
428
+ }
429
+ return (a == b) ? idx_a < idx_b : (a < b);
430
+ }
431
+ };
432
+
433
+ template <typename scalar_t>
434
+ struct GreaterOrNan {
435
+ C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
436
+ // If (a == b), then choose the one with lower idx, else max(a, b)
437
+ if (at::_isnan(a)) {
438
+ if (at::_isnan(b)) {
439
+ return idx_a < idx_b;
440
+ }
441
+ return true;
442
+ }
443
+ return (a == b) ? idx_a < idx_b : (a > b);
444
+ }
445
+ };
446
+
447
+ template <typename comp_t>
448
+ struct MinMaxReductionOps {
449
+ using scalar_t = typename binary_function_traits<comp_t>::arg1_t;
450
+ using index_t = int64_t;
451
+ using arg_t = detail::pair<scalar_t, index_t>;
452
+
453
+ static C10_DEVICE arg_t project(arg_t arg) {
454
+ return arg;
455
+ }
456
+
457
+ static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
458
+ return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
459
+ }
460
+
461
+ static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
462
+ return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
463
+ }
464
+
465
+ static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
466
+ return {a.first, a.second + base_idx};
467
+ }
468
+
469
+ #if defined(__CUDACC__) || defined(__HIPCC__)
470
+ static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
471
+ return arg_t(WARP_SHFL_DOWN(arg.first, offset),
472
+ WARP_SHFL_DOWN(arg.second, offset));
473
+ }
474
+ #endif
475
+ };
476
+
477
+ template <typename comp_t>
478
+ struct ArgReductionOps : public MinMaxReductionOps<comp_t> {
479
+ using typename MinMaxReductionOps<comp_t>::scalar_t;
480
+ using typename MinMaxReductionOps<comp_t>::index_t;
481
+ using typename MinMaxReductionOps<comp_t>::arg_t;
482
+
483
+ static C10_DEVICE index_t project(arg_t arg) {
484
+ return arg.second;
485
+ }
486
+ };
487
+
488
+ } // namespace detail
489
+
490
+ template <typename scalar_t>
491
+ struct ArgMaxOps :
492
+ public detail::ArgReductionOps<detail::GreaterOrNan<scalar_t>> {
493
+ };
494
+
495
+ template <typename scalar_t>
496
+ struct ArgMinOps :
497
+ public detail::ArgReductionOps<detail::LessOrNan<scalar_t>> {
498
+ };
499
+
500
+ template <typename scalar_t>
501
+ struct MinOps :
502
+ public detail::MinMaxReductionOps<detail::LessOrNan<scalar_t>> {
503
+ };
504
+
505
+ template <typename scalar_t>
506
+ struct MaxOps :
507
+ public detail::MinMaxReductionOps<detail::GreaterOrNan<scalar_t>> {
508
+ };
509
+
510
+ template <typename scalar_t, typename acc_scalar_t, typename index_t>
511
+ struct MinMaxOps {
512
+ using acc_t = detail::pair<acc_scalar_t, acc_scalar_t>;
513
+ inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
514
+ return combine(acc, {data, data});
515
+ }
516
+
517
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
518
+ auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first;
519
+ auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second;
520
+
521
+ return {min_val, max_val};
522
+ }
523
+
524
+ inline C10_DEVICE acc_t project(acc_t acc) const {
525
+ return acc;
526
+ }
527
+
528
+ static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
529
+ return acc;
530
+ }
531
+
532
+ #if defined(__CUDACC__) || defined(__HIPCC__)
533
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
534
+ return {
535
+ WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset)
536
+ };
537
+ }
538
+ #endif
539
+ };
540
+
541
+ }} // namespace at::native
542
+
543
+ #undef MAX
544
+ #undef MIN
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SparseTensorUtils.h ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Parallel.h>
4
+ #include <ATen/SparseTensorImpl.h>
5
+ #include <ATen/core/Tensor.h>
6
+
7
+ #ifndef AT_PER_OPERATOR_HEADERS
8
+ #include <ATen/Functions.h>
9
+ #else
10
+ #include <ATen/ops/empty.h>
11
+ #include <ATen/ops/tensor.h>
12
+ #endif
13
+
14
+ namespace at::sparse {
15
+
16
+ // Just for documentary purposes
17
+ using SparseTensor = Tensor;
18
+ using SparseType = Type;
19
+
20
+ // This is an internal utility function for getting at the SparseTensorImpl,
21
+ // so that we can write sparse tensor specific accessors for special fields
22
+ // in SparseTensor. You should only use this for writing low level
23
+ // setters/getters for SparseTensorImpl fields; otherwise, you should use
24
+ // the low level setters/getters that were implemented using this.
25
+ //
26
+ // This may be called repeatedly, so make sure it's pretty cheap.
27
+ inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
28
+ TORCH_INTERNAL_ASSERT(
29
+ self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
30
+ return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
31
+ }
32
+
33
+ // Takes indices and values and directly puts them into the sparse tensor, no
34
+ // copy. This used to be called THSTensor_(_move)
35
+ inline void alias_into_sparse(
36
+ const SparseTensor& self,
37
+ const Tensor& indices,
38
+ const Tensor& values) {
39
+ get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
40
+ }
41
+
42
+ // Take indices and values and makes a (data) copy of them to put into the
43
+ // sparse indices/values. This used to be called THSTensor_(_set)
44
+ inline void copy_into_sparse(
45
+ const SparseTensor& self,
46
+ const Tensor& indices,
47
+ const Tensor& values,
48
+ bool non_blocking) {
49
+ alias_into_sparse(
50
+ self,
51
+ indices.to(self._indices().options(), non_blocking, /*copy=*/true),
52
+ values.to(self._values().options(), non_blocking, /*copy=*/true));
53
+ }
54
+
55
+ // TODO: put this into the public API
56
+ inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
57
+ return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
58
+ }
59
+
60
+ inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
61
+ return self.sparse_dim() == src.sparse_dim() &&
62
+ self.dense_dim() == src.dense_dim();
63
+ }
64
+
65
+ // Give us a new values tensor, with the same dimensionality
66
+ // as 'values' but with a new number of non-zero elements.
67
+ // TODO: Expose this for real in ATen, some day?
68
+ // NB: Doesn't preserve data.
69
+ inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
70
+ std::vector<int64_t> size = values.sizes().vec();
71
+ size[0] = nnz;
72
+ return at::empty(size, values.options());
73
+ }
74
+
75
+ // NOTE [ Flatten Sparse Indices ]
76
+ // This helper function flattens a sparse indices tensor (a Tensor) into a 1D
77
+ // indices tensor. E.g.,
78
+ // input = [[2, 4, 0],
79
+ // [3, 1, 10]]
80
+ // full_size = [2, 12]
81
+ // output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
82
+ //
83
+ // In other words, assuming that each `indices[i, :]` is a valid index to a
84
+ // tensor `t` of shape `full_size`. This returns the corresponding indices to
85
+ // the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
86
+ // if forceClone is true, the result will forced to be a clone of self.
87
+ // if force_clone is true, the result will forced to be a clone of self.
88
+ TORCH_API Tensor flatten_indices(
89
+ const Tensor& indices,
90
+ IntArrayRef full_size,
91
+ bool force_clone = false);
92
+
93
+ // Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten
94
+ // Sparse Indices ], except this one allows partial flatten: only flatten on
95
+ // specified dims. Note that the flatten indices might be uncoalesced if
96
+ // dims_to_flatten.size() < sparse_dim. Also if input indices is already
97
+ // coalesced, the flattened indices will also be sorted.
98
+ //
99
+ // args:
100
+ // indices: sparse tensor indices
101
+ // sizes: sparse tensor sizes
102
+ // dims_to_flatten: a list of dim index to flatten
103
+ //
104
+ // Ex1:
105
+ // indices = [[2, 4, 0],
106
+ // [3, 1, 3]]
107
+ // sizes = [2, 12]
108
+ // dims_to_flatten = [0, 1]
109
+ // new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
110
+ //
111
+ // Ex2:
112
+ // dims_to_flatten = [1]
113
+ // new_indices = [ 3, 1, 3 ] # uncoalesced
114
+ TORCH_API Tensor flatten_indices_by_dims(
115
+ const Tensor& indices,
116
+ const IntArrayRef& sizes,
117
+ const IntArrayRef& dims_to_flatten);
118
+
119
+ // Find the CSR representation for a row `indices` from the COO format
120
+ TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
121
+
122
+ TORCH_API Tensor zeros_like_with_indices(const Tensor& t);
123
+
124
+ template <size_t static_shape_max_len>
125
+ class TensorGeometryHolder {
126
+ using geometry_holder_t = std::array<int64_t, static_shape_max_len>;
127
+
128
+ public:
129
+ explicit TensorGeometryHolder(
130
+ IntArrayRef sizes,
131
+ IntArrayRef strides,
132
+ TensorOptions options = {}) {
133
+ std::copy(sizes.begin(), sizes.end(), t_sizes.begin());
134
+ std::copy(strides.begin(), strides.end(), t_strides.begin());
135
+ }
136
+
137
+ explicit TensorGeometryHolder(const Tensor& t)
138
+ : TensorGeometryHolder(t.sizes(), t.strides()) {}
139
+
140
+ auto operator*() const {
141
+ return std::make_tuple(t_sizes, t_strides);
142
+ }
143
+
144
+ private:
145
+ geometry_holder_t t_sizes;
146
+ geometry_holder_t t_strides;
147
+ };
148
+
149
+ template <>
150
+ class TensorGeometryHolder<0> {
151
+ using geometry_holder_t = Tensor;
152
+
153
+ public:
154
+ explicit TensorGeometryHolder(
155
+ IntArrayRef sizes,
156
+ IntArrayRef strides,
157
+ TensorOptions options) {
158
+ const int64_t t_ndims = sizes.size();
159
+ const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU);
160
+ Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options);
161
+ t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options));
162
+ t_sizes_and_strides_cpu.select(0, 1).copy_(
163
+ at::tensor(strides, cpu_options));
164
+ const Tensor t_sizes_and_strides =
165
+ t_sizes_and_strides_cpu.to(options.device());
166
+ t_sizes = t_sizes_and_strides.select(0, 0);
167
+ t_strides = t_sizes_and_strides.select(0, 1);
168
+ }
169
+
170
+ explicit TensorGeometryHolder(const Tensor& t)
171
+ : TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {}
172
+
173
+ auto operator*() const {
174
+ return std::make_tuple(
175
+ t_sizes.template data_ptr<int64_t>(),
176
+ t_strides.template data_ptr<int64_t>());
177
+ }
178
+
179
+ private:
180
+ geometry_holder_t t_sizes;
181
+ geometry_holder_t t_strides;
182
+ };
183
+
184
+ // Return all indices of a tensor with the given shape.
185
+ //
186
+ // full_coo_indices(shape) is equivalent to
187
+ // torch.ones(shape).nonzero().transpose(-2, -1) but much faster.
188
+ TORCH_API Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options);
189
+
190
+ } // namespace at::sparse
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/StridedRandomAccessor.h ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at::native {
4
+
5
+ // (Const)StridedRandomAccessor is a
6
+ // (const) random access iterator defined over
7
+ // a strided array.
8
+
9
+ // The traits below are to introduce __restrict__
10
+ // modifier on different platforms.
11
+
12
+ template <typename T>
13
+ struct DefaultPtrTraits {
14
+ using PtrType = T*;
15
+ };
16
+
17
+ #if (defined(_WIN32) || defined(_WIN64))
18
+ #define RESTRICT __restrict
19
+ #else
20
+ #define RESTRICT __restrict__
21
+ #endif
22
+
23
+ template <typename T>
24
+ struct RestrictPtrTraits {
25
+ using PtrType = T* RESTRICT;
26
+ };
27
+
28
+ template <
29
+ typename T,
30
+ typename index_t = int64_t,
31
+ template <typename U> class PtrTraits = DefaultPtrTraits
32
+ >
33
+ class ConstStridedRandomAccessor {
34
+ public:
35
+ using difference_type = index_t;
36
+ using value_type = const T;
37
+ using pointer = const typename PtrTraits<T>::PtrType;
38
+ using reference = const value_type&;
39
+ using iterator_category = std::random_access_iterator_tag;
40
+
41
+ using PtrType = typename PtrTraits<T>::PtrType;
42
+ using index_type = index_t;
43
+
44
+ // Constructors {
45
+ C10_HOST_DEVICE
46
+ ConstStridedRandomAccessor(PtrType ptr, index_t stride)
47
+ : ptr{ptr}, stride{stride}
48
+ {}
49
+
50
+ C10_HOST_DEVICE
51
+ explicit ConstStridedRandomAccessor(PtrType ptr)
52
+ : ptr{ptr}, stride{static_cast<index_t>(1)}
53
+ {}
54
+
55
+ C10_HOST_DEVICE
56
+ ConstStridedRandomAccessor()
57
+ : ptr{nullptr}, stride{static_cast<index_t>(1)}
58
+ {}
59
+ // }
60
+
61
+ // Pointer-like operations {
62
+ C10_HOST_DEVICE
63
+ reference operator*() const {
64
+ return *ptr;
65
+ }
66
+
67
+ C10_HOST_DEVICE
68
+ const value_type* operator->() const {
69
+ return reinterpret_cast<const value_type*>(ptr);
70
+ }
71
+
72
+ C10_HOST_DEVICE
73
+ reference operator[](index_t idx) const {
74
+ return ptr[idx * stride];
75
+ }
76
+ // }
77
+
78
+ // Prefix/postfix increment/decrement {
79
+ C10_HOST_DEVICE
80
+ ConstStridedRandomAccessor& operator++() {
81
+ ptr += stride;
82
+ return *this;
83
+ }
84
+
85
+ C10_HOST_DEVICE
86
+ ConstStridedRandomAccessor operator++(int) {
87
+ ConstStridedRandomAccessor copy(*this);
88
+ ++*this;
89
+ return copy;
90
+ }
91
+
92
+ C10_HOST_DEVICE
93
+ ConstStridedRandomAccessor& operator--() {
94
+ ptr -= stride;
95
+ return *this;
96
+ }
97
+
98
+ C10_HOST_DEVICE
99
+ ConstStridedRandomAccessor operator--(int) {
100
+ ConstStridedRandomAccessor copy(*this);
101
+ --*this;
102
+ return copy;
103
+ }
104
+ // }
105
+
106
+ // Arithmetic operations {
107
+ C10_HOST_DEVICE
108
+ ConstStridedRandomAccessor& operator+=(index_t offset) {
109
+ ptr += offset * stride;
110
+ return *this;
111
+ }
112
+
113
+ C10_HOST_DEVICE
114
+ ConstStridedRandomAccessor operator+(index_t offset) const {
115
+ return ConstStridedRandomAccessor(ptr + offset * stride, stride);
116
+ }
117
+
118
+ C10_HOST_DEVICE
119
+ friend ConstStridedRandomAccessor operator+(
120
+ index_t offset,
121
+ const ConstStridedRandomAccessor& accessor
122
+ ) {
123
+ return accessor + offset;
124
+ }
125
+
126
+ C10_HOST_DEVICE
127
+ ConstStridedRandomAccessor& operator-=(index_t offset) {
128
+ ptr -= offset * stride;
129
+ return *this;
130
+ }
131
+
132
+ C10_HOST_DEVICE
133
+ ConstStridedRandomAccessor operator-(index_t offset) const {
134
+ return ConstStridedRandomAccessor(ptr - offset * stride, stride);
135
+ }
136
+
137
+ // Note that this operator is well-defined when `this` and `other`
138
+ // represent the same sequences, i.e. when
139
+ // 1. this.stride == other.stride,
140
+ // 2. |other - this| / this.stride is an Integer.
141
+ C10_HOST_DEVICE
142
+ difference_type operator-(const ConstStridedRandomAccessor& other) const {
143
+ return (ptr - other.ptr) / stride;
144
+ }
145
+ // }
146
+
147
+ // Comparison operators {
148
+ C10_HOST_DEVICE
149
+ bool operator==(const ConstStridedRandomAccessor& other) const {
150
+ return (ptr == other.ptr) && (stride == other.stride);
151
+ }
152
+
153
+ C10_HOST_DEVICE
154
+ bool operator!=(const ConstStridedRandomAccessor& other) const {
155
+ return !(*this == other);
156
+ }
157
+
158
+ C10_HOST_DEVICE
159
+ bool operator<(const ConstStridedRandomAccessor& other) const {
160
+ return ptr < other.ptr;
161
+ }
162
+
163
+ C10_HOST_DEVICE
164
+ bool operator<=(const ConstStridedRandomAccessor& other) const {
165
+ return (*this < other) || (*this == other);
166
+ }
167
+
168
+ C10_HOST_DEVICE
169
+ bool operator>(const ConstStridedRandomAccessor& other) const {
170
+ return !(*this <= other);
171
+ }
172
+
173
+ C10_HOST_DEVICE
174
+ bool operator>=(const ConstStridedRandomAccessor& other) const {
175
+ return !(*this < other);
176
+ }
177
+ // }
178
+
179
+ protected:
180
+ PtrType ptr;
181
+ index_t stride;
182
+ };
183
+
184
+ template <
185
+ typename T,
186
+ typename index_t = int64_t,
187
+ template <typename U> class PtrTraits = DefaultPtrTraits
188
+ >
189
+ class StridedRandomAccessor
190
+ : public ConstStridedRandomAccessor<T, index_t, PtrTraits> {
191
+ public:
192
+ using difference_type = index_t;
193
+ using value_type = T;
194
+ using pointer = typename PtrTraits<T>::PtrType;
195
+ using reference = value_type&;
196
+
197
+ using BaseType = ConstStridedRandomAccessor<T, index_t, PtrTraits>;
198
+ using PtrType = typename PtrTraits<T>::PtrType;
199
+
200
+ // Constructors {
201
+ C10_HOST_DEVICE
202
+ StridedRandomAccessor(PtrType ptr, index_t stride)
203
+ : BaseType(ptr, stride)
204
+ {}
205
+
206
+ C10_HOST_DEVICE
207
+ explicit StridedRandomAccessor(PtrType ptr)
208
+ : BaseType(ptr)
209
+ {}
210
+
211
+ C10_HOST_DEVICE
212
+ StridedRandomAccessor()
213
+ : BaseType()
214
+ {}
215
+ // }
216
+
217
+ // Pointer-like operations {
218
+ C10_HOST_DEVICE
219
+ reference operator*() const {
220
+ return *this->ptr;
221
+ }
222
+
223
+ C10_HOST_DEVICE
224
+ value_type* operator->() const {
225
+ return reinterpret_cast<value_type*>(this->ptr);
226
+ }
227
+
228
+ C10_HOST_DEVICE
229
+ reference operator[](index_t idx) const {
230
+ return this->ptr[idx * this->stride];
231
+ }
232
+ // }
233
+
234
+ // Prefix/postfix increment/decrement {
235
+ C10_HOST_DEVICE
236
+ StridedRandomAccessor& operator++() {
237
+ this->ptr += this->stride;
238
+ return *this;
239
+ }
240
+
241
+ C10_HOST_DEVICE
242
+ StridedRandomAccessor operator++(int) {
243
+ StridedRandomAccessor copy(*this);
244
+ ++*this;
245
+ return copy;
246
+ }
247
+
248
+ C10_HOST_DEVICE
249
+ StridedRandomAccessor& operator--() {
250
+ this->ptr -= this->stride;
251
+ return *this;
252
+ }
253
+
254
+ C10_HOST_DEVICE
255
+ StridedRandomAccessor operator--(int) {
256
+ StridedRandomAccessor copy(*this);
257
+ --*this;
258
+ return copy;
259
+ }
260
+ // }
261
+
262
+ // Arithmetic operations {
263
+ C10_HOST_DEVICE
264
+ StridedRandomAccessor& operator+=(index_t offset) {
265
+ this->ptr += offset * this->stride;
266
+ return *this;
267
+ }
268
+
269
+ C10_HOST_DEVICE
270
+ StridedRandomAccessor operator+(index_t offset) const {
271
+ return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride);
272
+ }
273
+
274
+ C10_HOST_DEVICE
275
+ friend StridedRandomAccessor operator+(
276
+ index_t offset,
277
+ const StridedRandomAccessor& accessor
278
+ ) {
279
+ return accessor + offset;
280
+ }
281
+
282
+ C10_HOST_DEVICE
283
+ StridedRandomAccessor& operator-=(index_t offset) {
284
+ this->ptr -= offset * this->stride;
285
+ return *this;
286
+ }
287
+
288
+ C10_HOST_DEVICE
289
+ StridedRandomAccessor operator-(index_t offset) const {
290
+ return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride);
291
+ }
292
+
293
+ // Note that here we call BaseType::operator- version
294
+ C10_HOST_DEVICE
295
+ difference_type operator-(const BaseType& other) const {
296
+ return (static_cast<const BaseType&>(*this) - other);
297
+ }
298
+ // }
299
+ };
300
+
301
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/native/IndexingUtils.h>
4
+ #include <ATen/native/TensorIterator.h>
5
+
6
+ namespace at::native {
7
+ namespace {
8
+ static std::string shapes_as_str(TensorList tensors) {
9
+ std::ostringstream os;
10
+ bool first = true;
11
+ for (auto& tensor : tensors) {
12
+ if (tensor.defined()) {
13
+ if (!first) {
14
+ os << ", ";
15
+ }
16
+ os << tensor.sizes();
17
+ first = false;
18
+ }
19
+ }
20
+ return os.str();
21
+ }
22
+ } // anonymous namespace
23
+
24
+ static std::tuple<bool, Tensor> canDispatchToMaskedFill(const Tensor& self, const torch::List<c10::optional<at::Tensor>>& indices,
25
+ const Tensor& value){
26
+ if (!(value.numel() ==1 && value.device().is_cpu())){
27
+ return std::make_tuple(false,Tensor());
28
+ }
29
+ int64_t num_ind = 0;
30
+ Tensor mask;
31
+ auto self_device = self.device();
32
+ for (const c10::optional<Tensor>& i: indices) {
33
+ if (!i.has_value() || !(*i).defined()){
34
+ num_ind++;
35
+ } else {
36
+ const Tensor &index = *i;
37
+ if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
38
+ index.device() != self_device || mask.defined()){
39
+ return std::make_tuple(false, Tensor());
40
+ } else {
41
+ mask = index;
42
+ for (const auto j : c10::irange(index.dim())) {
43
+ int64_t srcIdx = num_ind + j;
44
+ TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(), " at index ", j,
45
+ " does not match the shape of the indexed tensor ", self.sizes(), " at index ", srcIdx);
46
+ }
47
+ num_ind += mask.ndimension();
48
+ }
49
+ }
50
+ }
51
+ for (C10_UNUSED const auto i : c10::irange(num_ind, self.ndimension())) {
52
+ mask = mask.unsqueeze(-1);
53
+ }
54
+ return std::make_tuple(true, mask);
55
+ }
56
+
57
+ static AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
58
+ checkIndexTensorTypes(orig, /*allow_int*/ true);
59
+ // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
60
+ auto indices = expandTensors(self, orig);
61
+ // next broadcast all index tensors together
62
+ try {
63
+ indices = expand_outplace(indices);
64
+ } catch (std::exception& e) {
65
+ TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
66
+ " with shapes ", shapes_as_str(indices));
67
+ }
68
+ // add missing null Tensors so that it matches self.dim()
69
+ while (indices.size() < (size_t)self.dim()) {
70
+ indices.emplace_back();
71
+ }
72
+ // if the non-null indices are not all adjacent, transpose self and indices
73
+ // together so that they're adjacent at the front
74
+ if (!hasContiguousSubspace(indices)) {
75
+ std::tie(self, indices) = transposeToFront(self, indices);
76
+ }
77
+ // Ensure indices are on the same device as self
78
+ for (auto & indice : indices) {
79
+ if (indice.defined() && indice.device() != self.device()) {
80
+ indice = indice.to(self.device());
81
+ }
82
+ }
83
+ for (auto & indice : indices) {
84
+ if (indice.defined() && indice.dtype() == at::kInt) {
85
+ indice = indice.to(at::kLong);
86
+ }
87
+ }
88
+
89
+ return AdvancedIndex(self, indices);
90
+ }
91
+
92
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorDimApply.h ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <c10/util/irange.h>
4
+
5
+ namespace at::native {
6
+ //input tensors are non-zero dim and non-empty
7
+ template<typename T1, typename T2, typename Function>
8
+
9
+ void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) {
10
+ int ndims = self.dim();
11
+ int tensor_dim_apply_has_finished = 0;
12
+ std::vector<int64_t> counter(ndims, 0);
13
+ const T1* self_data = self.const_data_ptr<T1>();
14
+ T1* values_data = values.data_ptr<T1>();
15
+ T2* indices_data = indices.data_ptr<T2>();
16
+ int64_t self_stride = self.stride(dim);
17
+ int64_t values_stride = values.stride(dim);
18
+ int64_t indices_stride = indices.stride(dim);
19
+ int self_dim_size = self.size(dim);
20
+
21
+ while (!tensor_dim_apply_has_finished) {
22
+ func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride);
23
+ if (ndims == 1) {
24
+ break;
25
+ }
26
+ for (const auto dim_i : c10::irange(ndims)) {
27
+ if (dim_i == dim) {
28
+ if (dim_i == (ndims - 1)) {
29
+ tensor_dim_apply_has_finished = 1;
30
+ break;
31
+ }
32
+ continue;
33
+ }
34
+ counter[dim_i]++;
35
+ self_data += self.stride(dim_i);
36
+ values_data += values.stride(dim_i);
37
+ indices_data += indices.stride(dim_i);
38
+
39
+ if (counter[dim_i] == self.size(dim_i)) {
40
+ if (dim_i == ndims-1) {
41
+ tensor_dim_apply_has_finished = 1;
42
+ break;
43
+ } else {
44
+ self_data -= counter[dim_i]*self.stride(dim_i);
45
+ values_data -= counter[dim_i]*values.stride(dim_i);
46
+ indices_data -= counter[dim_i]*indices.stride(dim_i);
47
+ counter[dim_i] = 0;
48
+ }
49
+ } else {
50
+ break;
51
+ }
52
+ }
53
+ }
54
+ }
55
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorFactories.h ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/EmptyTensor.h>
5
+ #include <ATen/TensorIterator.h>
6
+ #include <ATen/Dispatch.h>
7
+ #include <ATen/Dispatch_v2.h>
8
+ #include <ATen/native/DispatchStub.h>
9
+
10
+ #ifndef AT_PER_OPERATOR_HEADERS
11
+ #include <ATen/Functions.h>
12
+ #else
13
+ #include <ATen/ops/scalar_tensor.h>
14
+ #endif
15
+
16
+ namespace at::native {
17
+ // Different combinations of row, col, and offset can lead to two cases:
18
+ //
19
+ // Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
20
+ // Example A: offset > 0
21
+ // 1 1 0 0 0
22
+ // 1 1 1 0 0
23
+ // 1 1 1 1 0
24
+ // Example B: offset <= 0
25
+ // 0 0 0
26
+ // 1 0 0
27
+ // 1 1 0
28
+ // In this case, we calculate the number of elements in the first row and
29
+ // last row of the tril respectively, and then compute the tril size.
30
+ //
31
+ // Case 2 - Trapezoid + Rectangle: row + offset > col
32
+ // Example:
33
+ // 1 1 0
34
+ // 1 1 1
35
+ // 1 1 1
36
+ // In this case, we first calculate the size of top trapezoid, and then
37
+ // calculate the size of the bottom rectangle.
38
+ inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
39
+ // If either dimension is 0 then the there is no tril
40
+ if (row == 0 || col == 0) {
41
+ return 0;
42
+ }
43
+ // number of elements in the first row of the tril
44
+ auto m_first_row = offset > 0 ?
45
+ std::min<int64_t>(col, 1 + offset) : // upper bounded by col
46
+ row + offset > 0; // either 0 or 1
47
+ // number of elements in the last row of the tril, bounded by [0, col]
48
+ auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
49
+ // number of rows, bounded by [0, row]
50
+ auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
51
+ auto n_row_trapezoid = (m_last_row - m_first_row + 1);
52
+
53
+ // calculate # of elements in the top trapezoid
54
+ auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1;
55
+
56
+ // calculate # of elements in the bottom rectangle if there is any
57
+ auto diff_row = n_row_all - n_row_trapezoid;
58
+ if (diff_row > 0) {
59
+ tril_size += diff_row * col;
60
+ }
61
+
62
+ return tril_size;
63
+ }
64
+
65
+ inline void check_args(
66
+ int64_t row, int64_t col, c10::optional<Layout> layout_opt) {
67
+ TORCH_CHECK(row >= 0, "row must be non-negative, got", row);
68
+ TORCH_CHECK(col >= 0, "col must be non-negative, got", col);
69
+ if (layout_opt.has_value()) {
70
+ TORCH_CHECK(
71
+ *layout_opt == at::kStrided,
72
+ "only support layout=torch.strided, got",
73
+ *layout_opt)
74
+ }
75
+ }
76
+
77
+ using at::check_size_nonnegative;
78
+
79
+ // assumes maximum value in created tensor is n-1 (e.g., torch.randperm(n))
80
+ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
81
+ // match defined() to behavior of checks below
82
+ TORCH_CHECK(at::scalar_tensor(n>0?n-1:n, tensor.options()).defined(),
83
+ "n is too large for result tensor type: '", tensor.toString(), "'");
84
+
85
+ // Ensure sufficient precision for floating point representation.
86
+ switch (tensor.scalar_type()) {
87
+ case at::ScalarType::Half:
88
+ TORCH_CHECK(n <= (int64_t(1) << 11) + 1, "n cannot be greater than 2049 for Half type.");
89
+ break;
90
+ case at::ScalarType::Float:
91
+ TORCH_CHECK(n <= (int64_t(1) << 24) + 1, "n cannot be greater than 2^24+1 for Float type.");
92
+ break;
93
+ case at::ScalarType::Double: // Unlikely to happen, but doesn't hurt to check
94
+ TORCH_CHECK(n <= (int64_t(1) << 53) + 1, "n cannot be greater than 2^53+1 for Double type.");
95
+ break;
96
+ default:
97
+ break;
98
+ }
99
+ }
100
+
101
+ // Called by `empty*` functions when deterministic algorithms are enabled to
102
+ // fill the tensor with NaN if it is floating point or complex type, or fill
103
+ // with max value if it is integer type
104
+ inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
105
+ if (tensor.is_floating_point() || tensor.is_complex()) {
106
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
107
+ kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
108
+ tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
109
+ });
110
+ } else {
111
+ AT_DISPATCH_V2(
112
+ tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
113
+ tensor.fill_(std::numeric_limits<scalar_t>::max());
114
+ }), kBool, AT_EXPAND(AT_INTEGRAL_TYPES_V2));
115
+ }
116
+ return tensor;
117
+ }
118
+
119
+ // The ZeroTensor allocator ignores whatever allocation is requested and always
120
+ // gives you nullptr
121
+ struct ZeroTensorAllocator final : public at::Allocator {
122
+ ZeroTensorAllocator(at::Device device) : device_(device) {};
123
+ ~ZeroTensorAllocator() override = default;
124
+ static void deleter(void* const pointer) {
125
+ TORCH_INTERNAL_ASSERT(!pointer);
126
+ }
127
+ DataPtr allocate(const size_t /*nbytes*/) override {
128
+ return {nullptr, nullptr, &deleter, device_};
129
+ }
130
+ DeleterFnPtr raw_deleter() const override {
131
+ return deleter;
132
+ }
133
+ void copy_data(void* dest, const void* src, std::size_t count) const final {}
134
+ at::Device device_;
135
+ };
136
+
137
+ using binary_fn = void (*)(TensorIterator&);
138
+
139
+ DECLARE_DISPATCH(binary_fn, complex_stub);
140
+ DECLARE_DISPATCH(binary_fn, polar_stub);
141
+
142
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <complex>
4
+ #include <type_traits>
5
+ #include <c10/core/ScalarType.h>
6
+ #include <ATen/detail/FunctionTraits.h>
7
+ #include <ATen/native/TensorIterator.h>
8
+
9
+
10
+ // This file includes utilities for dynamic_casting done by TensorIterator, see CUDALoops.cuh and Loops.h.
11
+
12
+ // dynamic_casting handles when the types expected by the iterator do not match the types of the arguments
13
+ // to the function that is being called.
14
+ // On CUDA, the cast is currently pushed down into the kernel (for performance reasons).
15
+ // On CPU, there is currently an internal assert that a dynamic_cast is not needed.
16
+
17
+ namespace at::native {
18
+
19
+ // `needs_dynamic_casting` compares the types expected by iterator
20
+ // (i.e. dtypes of the operands) with the actual type of the arguments
21
+ // (and returns) of func_t
22
+ template<typename func_t, int nargs=function_traits<func_t>::arity>
23
+ struct needs_dynamic_casting {
24
+ static bool check(TensorIteratorBase& iter) {
25
+ using traits = function_traits<func_t>;
26
+ using cpp_type = typename traits::template arg<nargs - 1>::type;
27
+ using cpp_map = c10::CppTypeToScalarType<cpp_type>;
28
+
29
+ if (iter.input_dtype(nargs-1) != cpp_map::value) {
30
+ return true;
31
+ }
32
+ return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
33
+ }
34
+ };
35
+
36
+ template<typename func_t>
37
+ struct needs_dynamic_casting<func_t, 0> {
38
+ static bool check(TensorIteratorBase& iter) {
39
+ using traits = function_traits<func_t>;
40
+ using cpp_type = typename traits::result_type;
41
+
42
+ // we could assert output numbers are correct here, but checks
43
+ // (including arity) are currently pushed outside of this struct.
44
+ if constexpr (std::is_void_v<cpp_type>) {
45
+ return false;
46
+ } else {
47
+ return iter.dtype(0) != c10::CppTypeToScalarType<cpp_type>::value;
48
+ }
49
+ }
50
+ };
51
+
52
+ } //namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorProperties.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // See NOTE: [Tensor vs. TensorBase]
4
+ namespace at {
5
+ class TensorBase;
6
+ }
7
+
8
+ namespace at::native {
9
+
10
+ TORCH_API bool cudnn_is_acceptable(const TensorBase& self);
11
+
12
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorShape.h ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <c10/util/irange.h>
4
+ #include <ATen/core/IListRef.h>
5
+
6
+ namespace at::native {
7
+
8
+ TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
9
+
10
+ inline bool cat_should_skip_tensor(const Tensor& t) {
11
+ return t.numel() == 0 && t.dim() == 1;
12
+ }
13
+
14
+ // Check to see if the shape of tensors is compatible
15
+ // for being concatenated along a given dimension.
16
+ inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
17
+ int64_t first_dims = first.dim();
18
+ int64_t second_dims = second.dim();
19
+ TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
20
+ first_dims, " and ", second_dims);
21
+ for (const auto dim : c10::irange(first_dims)) {
22
+ if (dim == dimension) {
23
+ continue;
24
+ }
25
+ int64_t first_dim_size = first.sizes()[dim];
26
+ int64_t second_dim_size = second.sizes()[dim];
27
+ TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
28
+ dimension, ". Expected size ", static_cast<long long>(first_dim_size), " but got size ", static_cast<long long>(second_dim_size), " for tensor number ", index, " in the list.");
29
+ }
30
+ }
31
+
32
+ inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) {
33
+ int64_t i = 0;
34
+ for(const Tensor& t : tensors) {
35
+ TORCH_CHECK(t.dim() > 0,
36
+ "zero-dimensional tensor (at position ", i, ") cannot be concatenated");
37
+ i++;
38
+ }
39
+ }
40
+
41
+ inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t dim) {
42
+ TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
43
+ TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
44
+ int64_t dim_size = self.size(dim);
45
+ TORCH_CHECK(split_size > 0 || dim_size == 0,
46
+ "split_size can only be 0 if dimension size is 0, "
47
+ "but got dimension size of ", dim_size);
48
+ // if split_size is 0 and dimension size is 0, there is 1 split.
49
+ int64_t num_splits = 1;
50
+ if (split_size != 0) {
51
+ // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
52
+ // (returns a single split). We might want to error here, but keep it for BC.
53
+ num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
54
+ }
55
+ return num_splits;
56
+ }
57
+
58
+ inline bool have_same_ndims(TensorList tensors) {
59
+ auto ndim = tensors[0].dim();
60
+ for (const auto tensor_idx : c10::irange(tensors.size())) {
61
+ if(tensors[tensor_idx].dim() != ndim) {
62
+ return false;
63
+ }
64
+ }
65
+ return true;
66
+ }
67
+
68
+ inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
69
+ auto tensor_zero_size = tensors[0].sizes();
70
+ std::vector<c10::SymInt> leading_dim_sizes(tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
71
+ for (const auto i : c10::irange(tensors.size())) {
72
+ at::Tensor tensor = tensors[i];
73
+ for(const auto j : c10::irange(dim)) {
74
+ TORCH_CHECK(
75
+ tensor.size(j) == leading_dim_sizes[j],
76
+ "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"
77
+ );
78
+ }
79
+ }
80
+ }
81
+
82
+ inline int64_t preprocess_chunk_cat_inputs(TensorList tensors, int64_t dim, int64_t num_chunks) {
83
+ TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
84
+ TORCH_CHECK(!tensors.empty(),
85
+ "_chunk_cat expects a non-empty input tensor list");
86
+ auto expected_dtype = tensors[0].dtype();
87
+ auto expected_device = tensors[0].device();
88
+ for(const auto i : c10::irange(tensors.size())) {
89
+ TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
90
+ TORCH_CHECK(tensors[i].dtype() == expected_dtype, "_chunk_cat expects all input tensors with the same dtype");
91
+ TORCH_CHECK(tensors[i].device() == expected_device, "_chunk_cat expects all inputs tensors on the same device");
92
+ }
93
+ if (have_same_ndims(tensors)) {
94
+ dim = maybe_wrap_dim(dim, tensors[0].dim());
95
+ } else {
96
+ TORCH_CHECK(dim >= 0, "_chunk_cat expects non-negative dim when input tensors have different ndims")
97
+ for(const auto i : c10::irange(tensors.size())) {
98
+ TORCH_CHECK(dim < tensors[i].ndimension(), "_chunk_cat expects dim < ndim for all input tensors");
99
+ }
100
+ }
101
+ leading_dimension_matches(tensors, dim);
102
+ return dim;
103
+ }
104
+
105
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TopKImpl.h ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/TensorAccessor.h>
3
+ #include <ATen/NumericUtils.h>
4
+
5
+ namespace at::native {
6
+
7
+ #ifdef CPU_CAPABILITY
8
+ inline namespace CPU_CAPABILITY {
9
+ #else
10
+ inline namespace DEFAULT {
11
+ #endif
12
+
13
+ // Core topk loop, shared between CPU and QuantizedCPU
14
+ template <typename scalar_t, typename accscalar_t>
15
+ void topk_impl_loop(
16
+ const int64_t mode_values_stride,
17
+ const int64_t mode_indices_stride,
18
+ const int64_t tmp_values_stride,
19
+ const int64_t k,
20
+ const int64_t dim_size,
21
+ const bool largest,
22
+ const bool sorted,
23
+ char** data, const int64_t* strides, const int64_t n) {
24
+
25
+ // If k is zero, then output values and indices are empty tensors
26
+ // So iterating over other dims is pointless
27
+ if (k == 0) {
28
+ return;
29
+ }
30
+ using elem_t = std::pair<accscalar_t, int64_t>;
31
+ std::vector<elem_t> queue(dim_size);
32
+ for (const auto i : c10::irange(n)) {
33
+ TensorAccessor<scalar_t, 1> mode_values(
34
+ reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
35
+ &k, &mode_values_stride);
36
+ TensorAccessor<int64_t, 1> mode_indices(
37
+ reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
38
+ &k, &mode_indices_stride);
39
+ TensorAccessor<const scalar_t, 1> tmp_values(
40
+ reinterpret_cast<scalar_t*>(data[2] + i * strides[2]),
41
+ &dim_size, &tmp_values_stride);
42
+
43
+ auto n_2 = dim_size;
44
+ auto use_partial_sort = k * 64 <= n_2;
45
+
46
+ for (const auto j : c10::irange(n_2)) {
47
+ queue[j].first = tmp_values[j];
48
+ queue[j].second = j;
49
+ }
50
+
51
+ // we want nan to be sorted as top for numpy compatibility
52
+ if (use_partial_sort) {
53
+ if (largest) {
54
+ std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
55
+ [](const elem_t& x, const elem_t& y) -> bool {
56
+ return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
57
+ });
58
+ } else {
59
+ std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
60
+ [](const elem_t& x, const elem_t& y) -> bool {
61
+ return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
62
+ });
63
+ }
64
+ } else {
65
+ if (largest) {
66
+ std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
67
+ [](const elem_t& x, const elem_t& y) -> bool {
68
+ return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
69
+ });
70
+ if (sorted) {
71
+ std::sort(queue.begin(), queue.begin() + k - 1,
72
+ [](const elem_t& x, const elem_t& y) -> bool {
73
+ return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
74
+ });
75
+ }
76
+ } else {
77
+ std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
78
+ [](const elem_t& x, const elem_t& y) -> bool {
79
+ return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
80
+ });
81
+ if (sorted) {
82
+ std::sort(queue.begin(), queue.begin() + k -1,
83
+ [](const elem_t& x, const elem_t& y) -> bool {
84
+ return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
85
+ });
86
+ }
87
+ }
88
+ }
89
+
90
+ for (const auto j : c10::irange(k)) {
91
+ mode_values[j] = queue[j].first;
92
+ mode_indices[j] = queue[j].second;
93
+ }
94
+ }
95
+ }
96
+
97
+ } // namespace CPU_CAPABILITY
98
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TransposeType.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/util/Exception.h>
3
+
4
+ namespace at::native {
5
+
6
+ // Used as an interface between the different BLAS-like libraries
7
+ enum class TransposeType {
8
+ NoTranspose,
9
+ Transpose,
10
+ ConjTranspose,
11
+ };
12
+
13
+ // Transforms TransposeType into the BLAS / LAPACK format
14
+ static inline char to_blas(TransposeType trans) {
15
+ switch (trans) {
16
+ case TransposeType::Transpose: return 'T';
17
+ case TransposeType::NoTranspose: return 'N';
18
+ case TransposeType::ConjTranspose: return 'C';
19
+ }
20
+ TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
21
+ }
22
+
23
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold3d.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/ScalarType.h>
4
+
5
+ namespace at::native {
6
+
7
+ void Unfold3dCopyCPU(
8
+ ScalarType dtype,
9
+ const void *src,
10
+ int64_t C,
11
+ int64_t X_D,
12
+ int64_t X_H,
13
+ int64_t X_W,
14
+ int64_t Y_D,
15
+ int64_t Y_H,
16
+ int64_t Y_W,
17
+ int64_t kernel_d,
18
+ int64_t kernel_h,
19
+ int64_t kernel_w,
20
+ int64_t stride_d,
21
+ int64_t stride_h,
22
+ int64_t stride_w,
23
+ int64_t pad_d,
24
+ int64_t pad_h,
25
+ int64_t pad_w,
26
+ void* dst);
27
+
28
+ void Unfold3dAccCPU(
29
+ ScalarType dtype,
30
+ const void *src,
31
+ int64_t C,
32
+ int64_t X_D,
33
+ int64_t X_H,
34
+ int64_t X_W,
35
+ int64_t Y_D,
36
+ int64_t Y_H,
37
+ int64_t Y_W,
38
+ int64_t kernel_d,
39
+ int64_t kernel_h,
40
+ int64_t kernel_w,
41
+ int64_t stride_d,
42
+ int64_t stride_h,
43
+ int64_t stride_w,
44
+ int64_t pad_d,
45
+ int64_t pad_h,
46
+ int64_t pad_w,
47
+ void *dst);
48
+
49
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnfoldBackward.h ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/TensorIterator.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+ #include <ATen/native/NonEmptyUtils.h>
7
+
8
+ #ifndef AT_PER_OPERATOR_HEADERS
9
+ #include <ATen/Functions.h>
10
+ #else
11
+ #include <ATen/ops/arange.h>
12
+ #endif
13
+
14
+ namespace at::native {
15
+
16
+ using unfold_backward_fn = void (*)(
17
+ Tensor& grad_in,
18
+ const Tensor& grad,
19
+ int64_t dim,
20
+ int64_t size,
21
+ int64_t step
22
+ );
23
+
24
+ DECLARE_DISPATCH(unfold_backward_fn, unfold_backward_stub);
25
+
26
+ namespace {
27
+
28
+ // Note on naming: it is unconventional.
29
+ // grad_in does not mean that it is a gradient wrt to input,
30
+ // grad_in/grad_out is just an input/output of unfold_backward kernel.
31
+
32
+ static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_out(
33
+ Tensor& grad_out,
34
+ const Tensor& grad_in,
35
+ int64_t dim,
36
+ int64_t size,
37
+ int64_t step
38
+ ) {
39
+ dim = maybe_wrap_dim(dim, grad_out.dim());
40
+ // last dim stores the folds
41
+
42
+ auto grad_out_dim_size = ensure_nonempty_size(grad_out, dim);
43
+ auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
44
+ // dictates the number of elements to iterate over
45
+ // in dimension `dim`
46
+ auto iter_dim_size = std::min(
47
+ grad_out_dim_size,
48
+ (grad_in_dim_size - 1) * step + size
49
+ );
50
+
51
+ /* prepare grad_out for TensorIterator { */
52
+ auto grad_out_strides = ensure_nonempty_vec(grad_out.strides().vec());
53
+ auto grad_out_sizes = ensure_nonempty_vec(grad_out.sizes().vec());
54
+ grad_out_sizes[dim] = iter_dim_size;
55
+ auto grad_out_restrided = grad_out.as_strided(
56
+ grad_out_sizes, grad_out_strides
57
+ );
58
+ /* } */
59
+
60
+ /* prepare grad_in for TensorIterator { */
61
+ auto grad_in_strides = ensure_nonempty_vec(grad_in.strides().vec());
62
+ auto grad_in_sizes = ensure_nonempty_vec(grad_in.sizes().vec());
63
+
64
+ // set strides for dim to 0
65
+ // and size to 1 because
66
+ // this dimension is indexed inside the kernel
67
+ grad_in_strides[dim] = 0;
68
+ grad_in_sizes[dim] = 1;
69
+
70
+ grad_in_strides.pop_back();
71
+ grad_in_sizes.pop_back();
72
+
73
+ auto grad_in_restrided = grad_in.squeeze(-1).as_strided(
74
+ grad_in_sizes, grad_in_strides
75
+ );
76
+ /* } */
77
+
78
+ // During the TensorIterator iteration we have to know
79
+ // i_dim in grad_out[i_1,...,i_dim,...i_n],
80
+ // idx_dim stores this information
81
+ /* prepare idx_dim for TensorIterator { */
82
+ auto idx_dim = at::arange(
83
+ 0, iter_dim_size, grad_in.options().dtype(at::kLong)
84
+ );
85
+
86
+ auto grad_out_dim = ensure_nonempty_dim(grad_out.dim());
87
+
88
+ auto idx_dim_strides = std::vector<int64_t>(grad_out_dim, 0);
89
+ auto idx_dim_sizes = std::vector<int64_t>(grad_out_dim, 1);
90
+
91
+ idx_dim_strides[dim] = 1;
92
+ idx_dim_sizes[dim] = iter_dim_size;
93
+
94
+ // idx_dim size will broadcast over determined by grad_out sizes in TensorIterator
95
+ auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);
96
+ /* } */
97
+
98
+ auto iter = TensorIteratorConfig()
99
+ .set_check_mem_overlap(false)
100
+ .check_all_same_dtype(false)
101
+ .resize_outputs(false)
102
+ .add_owned_output(grad_out_restrided)
103
+ .add_owned_input(grad_in_restrided)
104
+ .add_owned_input(idx_dim_restrided)
105
+ .build();
106
+
107
+ return iter;
108
+ }
109
+
110
+ }
111
+
112
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UpSample.h ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <math.h>
4
+
5
+ #include <ATen/OpMathType.h>
6
+ #include <ATen/TensorUtils.h>
7
+ #include <ATen/OpMathType.h>
8
+ #include <ATen/core/Tensor.h>
9
+ #include <ATen/cpu/vec/functional.h>
10
+ #include <ATen/cpu/vec/vec.h>
11
+ #include <ATen/native/DispatchStub.h>
12
+ #include <ATen/native/cpu/utils.h>
13
+
14
+ /**
15
+ * Note [compute_scales_value]
16
+ * Note [area_pixel_compute_scale]
17
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18
+ * Interpolate with scale_factor can have different behaviors
19
+ * depending on the value of recompute_scale_factor:
20
+ *
21
+ * - With recompute_scale_factor = True (current default behavior):
22
+ * the scale_factor, when provided by the user, are used to calculate
23
+ * the output size. The input size and the computed output_size
24
+ * are then used to infer new values for the scales which are
25
+ * used in the interpolation. Because floating-point math is not exact,
26
+ * this may be a different value from the user-supplied scales.
27
+ *
28
+ * - With recompute_scale_factor = False (which will be the default
29
+ * behavior starting 1.5.0):
30
+ * the behavior follows opencv logic, and the scales provided by
31
+ * the user are the ones used in the interpolation calculations.
32
+ *
33
+ * If the scales are not provided or if they are provided but
34
+ * recompute_scale_factor is set to True (default behavior), the scales
35
+ * are computed from the input and the output size;
36
+ *
37
+ *
38
+ * When the scales are inferred from the input and output sizes,
39
+ * we view each pixel as an area, idx + 0.5 as its center index.
40
+ * Here is an example formula in 1D case.
41
+ * if align_corners: center of two corner pixel areas are preserved,
42
+ * (0.5, 0.5) -> (0.5, 0.5),
43
+ * (input_size - 0.5, 0.5) -> (output_size - 0.5)
44
+ * scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
45
+ * src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
46
+ * if not align_corners: the whole range is scaled accordingly
47
+ * scale = input_size / output_size
48
+ * src_idx + 0.5 = scale * (dst_index + 0.5)
49
+ */
50
+
51
+ namespace at::native {
52
+
53
+ namespace upsample {
54
+
55
+ TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
56
+ c10::IntArrayRef input_size, // Full input tensor size.
57
+ at::OptionalIntArrayRef output_size,
58
+ c10::optional<c10::ArrayRef<double>> scale_factors);
59
+
60
+ inline c10::optional<double> get_scale_value(c10::optional<c10::ArrayRef<double>> scales, int idx) {
61
+ if (!scales) {
62
+ return c10::nullopt;
63
+ }
64
+ return scales->at(idx);
65
+ }
66
+
67
+ } // namespace upsample
68
+
69
+ using scale_t = c10::optional<double>;
70
+ using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
71
+ using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
72
+ using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
73
+ using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
74
+ using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
75
+ using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
76
+ using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w);
77
+ using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
78
+ using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
79
+ using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w);
80
+ using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
81
+ using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
82
+ DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel);
83
+ DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel);
84
+ DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel);
85
+ DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel);
86
+ DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel);
87
+ DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel);
88
+ DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel);
89
+ DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel);
90
+ DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel);
91
+ DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel);
92
+ DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel);
93
+ DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel);
94
+ DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel);
95
+ DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel);
96
+ DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel);
97
+ DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel);
98
+ DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel);
99
+ DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel);
100
+ DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel);
101
+ DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel);
102
+ DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel);
103
+ DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel);
104
+ DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel);
105
+
106
+ static C10_UNUSED std::array<int64_t, 3> upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
107
+ TORCH_CHECK(
108
+ output_size.size() == 1,
109
+ "It is expected output_size equals to 1, but got size ",
110
+ output_size.size());
111
+
112
+ TORCH_CHECK(
113
+ input_size.size() == 3,
114
+ "It is expected input_size equals to 3, but got size ",
115
+ input_size.size());
116
+
117
+ int64_t output_width = output_size[0];
118
+
119
+ int64_t nbatch = input_size[0];
120
+ int64_t channels = input_size[1];
121
+ int64_t input_width = input_size[2];
122
+
123
+ TORCH_CHECK(
124
+ input_width > 0 && output_width > 0,
125
+ "Input and output sizes should be greater than 0, but got input (W: ",
126
+ input_width,
127
+ ") and output (W: ",
128
+ output_width,
129
+ ")");
130
+
131
+ return {nbatch, channels, output_width};
132
+ }
133
+
134
+ static C10_UNUSED std::array<int64_t, 4> upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
135
+ TORCH_CHECK(
136
+ output_size.size() == 2,
137
+ "It is expected output_size equals to 2, but got size ",
138
+ output_size.size());
139
+
140
+ TORCH_CHECK(
141
+ input_size.size() == 4,
142
+ "It is expected input_size equals to 4, but got size ",
143
+ input_size.size());
144
+
145
+ int64_t output_height = output_size[0];
146
+ int64_t output_width = output_size[1];
147
+
148
+ int64_t nbatch = input_size[0];
149
+ int64_t channels = input_size[1];
150
+ int64_t input_height = input_size[2];
151
+ int64_t input_width = input_size[3];
152
+
153
+ TORCH_CHECK(
154
+ input_height > 0 && input_width > 0 && output_height > 0 &&
155
+ output_width > 0,
156
+ "Input and output sizes should be greater than 0,"
157
+ " but got input (H: ",
158
+ input_height,
159
+ ", W: ",
160
+ input_width,
161
+ ") output (H: ",
162
+ output_height,
163
+ ", W: ",
164
+ output_width,
165
+ ")");
166
+
167
+ return {nbatch, channels, output_height, output_width};
168
+ }
169
+
170
+ static C10_UNUSED
171
+ std::array<int64_t, 5> upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
172
+ TORCH_CHECK(
173
+ output_size.size() == 3,
174
+ "It is expected output_size equals to 3, but got size ",
175
+ output_size.size());
176
+
177
+ TORCH_CHECK(
178
+ input_size.size() == 5,
179
+ "It is expected input_size equals to 5, but got size ",
180
+ input_size.size());
181
+
182
+ int64_t output_depth = output_size[0];
183
+ int64_t output_height = output_size[1];
184
+ int64_t output_width = output_size[2];
185
+
186
+ int64_t nbatch = input_size[0];
187
+ int64_t channels = input_size[1];
188
+ int64_t input_depth = input_size[2];
189
+ int64_t input_height = input_size[3];
190
+ int64_t input_width = input_size[4];
191
+
192
+ TORCH_CHECK(
193
+ input_depth > 0 && input_height > 0 && input_width > 0 &&
194
+ output_depth > 0 && output_height > 0 && output_width > 0,
195
+ "Input and output sizes should be greater than 0, but got input (D: ",
196
+ input_depth,
197
+ ", H: ",
198
+ input_height,
199
+ ", W: ",
200
+ input_width,
201
+ ") output (D: ",
202
+ output_depth,
203
+ ", H: ",
204
+ output_height,
205
+ ", W: ",
206
+ output_width,
207
+ ")");
208
+
209
+
210
+ return {nbatch, channels, output_depth, output_height, output_width};
211
+ }
212
+
213
+ static inline void upsample_2d_shape_check(
214
+ const Tensor& input,
215
+ const Tensor& grad_output,
216
+ int64_t nbatch,
217
+ int64_t nchannels,
218
+ int64_t input_height,
219
+ int64_t input_width,
220
+ int64_t output_height,
221
+ int64_t output_width) {
222
+ TORCH_CHECK(
223
+ input_height > 0 && input_width > 0 && output_height > 0 &&
224
+ output_width > 0,
225
+ "Input and output sizes should be greater than 0,"
226
+ " but got input (H: ",
227
+ input_height,
228
+ ", W: ",
229
+ input_width,
230
+ ") output (H: ",
231
+ output_height,
232
+ ", W: ",
233
+ output_width,
234
+ ")");
235
+
236
+ if (input.defined()) {
237
+ // Allow for empty batch size but not other dimensions
238
+ TORCH_CHECK(
239
+ (input.numel() != 0 ||
240
+ (input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0)
241
+ ) &&
242
+ input.dim() == 4,
243
+ "Non-empty 4D data tensor expected but got a tensor with sizes ",
244
+ input.sizes());
245
+ } else if (grad_output.defined()) {
246
+ check_dim_size(grad_output, 4, 0, nbatch);
247
+ check_dim_size(grad_output, 4, 1, nchannels);
248
+ check_dim_size(grad_output, 4, 2, output_height);
249
+ check_dim_size(grad_output, 4, 3, output_width);
250
+ }
251
+ }
252
+
253
+ template <typename scalar_t>
254
+ static inline scalar_t compute_scales_value(
255
+ const c10::optional<double> scale,
256
+ int64_t input_size,
257
+ int64_t output_size) {
258
+ // see Note [compute_scales_value]
259
+ // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
260
+ return (scale.has_value() && scale.value() > 0.)
261
+ ? static_cast<scalar_t>(1.0 / scale.value())
262
+ : (static_cast<scalar_t>(input_size) / output_size);
263
+ }
264
+
265
+ template <typename scalar_t>
266
+ static inline scalar_t area_pixel_compute_scale(
267
+ int64_t input_size,
268
+ int64_t output_size,
269
+ bool align_corners,
270
+ const c10::optional<double> scale) {
271
+ // see Note [area_pixel_compute_scale]
272
+ if(align_corners) {
273
+ if(output_size > 1) {
274
+ return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
275
+ } else {
276
+ return static_cast<scalar_t>(0);
277
+ }
278
+ } else {
279
+ return compute_scales_value<scalar_t>(scale, input_size, output_size);
280
+ }
281
+ }
282
+
283
+ template <typename scalar_t>
284
+ static inline scalar_t area_pixel_compute_source_index(
285
+ scalar_t scale,
286
+ int64_t dst_index,
287
+ bool align_corners,
288
+ bool cubic) {
289
+ if (align_corners) {
290
+ return scale * dst_index;
291
+ } else {
292
+ scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
293
+ static_cast<scalar_t>(0.5);
294
+ // [Note] Follow Opencv resize logic:
295
+ // We allow negative src_idx here and later will use
296
+ // dx = src_idx - floorf(src_idx)
297
+ // to compute the "distance"(which affects weights).
298
+ // For linear modes, weight distribution doesn't matter
299
+ // for negative indices as they use 2 pixels to interpolate.
300
+ // For example, [-1, 0], they both use pixel 0 value so it
301
+ // doesn't affect if we bound the src_idx to 0 or not.
302
+ // TODO: Our current linear mode impls use unbound indices
303
+ // where we should and then remove this cubic flag.
304
+ // This matters in cubic mode, as we might need [-1, 0, 1, 2]
305
+ // to interpolate and the weights can be affected.
306
+ return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
307
+ : src_idx;
308
+ }
309
+ }
310
+
311
+ static inline int64_t nearest_neighbor_compute_source_index(
312
+ const float scale,
313
+ int64_t dst_index,
314
+ int64_t input_size) {
315
+ // Index computation matching OpenCV INTER_NEAREST
316
+ // which is buggy and kept for BC
317
+ const int64_t src_index =
318
+ std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
319
+ return src_index;
320
+ }
321
+
322
+ static inline int64_t nearest_neighbor_exact_compute_source_index(
323
+ const float scale,
324
+ int64_t dst_index,
325
+ int64_t input_size) {
326
+ // index_f32 = (output_index + 0.5) * scale - 0.5
327
+ // input_index = round(index_f32)
328
+ // Same as Pillow and Scikit-Image/Scipy ndi.zoom
329
+ const int64_t src_index =
330
+ std::min(static_cast<int64_t>(floorf((dst_index + 0.5) * scale)), input_size - 1);
331
+ return src_index;
332
+ }
333
+
334
+ static inline int64_t nearest_idx(
335
+ int64_t output_index,
336
+ int64_t input_size,
337
+ int64_t output_size,
338
+ c10::optional<double> scales) {
339
+ // This method specificly treats cases: output_size == input_size or
340
+ // output_size == 2 * input_size, that we would like to get rid of
341
+ // We keep this method for BC and consider as deprecated.
342
+ // See nearest_exact_idx as replacement
343
+ if (output_size == input_size) {
344
+ // scale_factor = 1, simply copy
345
+ return output_index;
346
+ } else if (output_size == 2 * input_size) {
347
+ // scale_factor = 2, shift input index
348
+ return output_index >> 1;
349
+ } else {
350
+ float scale = compute_scales_value<float>(scales, input_size, output_size);
351
+ return nearest_neighbor_compute_source_index(scale, output_index, input_size);
352
+ }
353
+ }
354
+
355
+ static inline int64_t nearest_exact_idx(
356
+ int64_t output_index,
357
+ int64_t input_size,
358
+ int64_t output_size,
359
+ c10::optional<double> scales) {
360
+ float scale = compute_scales_value<float>(scales, input_size, output_size);
361
+ return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size);
362
+ }
363
+
364
+ // Define a typedef to dispatch to nearest_idx or nearest_exact_idx
365
+ typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, c10::optional<double>);
366
+
367
+ template <typename scalar_t>
368
+ static scalar_t upsample_get_value_bounded(
369
+ scalar_t* data,
370
+ int64_t width,
371
+ int64_t height,
372
+ int64_t x,
373
+ int64_t y) {
374
+ int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
375
+ int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
376
+ return data[access_y * width + access_x];
377
+ }
378
+
379
+ template <typename scalar_t>
380
+ static void upsample_increment_value_bounded(
381
+ scalar_t* data,
382
+ int64_t width,
383
+ int64_t height,
384
+ int64_t x,
385
+ int64_t y,
386
+ scalar_t value) {
387
+ int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
388
+ int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
389
+ data[access_y * width + access_x] += value;
390
+ }
391
+
392
+ // Based on
393
+ // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
394
+ template <typename scalar_t>
395
+ static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
396
+ return ((A + 2) * x - (A + 3)) * x * x + 1;
397
+ }
398
+
399
+ template <typename scalar_t>
400
+ static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
401
+ return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
402
+ }
403
+
404
+ template <typename scalar_t>
405
+ static inline void get_cubic_upsample_coefficients(
406
+ scalar_t coeffs[4],
407
+ scalar_t t) {
408
+ scalar_t A = -0.75;
409
+
410
+ scalar_t x1 = t;
411
+ coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
412
+ coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
413
+
414
+ // opposite coefficients
415
+ scalar_t x2 = 1.0 - t;
416
+ coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
417
+ coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
418
+ }
419
+
420
+ template <typename scalar_t>
421
+ static inline scalar_t cubic_interp1d(
422
+ scalar_t x0,
423
+ scalar_t x1,
424
+ scalar_t x2,
425
+ scalar_t x3,
426
+ scalar_t t) {
427
+ scalar_t coeffs[4];
428
+ get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
429
+
430
+ return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
431
+ }
432
+
433
+ // when `real_input_index` becomes larger than the range the floating point
434
+ // type can accurately represent, the type casting to `int64_t` might exceed
435
+ // `input_size`, causing overflow. So we guard it with `std::min` below.
436
+ template<typename scalar_t, typename opmath_t>
437
+ static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
438
+ input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
439
+ lambda = std::min(
440
+ std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
441
+ static_cast<opmath_t>(1)
442
+ );
443
+ }
444
+
445
+ template<typename scalar_t, typename opmath_t>
446
+ static inline void compute_source_index_and_lambda(
447
+ int64_t& input_index0,
448
+ int64_t& input_index1,
449
+ scalar_t& lambda0,
450
+ scalar_t& lambda1,
451
+ opmath_t ratio,
452
+ int64_t output_index,
453
+ int64_t input_size,
454
+ int64_t output_size,
455
+ bool align_corners) {
456
+ if (output_size == input_size) {
457
+ // scale_factor = 1, simply copy
458
+ input_index0 = output_index;
459
+ input_index1 = output_index;
460
+ lambda0 = static_cast<scalar_t>(1);
461
+ lambda1 = static_cast<scalar_t>(0);
462
+ } else {
463
+ const auto real_input_index =
464
+ area_pixel_compute_source_index<opmath_t>(
465
+ ratio, output_index, align_corners, /*cubic=*/false);
466
+ guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
467
+ int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
468
+ input_index1 = input_index0 + offset;
469
+ lambda0 = static_cast<scalar_t>(1.) - lambda1;
470
+ }
471
+ }
472
+
473
+ // It will not be used by data types other than BFloat16 and Half.
474
+ template <typename scalar_in, typename scalar_out,
475
+ typename std::enable_if_t<!is_reduced_floating_point_v<scalar_out> || !std::is_same<scalar_in, float>::value, int> = 0>
476
+ void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
477
+ TORCH_CHECK((is_reduced_floating_point_v<scalar_out>),
478
+ "Upsample backward only support BFloat16 and Half in the lower precision data types on CPU.")
479
+ TORCH_CHECK((std::is_same<scalar_in, float>::value),
480
+ "Upsample backward should use float as acc buffer for BFloat16 and Half grad input on CPU.")
481
+ return;
482
+ }
483
+
484
+ template <typename scalar_in, typename scalar_out,
485
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_out> && std::is_same<scalar_in, float>::value, int> = 0>
486
+ void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
487
+ using bVec = Vectorized<scalar_out>;
488
+ using fVec = Vectorized<float>;
489
+ int64_t d = 0;
490
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
491
+ bVec gin_bvec = bVec::loadu(gin + d);
492
+ fVec gin_fvec0, gin_fvec1;
493
+ std::tie(gin_fvec0, gin_fvec1) = convert_to_float<scalar_out>(gin_bvec);
494
+ gin_fvec0 += fVec::loadu(buffer_ptr + d);
495
+ gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
496
+ fVec(0).store(buffer_ptr + d);
497
+ fVec(0).store(buffer_ptr + d + fVec::size());
498
+ convert_from_float<scalar_out>(gin_fvec0, gin_fvec1).store(gin + d);
499
+ }
500
+ for (; d < size; d++) {
501
+ gin[d] += buffer_ptr[d];
502
+ buffer_ptr[d] = 0;
503
+ }
504
+ }
505
+
506
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // DON'T include this except from Binary*.cu files. It should not leak into
2
+ // headers.
3
+ #pragma once
4
+ #define TORCH_ASSERT_NO_OPERATORS
5
+ #include <ATen/AccumulateType.h>
6
+ #include <ATen/Dispatch.h>
7
+ #include <ATen/native/BinaryOps.h>
8
+ #include <ATen/native/DispatchStub.h>
9
+ #include <ATen/native/TensorIterator.h>
10
+ #include <c10/cuda/CUDAGuard.h>
11
+ #include <c10/cuda/CUDAMathCompat.h>
12
+ #include <c10/util/TypeSafeSignMath.h>
13
+ #include <ATen/native/cuda/JitLoops.cuh>
14
+ #include <ATen/native/cuda/Loops.cuh>
15
+
16
+ #include <type_traits>
17
+
18
+ namespace at {
19
+ namespace native {
20
+ namespace binary_internal {
21
+
22
+ template <typename scalar_t>
23
+ struct DivFunctor {
24
+ __device__ scalar_t operator()(scalar_t a, scalar_t b) const {
25
+ return a / b;
26
+ }
27
+ };
28
+
29
+ template <typename T>
30
+ struct MulFunctor {
31
+ __device__ T operator()(T a, T b) const {
32
+ return a * b;
33
+ }
34
+ };
35
+
36
+ // Workaround for the error: '*' in boolean context, suggest '&&' instead
37
+ // [-Werror=int-in-bool-context]
38
+ template <>
39
+ struct MulFunctor<bool> {
40
+ __device__ bool operator()(bool a, bool b) const {
41
+ return a && b;
42
+ }
43
+ };
44
+ void div_true_kernel_cuda(TensorIteratorBase& iter);
45
+ void div_trunc_kernel_cuda(TensorIteratorBase& iter);
46
+ } // namespace binary_internal
47
+ } // namespace native
48
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/CompositeRandomAccessorCommon.h>
4
+ #include <thrust/tuple.h>
5
+
6
+ namespace at { namespace native {
7
+
8
+ struct TupleInfoCPU {
9
+ template <typename ...Types>
10
+ using tuple = thrust::tuple<Types...>;
11
+
12
+ template <typename ...Types>
13
+ static constexpr auto tie(Types&... args) noexcept {
14
+ return thrust::tie(args...);
15
+ }
16
+ };
17
+
18
+ template <typename KeyAccessor, typename ValueAccessor>
19
+ using CompositeRandomAccessorCPU =
20
+ CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfoCPU>;
21
+
22
+ template <typename Values, typename References>
23
+ void swap(
24
+ references_holder<Values, References> rh1,
25
+ references_holder<Values, References> rh2
26
+ ) {
27
+ return thrust::swap(rh1.data(), rh2.data());
28
+ }
29
+
30
+ template <int N, typename Values, typename References>
31
+ auto get(references_holder<Values, References> rh) -> decltype(thrust::get<N>(rh.data())) {
32
+ return thrust::get<N>(rh.data());
33
+ }
34
+
35
+ }} // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/Config.h>
2
+ #include <ATen/core/DimVector.h>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <ATen/native/cuda/CuFFTUtils.h>
5
+ #include <ATen/native/utils/ParamsHash.h>
6
+ #include <c10/util/accumulate.h>
7
+ #include <c10/util/irange.h>
8
+
9
+ #include <cufft.h>
10
+ #include <cufftXt.h>
11
+
12
+ #include <limits>
13
+ #include <list>
14
+ #include <sstream>
15
+ #include <stdexcept>
16
+ #include <string>
17
+ #include <unordered_map>
18
+
19
+ namespace at { namespace native { namespace detail {
20
+
21
+ // Enum representing the FFT type
22
+ enum class CuFFTTransformType : int8_t {
23
+ C2C, // Complex-to-complex
24
+ R2C, // Real-to-complex
25
+ C2R, // Complex-to-real
26
+ };
27
+
28
+ // This struct is used to let us easily compute hashes of the
29
+ // parameters.
30
+ // It will be the **key** to the plan cache.
31
+ struct CuFFTParams
32
+ {
33
+ int64_t signal_ndim_; // between 1 and max_rank, i.e., 1 <= signal_ndim <= 3
34
+ // These include additional batch dimension as well.
35
+ int64_t sizes_[max_rank + 1];
36
+ int64_t input_strides_[max_rank + 1];
37
+ int64_t output_strides_[max_rank + 1];
38
+ CuFFTTransformType fft_type_;
39
+ ScalarType value_type_;
40
+
41
+ CuFFTParams() = default;
42
+
43
+ CuFFTParams(IntArrayRef in_strides, IntArrayRef out_strides,
44
+ IntArrayRef signal_sizes, CuFFTTransformType fft_type, ScalarType value_type) {
45
+ // Padding bits must be zeroed for hashing
46
+ memset(this, 0, sizeof(*this));
47
+ signal_ndim_ = signal_sizes.size() - 1;
48
+ fft_type_ = fft_type;
49
+ value_type_ = value_type;
50
+
51
+ TORCH_INTERNAL_ASSERT(in_strides.size() == signal_sizes.size());
52
+ TORCH_INTERNAL_ASSERT(out_strides.size() == signal_sizes.size());
53
+ TORCH_INTERNAL_ASSERT(1 <= signal_ndim_ && signal_ndim_ <= max_rank);
54
+
55
+ std::copy(signal_sizes.cbegin(), signal_sizes.cend(), sizes_);
56
+ std::copy(in_strides.cbegin(), in_strides.cend(), input_strides_);
57
+ std::copy(out_strides.cbegin(), out_strides.cend(), output_strides_);
58
+ }
59
+ };
60
+
61
+ static_assert(std::is_trivial<CuFFTParams>::value, "");
62
+
63
+ // Returns true if the transform type has complex input
64
+ inline bool cufft_complex_input(CuFFTTransformType type) {
65
+ switch (type) {
66
+ case CuFFTTransformType::C2C:
67
+ case CuFFTTransformType::C2R:
68
+ return true;
69
+
70
+ case CuFFTTransformType::R2C:
71
+ return false;
72
+ }
73
+ TORCH_INTERNAL_ASSERT(false);
74
+ }
75
+
76
+ // Returns true if the transform type has complex output
77
+ inline bool cufft_complex_output(CuFFTTransformType type) {
78
+ switch (type) {
79
+ case CuFFTTransformType::C2C:
80
+ case CuFFTTransformType::R2C:
81
+ return true;
82
+
83
+ case CuFFTTransformType::C2R:
84
+ return false;
85
+ }
86
+ TORCH_INTERNAL_ASSERT(false);
87
+ }
88
+
89
+ // Create transform type enum from bools representing if input and output are complex
90
+ inline CuFFTTransformType GetCuFFTTransformType(bool complex_input, bool complex_output) {
91
+ if (complex_input && complex_output) {
92
+ return CuFFTTransformType::C2C;
93
+ } else if (complex_input && !complex_output) {
94
+ return CuFFTTransformType::C2R;
95
+ } else if (!complex_input && complex_output) {
96
+ return CuFFTTransformType::R2C;
97
+ }
98
+ TORCH_INTERNAL_ASSERT(false, "Real to real FFTs are not supported");
99
+ }
100
+
101
+
102
+ class CuFFTHandle {
103
+ ::cufftHandle handle_;
104
+ public:
105
+
106
+ CuFFTHandle() {
107
+ CUFFT_CHECK(cufftCreate(&handle_));
108
+ }
109
+
110
+ ::cufftHandle & get() { return handle_; }
111
+ const ::cufftHandle & get() const { return handle_; }
112
+
113
+ ~CuFFTHandle() {
114
+ // Not using fftDestroy() for rocFFT to work around double freeing of handles
115
+ #if !defined(USE_ROCM)
116
+ cufftDestroy(handle_);
117
+ #endif
118
+ }
119
+ };
120
+
121
+ __forceinline__
122
+ static bool is_pow_of_two(int64_t x) {
123
+ return (x & (x - 1)) == 0;
124
+ }
125
+
126
+ using cufft_size_type = long long int;
127
+
128
+ using CuFFTDimVector = c10::SmallVector<cufft_size_type, at::kDimVectorStaticSize>;
129
+
130
+ // Struct representing a tensor in CuFFT's data layout for planning transforms
131
+ // See NOTE [ cuFFT Embedded Strides ].
132
+ struct CuFFTDataLayout {
133
+ CuFFTDimVector embed;
134
+ cufft_size_type stride, dist;
135
+ bool must_clone, simple;
136
+ };
137
+
138
+ // Returns a cufft embedding for a contiguous signal of the given size.
139
+ // e.g. if the input is cloned, this will be the resulting data layout
140
+ // See NOTE [ cuFFT Embedded Strides ].
141
+ inline CuFFTDataLayout cufft_simple_embed(IntArrayRef sizes, bool onesided) {
142
+ CuFFTDataLayout layout;
143
+ layout.simple = true;
144
+ layout.must_clone = false;
145
+ layout.embed.assign(sizes.cbegin() + 1, sizes.cend());
146
+ if (onesided) {
147
+ layout.embed.back() = sizes.back() / 2 + 1;
148
+ }
149
+ layout.stride = 1;
150
+ layout.dist = 1;
151
+ for (const auto& len : layout.embed) {
152
+ layout.dist *= len;
153
+ }
154
+ return layout;
155
+ }
156
+
157
+ // Convert strides to a CuFFT embedded representation.
158
+ // If strides cannot be embedded, returns a simple layout and sets must_clone flag
159
+ // See NOTE [ cuFFT Embedded Strides ].
160
+ inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bool onesided) {
161
+ const auto signal_ndim = strides.size() - 1;
162
+ CuFFTDataLayout layout;
163
+ auto last_stride = strides[signal_ndim];
164
+ layout.must_clone = (last_stride <= 0);
165
+
166
+ const auto last_dim_size = onesided ?
167
+ sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim];
168
+ const auto signal_numel = c10::multiply_integers(sizes.slice(1, sizes.size() - 2)) * last_dim_size;
169
+
170
+ // Zero stides are not allowed, even if the batch size is one.
171
+ // If that happens just set a dummy case
172
+ if (sizes[0] == 1) {
173
+ layout.dist = signal_numel;
174
+ } else if (strides[0] == 0) {
175
+ layout.must_clone = true;
176
+ } else {
177
+ layout.dist = strides[0];
178
+ }
179
+
180
+ // Calculate the embedding shape, or set must_clone if the strides cannot be embedded
181
+ layout.embed.resize(signal_ndim);
182
+ for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) {
183
+ auto stride = strides[i];
184
+ if (sizes[i] == 1) {
185
+ layout.embed[i] = 1;
186
+ } else if (stride > 0 && stride % last_stride == 0) {
187
+ layout.embed[i] = stride / last_stride;
188
+ last_stride = stride;
189
+ } else {
190
+ layout.must_clone = true;
191
+ }
192
+ }
193
+
194
+ if (layout.must_clone) {
195
+ // If the input needs to be cloned, assume it will be contiguous
196
+ layout = cufft_simple_embed(sizes, onesided);
197
+ layout.must_clone = true;
198
+ } else {
199
+ layout.embed[0] = sizes[1];
200
+ layout.stride = strides[signal_ndim];
201
+ // Determine if layout represents a simple embedding (contiguous data)
202
+ layout.simple = [&] {
203
+ for (const auto i : c10::irange(1, signal_ndim - 1)) {
204
+ if (layout.embed[i] != sizes[i + 1]) {
205
+ return false;
206
+ }
207
+ }
208
+
209
+ return (layout.stride == 1 && layout.dist == signal_numel &&
210
+ layout.embed.back() == last_dim_size);
211
+ }();
212
+ }
213
+ return layout;
214
+ }
215
+
216
+ // This class contains all the information needed to execute a cuFFT plan:
217
+ // 1. the plan
218
+ // 2. whether to clone input before executing the plan
219
+ // 3. the workspace size needed
220
+ //
221
+ // This class will be the **value** in the plan cache.
222
+ // It **owns** the raw plan via a unique_ptr.
223
+ class CuFFTConfig {
224
+ public:
225
+
226
+ // Only move semantics is enought for this class. Although we already use
227
+ // unique_ptr for the plan, still remove copy constructor and assignment op so
228
+ // we don't accidentally copy and take perf hit.
229
+ CuFFTConfig(const CuFFTConfig&) = delete;
230
+ CuFFTConfig& operator=(CuFFTConfig const&) = delete;
231
+
232
+ explicit CuFFTConfig(const CuFFTParams& params):
233
+ CuFFTConfig(
234
+ IntArrayRef(params.input_strides_, params.signal_ndim_ + 1),
235
+ IntArrayRef(params.output_strides_, params.signal_ndim_ + 1),
236
+ IntArrayRef(params.sizes_, params.signal_ndim_ + 1),
237
+ params.fft_type_,
238
+ params.value_type_) {}
239
+
240
+ // For complex types, strides are in units of 2 * element_size(dtype)
241
+ // sizes are for the full signal, including batch size and always two-sided
242
+ CuFFTConfig(IntArrayRef in_strides, IntArrayRef out_strides,
243
+ IntArrayRef sizes, CuFFTTransformType fft_type, ScalarType dtype):
244
+ fft_type_(fft_type), value_type_(dtype) {
245
+
246
+ // signal sizes (excluding batch dim)
247
+ CuFFTDimVector signal_sizes(sizes.begin() + 1, sizes.end());
248
+
249
+ // input batch size
250
+ const int64_t batch = sizes[0];
251
+ const int64_t signal_ndim = sizes.size() - 1;
252
+
253
+ // Since cuFFT has limited non-unit stride support and various constraints, we
254
+ // use a flag to keep track throughout this function to see if we need to
255
+ // input = input.clone();
256
+
257
+ #if defined(USE_ROCM)
258
+ // clone input to avoid issues with hipfft clobering the input and failing tests
259
+ clone_input = true;
260
+ #else
261
+ clone_input = false;
262
+ #endif
263
+
264
+ // For half, base strides on the real part of real-to-complex and
265
+ // complex-to-real transforms are not supported. Since our output is always
266
+ // contiguous, only need to check real-to-complex case.
267
+ if (dtype == ScalarType::Half) {
268
+ // cuFFT on half requires compute capability of at least SM_53
269
+ auto dev_prop = at::cuda::getCurrentDeviceProperties();
270
+ TORCH_CHECK(dev_prop->major >= 5 && !(dev_prop->major == 5 && dev_prop->minor < 3),
271
+ "cuFFT doesn't support signals of half type with compute "
272
+ "capability less than SM_53, but the device containing input half "
273
+ "tensor only has SM_", dev_prop->major, dev_prop->minor);
274
+ for (const auto i : c10::irange(signal_ndim)) {
275
+ TORCH_CHECK(is_pow_of_two(sizes[i + 1]),
276
+ "cuFFT only supports dimensions whose sizes are powers of two when"
277
+ " computing in half precision, but got a signal size of",
278
+ sizes.slice(1));
279
+ }
280
+ clone_input |= in_strides.back() != 1;
281
+ }
282
+
283
+ CuFFTDataLayout in_layout;
284
+ if (clone_input) {
285
+ in_layout = cufft_simple_embed(sizes, fft_type == CuFFTTransformType::C2R);
286
+ } else {
287
+ in_layout = as_cufft_embed(in_strides, sizes, fft_type == CuFFTTransformType::C2R);
288
+ }
289
+ auto out_layout = as_cufft_embed(out_strides, sizes, fft_type == CuFFTTransformType::R2C);
290
+ TORCH_INTERNAL_ASSERT(!out_layout.must_clone, "Out strides cannot be represented as CuFFT embedding");
291
+ clone_input |= in_layout.must_clone;
292
+
293
+ // Check if we can take advantage of simple data layout.
294
+ //
295
+ // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
296
+
297
+ const bool simple_layout = in_layout.simple && out_layout.simple;
298
+ cudaDataType itype, otype, exec_type;
299
+ const auto complex_input = cufft_complex_input(fft_type);
300
+ const auto complex_output = cufft_complex_output(fft_type);
301
+ if (dtype == ScalarType::Float) {
302
+ itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
303
+ otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
304
+ exec_type = CUDA_C_32F;
305
+ } else if (dtype == ScalarType::Double) {
306
+ itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
307
+ otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
308
+ exec_type = CUDA_C_64F;
309
+ } else if (dtype == ScalarType::Half) {
310
+ itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
311
+ otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
312
+ exec_type = CUDA_C_16F;
313
+ } else {
314
+ TORCH_CHECK(false, "cuFFT doesn't support tensor of type: ", dtype);
315
+ }
316
+
317
+ // disable auto allocation of workspace to use THC allocator
318
+ CUFFT_CHECK(cufftSetAutoAllocation(plan(), /* autoAllocate */ 0));
319
+
320
+ size_t ws_size_t;
321
+
322
+ // make plan
323
+ if (simple_layout) {
324
+ // If with unit-stride, we tell cuFFT by setting inembed == onembed == NULL.
325
+ // In such case, cuFFT ignores istride, ostride, idist, and odist
326
+ // by assuming istride = ostride = 1.
327
+ //
328
+ // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
329
+ CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
330
+ /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
331
+ /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
332
+ batch, &ws_size_t, exec_type));
333
+ } else {
334
+ CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
335
+ in_layout.embed.data(), in_layout.stride, in_layout.dist, itype,
336
+ out_layout.embed.data(), out_layout.stride, out_layout.dist, otype,
337
+ batch, &ws_size_t, exec_type));
338
+ }
339
+ ws_size = static_cast<int64_t>(ws_size_t);
340
+ }
341
+
342
+ const cufftHandle &plan() const { return plan_ptr.get(); }
343
+
344
+ CuFFTTransformType transform_type() const { return fft_type_; }
345
+ ScalarType data_type() const { return value_type_; }
346
+ bool should_clone_input() const { return clone_input; }
347
+ int64_t workspace_size() const { return ws_size; }
348
+
349
+ private:
350
+ CuFFTHandle plan_ptr;
351
+ bool clone_input;
352
+ int64_t ws_size;
353
+ CuFFTTransformType fft_type_;
354
+ ScalarType value_type_;
355
+ };
356
+
357
+ #if defined(USE_ROCM)
358
+ // Note that the max plan number for CUDA version < 10 has to be 1023
359
+ // due to a bug that fails on the 1024th plan
360
+ constexpr int64_t CUFFT_MAX_PLAN_NUM = 1023;
361
+ constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM;
362
+ #else
363
+ constexpr int64_t CUFFT_MAX_PLAN_NUM = std::numeric_limits<int64_t>::max();
364
+ // The default max cache size chosen for CUDA version > 10 is arbitrary.
365
+ // This number puts a limit on how big of a plan cache should we maintain by
366
+ // default. Users can always configure it via cufft_set_plan_cache_max_size.
367
+ constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = 4096;
368
+ #endif
369
+ static_assert(0 <= CUFFT_MAX_PLAN_NUM && CUFFT_MAX_PLAN_NUM <= std::numeric_limits<int64_t>::max(),
370
+ "CUFFT_MAX_PLAN_NUM not in size_t range");
371
+ static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 && CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM,
372
+ "CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range");
373
+
374
+ // This cache assumes that the mapping from key to value never changes.
375
+ // This is **NOT** thread-safe. Please use a mutex when using it **AND** the
376
+ // value returned from try_emplace_value.
377
+ // The contract of using this cache is that try_emplace_value should only be
378
+ // used when the max_size is positive.
379
+ class CuFFTParamsLRUCache {
380
+ public:
381
+ using kv_t = typename std::pair<CuFFTParams, CuFFTConfig>;
382
+ using map_t = typename std::unordered_map<std::reference_wrapper<CuFFTParams>,
383
+ typename std::list<kv_t>::iterator,
384
+ ParamsHash<CuFFTParams>,
385
+ ParamsEqual<CuFFTParams>>;
386
+ using map_kkv_iter_t = typename map_t::iterator;
387
+
388
+
389
+ CuFFTParamsLRUCache() : CuFFTParamsLRUCache(CUFFT_DEFAULT_CACHE_SIZE) {}
390
+
391
+ CuFFTParamsLRUCache(int64_t max_size) {
392
+ _set_max_size(max_size);
393
+ }
394
+
395
+ CuFFTParamsLRUCache(CuFFTParamsLRUCache&& other) noexcept :
396
+ _usage_list(std::move(other._usage_list)),
397
+ _cache_map(std::move(other._cache_map)),
398
+ _max_size(other._max_size) {}
399
+
400
+ CuFFTParamsLRUCache& operator=(CuFFTParamsLRUCache&& other) noexcept {
401
+ _usage_list = std::move(other._usage_list);
402
+ _cache_map = std::move(other._cache_map);
403
+ _max_size = other._max_size;
404
+ return *this;
405
+ }
406
+
407
+ // If key is in this cache, return the cached config. Otherwise, emplace the
408
+ // config in this cache and return it.
409
+ // Return const reference because CuFFTConfig shouldn't be tampered with once
410
+ // created.
411
+ const CuFFTConfig &lookup(CuFFTParams params) {
412
+ AT_ASSERT(_max_size > 0);
413
+
414
+ map_kkv_iter_t map_it = _cache_map.find(params);
415
+ // Hit, put to list front
416
+ if (map_it != _cache_map.end()) {
417
+ _usage_list.splice(_usage_list.begin(), _usage_list, map_it->second);
418
+ return map_it->second->second;
419
+ }
420
+
421
+ // Miss
422
+ // remove if needed
423
+ if (_usage_list.size() >= _max_size) {
424
+ auto last = _usage_list.end();
425
+ last--;
426
+ _cache_map.erase(last->first);
427
+ _usage_list.pop_back();
428
+ }
429
+
430
+ // construct new plan at list front, then insert into _cache_map
431
+ _usage_list.emplace_front(std::piecewise_construct,
432
+ std::forward_as_tuple(params),
433
+ std::forward_as_tuple(params));
434
+ auto kv_it = _usage_list.begin();
435
+ _cache_map.emplace(std::piecewise_construct,
436
+ std::forward_as_tuple(kv_it->first),
437
+ std::forward_as_tuple(kv_it));
438
+ return kv_it->second;
439
+ }
440
+
441
+ void clear() {
442
+ _cache_map.clear();
443
+ _usage_list.clear();
444
+ }
445
+
446
+ void resize(int64_t new_size) {
447
+ _set_max_size(new_size);
448
+ auto cur_size = _usage_list.size();
449
+ if (cur_size > _max_size) {
450
+ auto delete_it = _usage_list.end();
451
+ for (size_t i = 0; i < cur_size - _max_size; i++) {
452
+ delete_it--;
453
+ _cache_map.erase(delete_it->first);
454
+ }
455
+ _usage_list.erase(delete_it, _usage_list.end());
456
+ }
457
+ }
458
+
459
+ size_t size() const { return _cache_map.size(); }
460
+
461
+ size_t max_size() const noexcept { return _max_size; }
462
+
463
+ std::mutex mutex;
464
+
465
+ private:
466
+ // Only sets size and does value check. Does not resize the data structures.
467
+ void _set_max_size(int64_t new_size) {
468
+ // We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since
469
+ // CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check
470
+ // first.
471
+ TORCH_CHECK(new_size >= 0,
472
+ "cuFFT plan cache size must be non-negative, but got ", new_size);
473
+ TORCH_CHECK(new_size <= CUFFT_MAX_PLAN_NUM,
474
+ "cuFFT plan cache size can not be larger than ", CUFFT_MAX_PLAN_NUM, ", but got ", new_size);
475
+ _max_size = static_cast<size_t>(new_size);
476
+ }
477
+
478
+ std::list<kv_t> _usage_list;
479
+ map_t _cache_map;
480
+ size_t _max_size;
481
+ };
482
+
483
+ // Since ATen is separated into CPU build and CUDA build, we need a way to call
484
+ // these functions only when CUDA is loaded. We use CUDA hooks for this purpose
485
+ // (at cuda/detail/CUDAHooks.cpp), and call the hooked functions from the actual
486
+ // native function counterparts (at native/SpectralOps.cpp), i.e.,
487
+ // _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size
488
+ // _cufft_get_plan_cache_size, and _cufft_clear_plan_cache.
489
+ int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index);
490
+ void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size);
491
+ int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index);
492
+ void cufft_clear_plan_cache_impl(DeviceIndex device_index);
493
+
494
+ }}} // namespace at::native::detail
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at { namespace native {
4
+ #if defined(USE_ROCM)
5
+ // take these out when ROCm implements std:: math functions
6
+ #include <math.h>
7
+ template <typename scalar_t>
8
+ static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
9
+
10
+ template <>
11
+ __forceinline__ __device__ float device_sqrt(float val) {
12
+ return ::sqrtf(val);
13
+ }
14
+
15
+ template <>
16
+ __forceinline__ __device__ double device_sqrt(double val) {
17
+ return ::sqrt(val);
18
+ }
19
+ #else
20
+ template<typename scalar_t>
21
+ __forceinline__ __device__ double device_sqrt(scalar_t val) {
22
+ return std::sqrt(val);
23
+ }
24
+ #endif
25
+ }}
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/AccumulateType.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/Dispatch_v2.h>
6
+ #include <ATen/ExpandBase.h>
7
+ #include <ATen/OpMathType.h>
8
+ #include <ATen/native/TensorIterator.h>
9
+ #include <ATen/native/cuda/Loops.cuh>
10
+ #include <c10/util/Half.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+ #include <ATen/cuda/detail/OffsetCalculator.cuh>
14
+ #include <ATen/cuda/CUDAGraphsUtils.cuh>
15
+ #include <ATen/detail/FunctionTraits.h>
16
+ #include <ATen/core/DistributionsHelper.h>
17
+
18
+ #include <curand.h>
19
+ #include <curand_kernel.h>
20
+ #include <curand_philox4x32_x.h>
21
+ #include <cstdint>
22
+ #include <limits>
23
+ #include <utility>
24
+ #include <mutex>
25
+ #include <tuple>
26
+ #include <type_traits>
27
+
28
+ namespace at {
29
+ namespace native {
30
+ namespace {
31
+
32
+ // launch bounds used for kernels utilizing TensorIterator
33
+ const uint32_t block_size_bound = 256;
34
+ const uint32_t grid_size_bound = 4;
35
+ // number of randoms given by distributions like curand_uniform4, curand_uniform2_double
36
+ // used in calculating philox offset.
37
+ const uint32_t curand4_engine_calls = 4;
38
+
39
+ // utility function that calculates proper philox_offset
40
+ // for distributions utilizing TensorIterator. For distributions using
41
+ // TensorIterator, we are using a grid-stride loop with each
42
+ // thread yielding one element per thread. For the edge of the grid-stride
43
+ // loop, if the tensor size is large, the unroll loop will kick in and the float4
44
+ // from curand4 will start getting utilized (for common tensor sizes, we end up
45
+ // using rand.x from each thread). Hence, the philox_offset is
46
+ // (number of elements per thread * number of engine calls), which makes
47
+ // sure that philox offset increment is not less than the number of randoms used
48
+ // in each thread.
49
+ std::tuple<uint64_t, dim3, dim3> calc_execution_policy(int64_t total_elements) {
50
+ const uint64_t numel = static_cast<uint64_t>(total_elements);
51
+ const uint32_t block_size = block_size_bound;
52
+ const uint32_t unroll = curand4_engine_calls;
53
+ dim3 dim_block(block_size);
54
+ dim3 grid((numel + block_size - 1) / block_size);
55
+ uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
56
+ grid.x = std::min(
57
+ static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
58
+ grid.x);
59
+ //number of times random will be generated per thread, to offset philox counter in thc random state
60
+ uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll) + 1)
61
+ * curand4_engine_calls;
62
+ return std::make_tuple(counter_offset, grid, dim_block);
63
+ }
64
+
65
+ // grid stride loop kernel for distributions
66
+ template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
67
+ C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
68
+ __global__ void distribution_elementwise_grid_stride_kernel(int numel,
69
+ PhiloxCudaState philox_args,
70
+ const dist_t dist_func,
71
+ const transform_t transform_func) {
72
+ auto seeds = at::cuda::philox::unpack(philox_args);
73
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
74
+ curandStatePhilox4_32_10_t state;
75
+ curand_init(std::get<0>(seeds),
76
+ idx,
77
+ std::get<1>(seeds),
78
+ &state);
79
+
80
+ int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
81
+ blockDim.x * gridDim.x * unroll_factor;
82
+ for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
83
+ auto rand = dist_func(&state);
84
+ #pragma unroll
85
+ for (int ii = 0; ii < unroll_factor; ii++) {
86
+ int li = linear_index + blockDim.x * gridDim.x * ii;
87
+ if (li < numel) {
88
+ transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
89
+ }
90
+ }
91
+ __syncthreads();
92
+ }
93
+ }
94
+
95
+ /**
96
+ * distribution_nullary_kernel is analogous to gpu_kernel in
97
+ * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses
98
+ * TensorIterator to launch a kernel. However, the differences are
99
+ * - it launches a grid-stride loop based kernel. The kernel is not
100
+ * generic like elementwise_kernel in Loops.cuh and is specialized
101
+ * for the distribution kernels here.
102
+ * - For big size tensors, we can launch multiple kernels recursively
103
+ * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
104
+ * offset calculation is done in this function.
105
+ *
106
+ * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
107
+ * to have grid-stride loop kernel and then use that to launch our distribution
108
+ * kernels? Note that we need a grid-stride loop kernel because, we found by testing
109
+ * that it achieves peak effective bandwidth.
110
+ */
111
+ template<typename scalar_t,
112
+ typename accscalar_t,
113
+ int unroll_factor,
114
+ typename RNG,
115
+ typename dist_t,
116
+ typename transform_t>
117
+ void distribution_nullary_kernel(at::TensorIteratorBase& iter,
118
+ RNG gen,
119
+ const dist_t& dist_func,
120
+ const transform_t transform_func) {
121
+ static_assert(unroll_factor >= 1, "unroll_factor must be >= 1.");
122
+ int64_t numel = iter.numel();
123
+ if (numel == 0) {
124
+ return;
125
+ }
126
+
127
+ auto execution_policy = calc_execution_policy(numel);
128
+ auto counter_offset = std::get<0>(execution_policy);
129
+ auto grid = std::get<1>(execution_policy);
130
+ auto block = std::get<2>(execution_policy);
131
+ PhiloxCudaState rng_engine_inputs;
132
+ {
133
+ // See Note [Acquire lock when using random generators]
134
+ std::lock_guard<std::mutex> lock(gen->mutex_);
135
+ rng_engine_inputs = gen->philox_cuda_state(counter_offset);
136
+ }
137
+
138
+ if (!iter.can_use_32bit_indexing()) {
139
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
140
+ distribution_nullary_kernel<scalar_t, accscalar_t, unroll_factor>(sub_iter,
141
+ gen, dist_func, transform_func);
142
+ }
143
+ return;
144
+ }
145
+
146
+ char* out_data = (char*)iter.data_ptr(0);
147
+
148
+ auto stream = at::cuda::getCurrentCUDAStream();
149
+ if (iter.is_trivial_1d()) {
150
+ auto strides = iter.get_inner_strides();
151
+ int stride0 = strides[0];
152
+ distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
153
+ numel,
154
+ rng_engine_inputs,
155
+ dist_func,
156
+ [=]__device__(int idx, accscalar_t rand) {
157
+ scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
158
+ *out = transform_func(rand);
159
+ }
160
+ );
161
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
162
+ } else {
163
+ auto offset_calc = make_offset_calculator<1>(iter);
164
+ distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
165
+ numel,
166
+ rng_engine_inputs,
167
+ dist_func,
168
+ [=]__device__(int idx, accscalar_t rand) {
169
+ auto offsets = offset_calc.get(idx);
170
+ scalar_t* out = (scalar_t*)&out_data[offsets[0]];
171
+ *out = transform_func(rand);
172
+ }
173
+ );
174
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
175
+ }
176
+ }
177
+
178
+ // Binary kernel
179
+ template <typename func_t, typename inp_offset_calc_t, typename out_offset_calc_t>
180
+ __global__ void distribution_binary_elementwise_kernel(
181
+ int numel,
182
+ func_t f,
183
+ PhiloxCudaState philox_args,
184
+ typename function_traits<func_t>::result_type *output_data,
185
+ const typename function_traits<func_t>::template arg<1>::type *input_data_1,
186
+ const typename function_traits<func_t>::template arg<2>::type *input_data_2,
187
+ inp_offset_calc_t inp_calc,
188
+ out_offset_calc_t out_calc) {
189
+ auto seeds = at::cuda::philox::unpack(philox_args);
190
+
191
+ using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
192
+ using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
193
+
194
+ input_t_1 inputs_1[thread_work_size()];
195
+ input_t_2 inputs_2[thread_work_size()];
196
+
197
+ int base_index = block_work_size() * blockIdx.x;
198
+ int remaining = std::min<int>(numel - base_index, block_work_size());
199
+
200
+ curandStatePhilox4_32_10_t state;
201
+ curand_init(std::get<0>(seeds),
202
+ blockIdx.x * blockDim.x + threadIdx.x,
203
+ std::get<1>(seeds),
204
+ &state);
205
+
206
+ // load data into registers
207
+ int thread_idx = threadIdx.x;
208
+ #pragma unroll
209
+ for (int i = 0; i < thread_work_size(); i++) {
210
+ if (thread_idx >= remaining) {
211
+ break;
212
+ }
213
+ int input_idx = thread_idx + base_index;
214
+ auto offsets = inp_calc.get(input_idx);
215
+ inputs_1[i] = input_data_1[offsets[0]];
216
+ inputs_2[i] = input_data_2[offsets[1]];
217
+
218
+ thread_idx += num_threads();
219
+ }
220
+
221
+ // compute and store
222
+ thread_idx = threadIdx.x;
223
+ #pragma unroll
224
+ for (int i = 0; i < thread_work_size(); i++) {
225
+ if (thread_idx >= remaining) {
226
+ break;
227
+ }
228
+ int input_idx = thread_idx + base_index;
229
+ auto offsets = out_calc.get(input_idx);
230
+ output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]);
231
+ thread_idx += num_threads();
232
+ }
233
+ }
234
+
235
+ template <typename func_t>
236
+ void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) {
237
+ static_assert(std::is_same<typename function_traits<func_t>::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t");
238
+ using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
239
+ using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
240
+ using output_t = typename function_traits<func_t>::result_type;
241
+
242
+ if (!iter.can_use_32bit_indexing()) {
243
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
244
+ distribution_binary_kernel(sub_iter, philox_args, f);
245
+ }
246
+ return;
247
+ }
248
+
249
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());
250
+
251
+ int64_t numel = iter.numel();
252
+ if (numel == 0) {
253
+ return;
254
+ }
255
+
256
+ output_t *output_data = static_cast<output_t *>(iter.data_ptr(0));
257
+ const input_t_1 *input_data_1 = static_cast<const input_t_1 *>(iter.data_ptr(1));
258
+ const input_t_2 *input_data_2 = static_cast<const input_t_2 *>(iter.data_ptr(2));
259
+
260
+ int64_t grid = (numel + block_work_size() - 1) / block_work_size();
261
+ auto stream = at::cuda::getCurrentCUDAStream();
262
+
263
+ if (iter.is_contiguous()) {
264
+ distribution_binary_elementwise_kernel<<<grid,num_threads(), 0, stream>>>(
265
+ numel, f, philox_args, output_data, input_data_1, input_data_2,
266
+ TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>());
267
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
268
+ } else {
269
+ distribution_binary_elementwise_kernel<<<grid, num_threads(), 0, stream>>>(
270
+ numel, f, philox_args, output_data, input_data_1, input_data_2,
271
+ make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter));
272
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
273
+ }
274
+ }
275
+
276
+ } // namespace
277
+ }} // namespace at::native
278
+
279
+
280
+ namespace at {
281
+ namespace native {
282
+ namespace templates {
283
+ namespace cuda {
284
+
285
+ // ==================================================== Random ========================================================
286
+
287
+ template<typename RNG>
288
+ void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
289
+ AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
290
+ if ((
291
+ std::is_same<scalar_t, int64_t>::value ||
292
+ std::is_same<scalar_t, double>::value ||
293
+ std::is_same<scalar_t, float>::value ||
294
+ std::is_same<scalar_t, at::BFloat16>::value) && range >= 1ULL << 32)
295
+ {
296
+ // define lambda to mod with range and add base
297
+ auto random_func = [range, base] __device__ (uint64_t rand) {
298
+ return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
299
+ };
300
+ distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
301
+ gen,
302
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
303
+ ulonglong2 ret;
304
+ uint4 rand_val = curand4(state);
305
+ ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
306
+ ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
307
+ return ret;
308
+ },
309
+ random_func);
310
+ } else {
311
+ auto random_func = [range, base] __device__ (uint32_t rand) {
312
+ return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
313
+ };
314
+ distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
315
+ gen,
316
+ [] __device__ (curandStatePhilox4_32_10_t* state) {
317
+ return curand4(state);
318
+ },
319
+ random_func);
320
+ }
321
+ }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
322
+ }
323
+
324
+ // This is the special kernel to handle single specific case:
325
+ // from(inclusive) = std::numeric_limits<int64_t>::lowest()
326
+ // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
327
+ template<typename RNG>
328
+ void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) {
329
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] {
330
+ if (std::is_same<scalar_t, int64_t>::value ||
331
+ std::is_same<scalar_t, double>::value ||
332
+ std::is_same<scalar_t, float>::value ||
333
+ std::is_same<scalar_t, at::BFloat16>::value) {
334
+ auto random_func = [] __device__ (uint64_t rand) {
335
+ return transformation::uniform_int_full_range<scalar_t>(rand);
336
+ };
337
+ distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
338
+ gen,
339
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
340
+ ulonglong2 ret;
341
+ uint4 rand_val = curand4(state);
342
+ ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
343
+ ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
344
+ return ret;
345
+ },
346
+ random_func);
347
+ } else {
348
+ TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16");
349
+ }
350
+ });
351
+ }
352
+
353
+ template<typename RNG>
354
+ struct RandomFromToKernel {
355
+ void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, c10::optional<Generator> gen) {
356
+ random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
357
+ }
358
+ void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
359
+ random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
360
+ }
361
+ };
362
+
363
+ template<typename RNG>
364
+ void random_kernel(TensorIteratorBase& iter, RNG gen) {
365
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
366
+ if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
367
+ auto random_func = [] __device__ (uint64_t rand) {
368
+ return transformation::uniform_int<scalar_t>(rand);
369
+ };
370
+ distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter, gen,
371
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
372
+ ulonglong2 ret;
373
+ uint4 rand_val = curand4(state);
374
+ ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
375
+ ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
376
+ return ret;
377
+ },
378
+ random_func);
379
+ } else {
380
+ auto random_func = [] __device__ (uint32_t rand) {
381
+ return transformation::uniform_int<scalar_t>(rand);
382
+ };
383
+ distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
384
+ gen,
385
+ [] __device__ (curandStatePhilox4_32_10_t* state) {
386
+ return curand4(state);
387
+ },
388
+ random_func);
389
+ }
390
+ });
391
+ }
392
+
393
+ template<typename RNG>
394
+ struct RandomKernel {
395
+ void operator()(TensorIteratorBase& iter, RNG gen) {
396
+ random_kernel(iter, gen);
397
+ }
398
+ };
399
+
400
+ // ====================================================================================================================
401
+
402
+ template<typename scalar_t, typename accscalar_t, size_t curand4_engine_calls, typename RNG, typename transform_t>
403
+ void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
404
+ if (std::is_same<scalar_t, double>::value) {
405
+ distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
406
+ gen,
407
+ [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
408
+ transform);
409
+ } else {
410
+ distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
411
+ gen,
412
+ [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
413
+ transform);
414
+ }
415
+ }
416
+
417
+ template<typename scalar_t, typename accscalar_t, size_t curand4_engine_calls, typename RNG, typename transform_t>
418
+ void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
419
+ if (std::is_same<scalar_t, double>::value) {
420
+ distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
421
+ gen,
422
+ [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal2_double(state); },
423
+ transform);
424
+ } else {
425
+ distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
426
+ gen,
427
+ [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal4(state); },
428
+ transform);
429
+ }
430
+ }
431
+
432
+ // ==================================================== Normal ========================================================
433
+
434
+ template<typename RNG>
435
+ void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) {
436
+ auto iter = TensorIterator::borrowing_nullary_op(self);
437
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] {
438
+ using accscalar_t = at::acc_type<scalar_t, true>;
439
+ auto mean = static_cast<accscalar_t>(mean_);
440
+ auto std = static_cast<accscalar_t>(std_);
441
+ // define lambda to multiply std and add mean
442
+ auto normal_func = [mean, std] __device__ (accscalar_t rand) {
443
+ return static_cast<scalar_t>(transformation::normal<accscalar_t>(rand, mean, std));
444
+ };
445
+ normal_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, normal_func);
446
+ });
447
+ }
448
+
449
+ template<typename RNG>
450
+ struct NormalKernel {
451
+ void operator()(const TensorBase &self, double mean, double std, c10::optional<Generator> gen) {
452
+ normal_kernel(self, mean, std, check_generator<RNG>(gen));
453
+ }
454
+ };
455
+
456
+ // ==================================================== Uniform ========================================================
457
+
458
+ template<typename RNG>
459
+ void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) {
460
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
461
+ auto from = static_cast<scalar_t>(from_);
462
+ auto to = static_cast<scalar_t>(to_);
463
+ using opmath_t = at::opmath_type<scalar_t>;
464
+ auto range = static_cast<opmath_t>(to-from);
465
+ // define lambda to reverse bounds, multiply 'range' and add 'from_'
466
+ auto uniform_func = [range, from, to] __device__ (opmath_t rand) {
467
+ // Compute output value before reversing the bounds
468
+ // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947
469
+ auto value = static_cast<scalar_t>(rand * range + from);
470
+ // reverse the bounds of curand4 from (0, 1] to [0, 1)
471
+ // Note that this method is from legacy THCTensorRandom and is likely to give
472
+ // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
473
+ // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
474
+ // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
475
+ auto reverse_bound_value = value == to ? from : value;
476
+ return reverse_bound_value;
477
+ };
478
+ uniform_and_transform<scalar_t, opmath_t, curand4_engine_calls>(iter, gen, uniform_func);
479
+ });
480
+ }
481
+
482
+ template<typename RNG>
483
+ struct UniformKernel {
484
+ void operator()(TensorIteratorBase& iter, double from, double to, c10::optional<Generator> gen) {
485
+ uniform_kernel(iter, from, to, check_generator<RNG>(gen));
486
+ }
487
+ };
488
+
489
+ // ================================================== LogNormal =======================================================
490
+
491
+ template<typename RNG>
492
+ void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) {
493
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
494
+ using accscalar_t = at::acc_type<scalar_t, true>;
495
+ auto mean = static_cast<accscalar_t>(mean_);
496
+ auto std = static_cast<accscalar_t>(std_);
497
+ // define lambda for log_normal transformation
498
+ auto log_normal_func = [mean, std] __device__ (accscalar_t rand) {
499
+ return static_cast<scalar_t>(transformation::log_normal<accscalar_t>(transformation::normal<accscalar_t>(rand, mean, std)));
500
+ };
501
+ normal_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, log_normal_func);
502
+ });
503
+ }
504
+
505
+ template<typename RNG>
506
+ struct LogNormalKernel {
507
+ void operator()(TensorIteratorBase& iter, double mean, double std, c10::optional<Generator> gen) {
508
+ log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
509
+ }
510
+ };
511
+
512
+ // =================================================== Geometric ======================================================
513
+
514
+ template<typename RNG>
515
+ void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) {
516
+ AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
517
+ using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
518
+ // define lambda for geometric transformation
519
+ auto geometric_func = [p] __device__ (accscalar_t rand) {
520
+ return static_cast<scalar_t>(transformation::geometric<accscalar_t>(rand, p));
521
+ };
522
+ uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, geometric_func);
523
+ });
524
+ }
525
+
526
+ template<typename RNG>
527
+ struct GeometricKernel {
528
+ void operator()(TensorIteratorBase& iter, double p, c10::optional<Generator> gen) {
529
+ geometric_kernel(iter, p, check_generator<RNG>(gen));
530
+ }
531
+ };
532
+
533
+ // ================================================== Exponential =====================================================
534
+
535
+ template<typename RNG>
536
+ void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) {
537
+ TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
538
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
539
+ using accscalar_t = at::acc_type<scalar_t, true>;
540
+ auto lambda = static_cast<accscalar_t>(lambda_);
541
+ // define lambda for exponential transformation
542
+ auto exponential_func = [lambda] __device__ (accscalar_t rand) {
543
+ return static_cast<scalar_t>(transformation::exponential<accscalar_t>(rand, lambda));
544
+ };
545
+ uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, exponential_func);
546
+ });
547
+ }
548
+
549
+ template<typename RNG>
550
+ struct ExponentialKernel {
551
+ void operator()(TensorIteratorBase& iter, double lambda, c10::optional<Generator> gen) {
552
+ exponential_kernel(iter, lambda, check_generator<RNG>(gen));
553
+ }
554
+ };
555
+
556
+ // ==================================================== Cauchy ========================================================
557
+
558
+ template<typename RNG>
559
+ void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) {
560
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
561
+ using accscalar_t = at::acc_type<scalar_t, true>;
562
+ auto median = static_cast<accscalar_t>(median_);
563
+ auto sigma = static_cast<accscalar_t>(sigma_);
564
+ // define lambda for cauchy transformation
565
+ auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
566
+ return static_cast<scalar_t>(transformation::cauchy<accscalar_t>(rand, median, sigma));
567
+ };
568
+ uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, cauchy_func);
569
+ });
570
+ }
571
+
572
+ template<typename RNG>
573
+ struct CauchyKernel {
574
+ void operator()(TensorIteratorBase& iter, double median, double sigma, c10::optional<Generator> gen) {
575
+ cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
576
+ }
577
+ };
578
+
579
+ // ==================================================== Bernoulli =====================================================
580
+
581
+ template<typename scalar_t, typename prob_t>
582
+ void bernoulli_tensor_cuda_kernel(
583
+ const TensorBase &ret, const at::TensorBase &p,
584
+ PhiloxCudaState philox_args) {
585
+ auto functor = [philox_args] __device__(
586
+ int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
587
+ const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
588
+ auto seeds = at::cuda::philox::unpack(philox_args);
589
+ curandStatePhilox4_32_10_t state;
590
+ curand_init(std::get<0>(seeds),
591
+ blockIdx.x * blockDim.x + threadIdx.x,
592
+ std::get<1>(seeds),
593
+ &state);
594
+
595
+ // See Note [Register spilling in curand call for CUDA < 10]
596
+ float4 rand = curand_uniform4(&state);
597
+ switch (n) {
598
+ case 4: {
599
+ CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
600
+ v4 = static_cast<scalar_t>(rand.w <= p4);
601
+ // fallthrough
602
+ }
603
+ case 3: {
604
+ CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
605
+ v3 = static_cast<scalar_t>(rand.z <= p3);
606
+ // fallthrough
607
+ }
608
+ case 2: {
609
+ CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
610
+ v2 = static_cast<scalar_t>(rand.y <= p2);
611
+ // fallthrough
612
+ }
613
+ case 1: {
614
+ CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
615
+ v1 = static_cast<scalar_t>(rand.x <= p1);
616
+ }
617
+ }
618
+ };
619
+ // The template argument `4` below indicates that we want to operate on four
620
+ // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
621
+ at::cuda::CUDA_tensor_apply2<scalar_t, prob_t, 4, decltype(functor),
622
+ /*max_threads_per_block=*/512,
623
+ /*min_blocks_per_sm==*/2>(ret, p, functor);
624
+ }
625
+
626
+ template<typename RNG>
627
+ void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) {
628
+ PhiloxCudaState rng_engine_inputs;
629
+ {
630
+ // See Note [Acquire lock when using random generators]
631
+ std::lock_guard<std::mutex> lock(gen->mutex_);
632
+ rng_engine_inputs = gen->philox_cuda_state(10);
633
+ }
634
+ TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type());
635
+ // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else
636
+ const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
637
+ auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type));
638
+ auto p = expand_inplace(self, p_cuda);
639
+ AT_DISPATCH_ALL_TYPES_AND3(
640
+ at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
641
+ if (std::is_same<scalar_t, double>::value) {
642
+ return bernoulli_tensor_cuda_kernel<double, double>(self, *p, rng_engine_inputs);
643
+ } else {
644
+ return bernoulli_tensor_cuda_kernel<scalar_t, float>(self, *p, rng_engine_inputs);
645
+ }
646
+ });
647
+ }
648
+
649
+ template<typename RNG>
650
+ void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) {
651
+ AT_DISPATCH_ALL_TYPES_AND3(
652
+ at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
653
+ using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
654
+ // define lambda for bernoulli transformation
655
+ auto bernoulli_func = [p] __device__ (accscalar_t rand) {
656
+ return static_cast<scalar_t>(transformation::bernoulli<accscalar_t>(rand, p));
657
+ };
658
+ uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, bernoulli_func);
659
+ });
660
+ }
661
+
662
+ template<typename RNG>
663
+ struct BernoulliKernel {
664
+ void operator()(TensorIteratorBase& iter, double p, c10::optional<Generator> gen) {
665
+ bernoulli_kernel(iter, p, check_generator<RNG>(gen));
666
+ }
667
+ void operator()(const TensorBase &self, const TensorBase &p_, c10::optional<Generator> gen) {
668
+ bernoulli_kernel(self, p_, check_generator<RNG>(gen));
669
+ }
670
+ };
671
+
672
+ }}}}
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at {
4
+ struct CUDAGeneratorImpl;
5
+ struct TensorIteratorBase;
6
+ class TensorBase;
7
+
8
+ namespace native {
9
+
10
+ void launch_poisson_cuda_kernel(
11
+ const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen);
12
+
13
+ void launch_gamma_kernel(
14
+ const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen);
15
+
16
+ void launch_binomial_cuda_kernel(
17
+ TensorIteratorBase &iter, CUDAGeneratorImpl *gen);
18
+
19
+ void launch_dirichlet_kernel(TensorIteratorBase &iter);
20
+
21
+ void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter);
22
+
23
+ void launch_dirichlet_grad_kernel(TensorIteratorBase &iter);
24
+
25
+ }} // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/cuda/Atomic.cuh>
4
+ #include <ATen/cuda/CUDAContext.h>
5
+ #include <ATen/TensorUtils.h>
6
+
7
+ namespace at {
8
+ namespace native {
9
+
10
+ Tensor embedding_backward_cuda_kernel(
11
+ const Tensor &grad,
12
+ const Tensor &orig_indices,
13
+ const Tensor &sorted_indices,
14
+ const Tensor &count,
15
+ int64_t num_weights,
16
+ int padding_idx = -1,
17
+ bool mode_mean = false,
18
+ const Tensor &offset2bag = Tensor(),
19
+ const Tensor &bag_size = Tensor(),
20
+ const Tensor &per_sample_weights = Tensor());
21
+
22
+ }}
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/OpMathType.h>
3
+ #include <ATen/native/ForeachUtils.h>
4
+ #include <ATen/native/cuda/MultiTensorApply.cuh>
5
+ #include <ATen/native/cuda/Pow.cuh>
6
+
7
+ namespace at::native {
8
+
9
+ namespace {
10
+
11
+ // TODO(crcrpar): Handle version bump in codegen.
12
+ // rel:
13
+ // https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
14
+ inline void increment_version(TensorList tensors) {
15
+ for (const auto& t : tensors) {
16
+ t.unsafeGetTensorImpl()->bump_version();
17
+ }
18
+ }
19
+
20
+ // Initializes args and checks if all args are aligned
21
+ template <int depth, typename T>
22
+ __device__ bool init_args(
23
+ T** args,
24
+ TensorListMetadata<depth>& tl,
25
+ const int64_t chunk_idx,
26
+ const int64_t chunk_size,
27
+ const int64_t tensor_loc) {
28
+ bool all_aligned = true;
29
+ for (int i = 0; i < depth; i++) {
30
+ args[i] = (T*)tl.addresses[i][tensor_loc];
31
+ args[i] += chunk_idx * chunk_size;
32
+
33
+ if (!is_aligned(args[i])) {
34
+ all_aligned = false;
35
+ }
36
+ }
37
+ return all_aligned;
38
+ }
39
+
40
+ // Initializes args and checks if all args are aligned
41
+ template <int depth, typename T, typename T2>
42
+ __device__ bool init_args(
43
+ T** args,
44
+ TensorListScalarListMetadata<T2, depth>& tl,
45
+ const int64_t chunk_idx,
46
+ const int64_t chunk_size,
47
+ const int64_t tensor_loc) {
48
+ bool all_aligned = true;
49
+ for (int i = 0; i < depth; i++) {
50
+ args[i] = (T*)tl.addresses[i][tensor_loc];
51
+ args[i] += chunk_idx * chunk_size;
52
+
53
+ if (!is_aligned(args[i])) {
54
+ all_aligned = false;
55
+ }
56
+ }
57
+ return all_aligned;
58
+ }
59
+
60
+ template <int depth, typename T>
61
+ __device__ bool init_args(
62
+ T** args,
63
+ FusedOptimizerTensorListMetadata<depth>& tl,
64
+ const int64_t chunk_idx,
65
+ const int64_t chunk_size,
66
+ const int64_t tensor_loc) {
67
+ bool all_aligned = true;
68
+ for (int i = 0; i < depth; i++) {
69
+ args[i] = (T*)tl.addresses[i][tensor_loc];
70
+ args[i] += chunk_idx * chunk_size;
71
+
72
+ if (!is_aligned(args[i])) {
73
+ all_aligned = false;
74
+ }
75
+ }
76
+ return all_aligned;
77
+ }
78
+
79
+ template <int depth, typename T>
80
+ __device__ void load_args(
81
+ T r_args[][kILP],
82
+ T** args,
83
+ const int64_t i_start,
84
+ const int64_t chunk_size,
85
+ const int64_t n) {
86
+ #pragma unroll
87
+ for (int ii = 0; ii < kILP; ii++) {
88
+ const auto i = i_start + threadIdx.x + ii * blockDim.x;
89
+ for (int r_index = 0; r_index < depth; r_index++) {
90
+ r_args[r_index][ii] = 0;
91
+ if (i < n && i < chunk_size) {
92
+ r_args[r_index][ii] = args[r_index][i];
93
+ }
94
+ }
95
+ }
96
+ }
97
+
98
+ template <typename T>
99
+ __device__ void store_args(
100
+ T* dst,
101
+ T* src,
102
+ const int64_t i_start,
103
+ const int64_t chunk_size,
104
+ const int64_t n) {
105
+ #pragma unroll
106
+ for (int ii = 0; ii < kILP; ii++) {
107
+ const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
108
+ if (i < n && i < chunk_size)
109
+ dst[i] = src[ii];
110
+ }
111
+ }
112
+
113
+ template <int res_arg_index, typename Op, typename T, typename opmath_t>
114
+ __device__ __forceinline__ void binary_op_scalar(
115
+ T r_args[][kILP],
116
+ T** args,
117
+ opmath_t scalar,
118
+ const int64_t n,
119
+ const int64_t chunk_size,
120
+ const bool all_aligned,
121
+ Op op) {
122
+ // to make things simple, we put aligned case in a different code path
123
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
124
+ for (int64_t i_start = threadIdx.x;
125
+ i_start * kILP < n && i_start * kILP < chunk_size;
126
+ i_start += blockDim.x) {
127
+ // load
128
+ load_store(r_args[0], args[0], 0, i_start);
129
+ #pragma unroll
130
+ for (int ii = 0; ii < kILP; ii++) {
131
+ r_args[0][ii] = static_cast<T>(
132
+ op(static_cast<opmath_t>(r_args[0][ii]),
133
+ static_cast<opmath_t>(scalar)));
134
+ }
135
+ // store
136
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
137
+ }
138
+ } else {
139
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
140
+ i_start += blockDim.x * kILP) {
141
+ // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args
142
+ // has depth 1
143
+ load_args<1>(r_args, args, i_start, chunk_size, n);
144
+ #pragma unroll
145
+ for (int ii = 0; ii < kILP; ii++) {
146
+ r_args[0][ii] = static_cast<T>(
147
+ op(static_cast<opmath_t>(r_args[0][ii]),
148
+ static_cast<opmath_t>(scalar)));
149
+ }
150
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
151
+ }
152
+ }
153
+ }
154
+
155
+ template <int res_arg_index, typename Op, typename T, typename opmath_t>
156
+ __device__ __forceinline__ void pointwise_op_scalar(
157
+ T r_args[][kILP],
158
+ T** args,
159
+ opmath_t scalar,
160
+ const int64_t n,
161
+ const int64_t chunk_size,
162
+ const bool all_aligned,
163
+ Op op) {
164
+ // to make things simple, we put aligned case in a different code path
165
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
166
+ for (int64_t i_start = threadIdx.x;
167
+ i_start * kILP < n && i_start * kILP < chunk_size;
168
+ i_start += blockDim.x) {
169
+ // load
170
+ load_store(r_args[0], args[0], 0, i_start);
171
+ load_store(r_args[1], args[1], 0, i_start);
172
+ load_store(r_args[2], args[2], 0, i_start);
173
+ #pragma unroll
174
+ for (int ii = 0; ii < kILP; ii++) {
175
+ r_args[0][ii] = static_cast<T>(
176
+ static_cast<opmath_t>(r_args[0][ii]) +
177
+ scalar *
178
+ op(static_cast<opmath_t>(r_args[1][ii]),
179
+ static_cast<opmath_t>(r_args[2][ii])));
180
+ }
181
+ // store
182
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
183
+ }
184
+ } else {
185
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
186
+ i_start += blockDim.x * kILP) {
187
+ // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args
188
+ // has depth 3
189
+ load_args<3>(r_args, args, i_start, chunk_size, n);
190
+ #pragma unroll
191
+ for (int ii = 0; ii < kILP; ii++) {
192
+ r_args[0][ii] = static_cast<T>(
193
+ static_cast<opmath_t>(r_args[0][ii]) +
194
+ scalar *
195
+ op(static_cast<opmath_t>(r_args[1][ii]),
196
+ static_cast<opmath_t>(r_args[2][ii])));
197
+ }
198
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
199
+ }
200
+ }
201
+ }
202
+
203
+ //
204
+ // Binary Functors
205
+ //
206
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
207
+ struct BinaryOpScalarFunctor {
208
+ using opmath_t = at::opmath_type<T>;
209
+ template <typename Op>
210
+ __device__ __forceinline__ void operator()(
211
+ int chunk_size,
212
+ TensorListMetadata<depth>& tl,
213
+ Op op,
214
+ opmath_t scalar) {
215
+ const int tensor_loc = tl.block_to_tensor[blockIdx.x];
216
+ const int chunk_idx = tl.block_to_chunk[blockIdx.x];
217
+ auto n = tl.numel_for_tensor[tensor_loc];
218
+
219
+ T* args[depth];
220
+ const bool all_aligned =
221
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
222
+ n -= chunk_idx * chunk_size;
223
+ T r_args[r_args_depth][kILP];
224
+
225
+ binary_op_scalar<res_arg_index>(
226
+ r_args, args, scalar, n, chunk_size, all_aligned, op);
227
+ }
228
+ };
229
+
230
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
231
+ struct BinaryOpScalarListFunctor {
232
+ using opmath_t = at::opmath_type<T>;
233
+ template <typename Op>
234
+ __device__ __forceinline__ void operator()(
235
+ int chunk_size,
236
+ TensorListScalarListMetadata<opmath_t, depth>& tl,
237
+ Op op) {
238
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
239
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
240
+ auto n = tl.numel_for_tensor[tensor_loc];
241
+
242
+ T* args[depth];
243
+ const bool all_aligned =
244
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
245
+ opmath_t scalar = tl.scalar_vals[tensor_loc];
246
+ n -= chunk_idx * chunk_size;
247
+ T r_args[r_args_depth][kILP];
248
+
249
+ binary_op_scalar<res_arg_index>(
250
+ r_args, args, scalar, n, chunk_size, all_aligned, op);
251
+ }
252
+ };
253
+
254
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
255
+ struct BinaryOpListAlphaFunctor {
256
+ using opmath_t = at::opmath_type<T>;
257
+ template <typename Op>
258
+ __device__ __forceinline__ void operator()(
259
+ int chunk_size,
260
+ TensorListMetadata<depth>& tl,
261
+ Op op,
262
+ opmath_t alpha) {
263
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
264
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
265
+ auto n = tl.numel_for_tensor[tensor_loc];
266
+
267
+ T* args[depth];
268
+ const bool all_aligned =
269
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
270
+ n -= chunk_idx * chunk_size;
271
+ T r_args[r_args_depth][kILP];
272
+
273
+ // to make things simple, we put aligned case in a different code path
274
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
275
+ for (int64_t i_start = threadIdx.x;
276
+ i_start * kILP < n && i_start * kILP < chunk_size;
277
+ i_start += blockDim.x) {
278
+ // load
279
+ load_store(r_args[0], args[0], 0, i_start);
280
+ load_store(r_args[1], args[1], 0, i_start);
281
+ #pragma unroll
282
+ for (int ii = 0; ii < kILP; ii++) {
283
+ r_args[0][ii] = static_cast<T>(
284
+ op(static_cast<opmath_t>(r_args[0][ii]),
285
+ alpha * static_cast<opmath_t>(r_args[1][ii])));
286
+ }
287
+ // store
288
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
289
+ }
290
+ } else {
291
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
292
+ i_start += blockDim.x * kILP) {
293
+ load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
294
+ #pragma unroll
295
+ for (int ii = 0; ii < kILP; ii++) {
296
+ r_args[0][ii] = static_cast<T>(
297
+ op(static_cast<opmath_t>(r_args[0][ii]),
298
+ alpha * static_cast<opmath_t>(r_args[1][ii])));
299
+ }
300
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
301
+ }
302
+ }
303
+ }
304
+ };
305
+
306
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
307
+ struct BinaryOpScalarTensorFunctor {
308
+ using opmath_t = at::opmath_type<T>;
309
+ template <typename Op>
310
+ __device__ __forceinline__ void operator()(
311
+ int chunk_size,
312
+ TensorListMetadata<depth>& tl,
313
+ Op op,
314
+ T* scalar,
315
+ opmath_t alpha) {
316
+ const int tensor_loc = tl.block_to_tensor[blockIdx.x];
317
+ const int chunk_idx = tl.block_to_chunk[blockIdx.x];
318
+ auto n = tl.numel_for_tensor[tensor_loc];
319
+
320
+ T* args[depth];
321
+ const bool all_aligned =
322
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
323
+ n -= chunk_idx * chunk_size;
324
+ T r_args[r_args_depth][kILP];
325
+
326
+ // to make things simple, we put aligned case in a different code path
327
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
328
+ for (int64_t i_start = threadIdx.x;
329
+ i_start * kILP < n && i_start * kILP < chunk_size;
330
+ i_start += blockDim.x) {
331
+ // load
332
+ load_store(r_args[0], args[0], 0, i_start);
333
+ #pragma unroll
334
+ for (int ii = 0; ii < kILP; ii++) {
335
+ r_args[0][ii] = static_cast<T>(op(
336
+ static_cast<opmath_t>(r_args[0][ii]),
337
+ static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
338
+ }
339
+ // store
340
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
341
+ }
342
+ } else {
343
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
344
+ i_start += blockDim.x * kILP) {
345
+ // Regardless if depth is 1 (for inplace) or 2 (for out of place),
346
+ // r_args has depth 1
347
+ load_args<1>(r_args, args, i_start, chunk_size, n);
348
+ #pragma unroll
349
+ for (int ii = 0; ii < kILP; ii++) {
350
+ r_args[0][ii] = static_cast<T>(op(
351
+ static_cast<opmath_t>(r_args[0][ii]),
352
+ static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
353
+ }
354
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
355
+ }
356
+ }
357
+ }
358
+ };
359
+
360
+ //
361
+ // Unary Functors
362
+ //
363
+
364
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
365
+ struct ZeroFunctor {
366
+ __device__ __forceinline__ void operator()(
367
+ int chunk_size,
368
+ TensorListMetadata<1>& tl) {
369
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
370
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
371
+ auto n = tl.numel_for_tensor[tensor_loc];
372
+
373
+ T* args[depth];
374
+ const auto all_aligned =
375
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
376
+ n -= chunk_idx * chunk_size;
377
+ T r_args[r_args_depth][kILP];
378
+
379
+ // to make things simple, we put aligned case in a different code path
380
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
381
+ for (int64_t i_start = threadIdx.x;
382
+ i_start * kILP < n && i_start * kILP < chunk_size;
383
+ i_start += blockDim.x) {
384
+ #pragma unroll
385
+ for (int ii = 0; ii < kILP; ii++) {
386
+ r_args[0][ii] = 0;
387
+ }
388
+ // store
389
+ load_store(args[0], r_args[0], i_start, 0);
390
+ }
391
+ } else {
392
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
393
+ i_start += blockDim.x * kILP) {
394
+ #pragma unroll
395
+ for (int ii = 0; ii < kILP; ii++) {
396
+ r_args[0][ii] = 0;
397
+ }
398
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
399
+ }
400
+ }
401
+ }
402
+ };
403
+
404
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
405
+ struct UnaryOpFunctor {
406
+ using opmath_t = at::opmath_type<T>;
407
+ template <typename Op>
408
+ __device__ __forceinline__ void operator()(
409
+ int chunk_size,
410
+ TensorListMetadata<depth>& tl,
411
+ Op op) {
412
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
413
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
414
+ auto n = tl.numel_for_tensor[tensor_loc];
415
+
416
+ T* args[depth];
417
+ bool all_aligned =
418
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
419
+ n -= chunk_idx * chunk_size;
420
+ T r_args[r_args_depth][kILP];
421
+
422
+ // to make things simple, we put aligned case in a different code path
423
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
424
+ for (int64_t i_start = threadIdx.x;
425
+ i_start * kILP < n && i_start * kILP < chunk_size;
426
+ i_start += blockDim.x) {
427
+ // load
428
+ load_store(r_args[0], args[0], 0, i_start);
429
+ #pragma unroll
430
+ for (int ii = 0; ii < kILP; ii++) {
431
+ r_args[0][ii] =
432
+ static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
433
+ }
434
+ // store
435
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
436
+ }
437
+ } else {
438
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
439
+ i_start += blockDim.x * kILP) {
440
+ load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
441
+ #pragma unroll
442
+ for (int ii = 0; ii < kILP; ii++) {
443
+ r_args[0][ii] =
444
+ static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
445
+ }
446
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
447
+ }
448
+ }
449
+ }
450
+ };
451
+
452
+ //
453
+ // Pointwise Functors
454
+ //
455
+
456
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
457
+ struct PointwiseOpScalarFunctor {
458
+ using opmath_t = at::opmath_type<T>;
459
+ template <typename Op>
460
+ __device__ __forceinline__ void operator()(
461
+ int chunk_size,
462
+ TensorListMetadata<depth>& tl,
463
+ Op op,
464
+ opmath_t scalar) {
465
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
466
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
467
+ auto n = tl.numel_for_tensor[tensor_loc];
468
+
469
+ T* args[depth];
470
+ const bool all_aligned =
471
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
472
+ n -= chunk_idx * chunk_size;
473
+ T r_args[r_args_depth][kILP];
474
+
475
+ pointwise_op_scalar<res_arg_index>(
476
+ r_args, args, scalar, n, chunk_size, all_aligned, op);
477
+ }
478
+ };
479
+
480
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
481
+ struct PointwiseOpScalarListFunctor {
482
+ using opmath_t = at::opmath_type<T>;
483
+ template <typename Op>
484
+ __device__ __forceinline__ void operator()(
485
+ int chunk_size,
486
+ TensorListScalarListMetadata<opmath_t, depth>& tl,
487
+ Op op) {
488
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
489
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
490
+ auto n = tl.numel_for_tensor[tensor_loc];
491
+
492
+ T* args[depth];
493
+ const bool all_aligned =
494
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
495
+ opmath_t scalar = tl.scalar_vals[tensor_loc];
496
+ n -= chunk_idx * chunk_size;
497
+ T r_args[r_args_depth][kILP];
498
+
499
+ pointwise_op_scalar<res_arg_index>(
500
+ r_args, args, scalar, n, chunk_size, all_aligned, op);
501
+ }
502
+ };
503
+
504
+ template <typename T, int depth>
505
+ struct PointwiseOpListFunctor {
506
+ using opmath_t = at::opmath_type<T>;
507
+ template <typename Op>
508
+ __device__ __forceinline__ void operator()(
509
+ int chunk_size,
510
+ TensorListMetadata<depth>& tl,
511
+ Op op) {
512
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
513
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
514
+ auto n = tl.numel_for_tensor[tensor_loc];
515
+
516
+ T* args[depth];
517
+ const bool all_aligned =
518
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
519
+ n -= chunk_idx * chunk_size;
520
+ T r_args[depth - 1][kILP];
521
+
522
+ // to make things simple, we put aligned case in a different code path
523
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
524
+ for (int64_t i_start = threadIdx.x;
525
+ i_start * kILP < n && i_start * kILP < chunk_size;
526
+ i_start += blockDim.x) {
527
+ // load
528
+ load_store(r_args[0], args[0], 0, i_start);
529
+ load_store(r_args[1], args[1], 0, i_start);
530
+ #pragma unroll
531
+ for (int ii = 0; ii < kILP; ii++) {
532
+ r_args[0][ii] = static_cast<T>(
533
+ op(static_cast<opmath_t>(r_args[0][ii]),
534
+ static_cast<opmath_t>(r_args[1][ii])));
535
+ }
536
+ // store
537
+ load_store(args[2], r_args[0], i_start, 0);
538
+ }
539
+ } else {
540
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
541
+ i_start += blockDim.x * kILP) {
542
+ load_args<depth - 1>(r_args, args, i_start, chunk_size, n);
543
+ #pragma unroll
544
+ for (int ii = 0; ii < kILP; ii++) {
545
+ r_args[0][ii] = static_cast<T>(
546
+ op(static_cast<opmath_t>(r_args[0][ii]),
547
+ static_cast<opmath_t>(r_args[1][ii])));
548
+ }
549
+ store_args(args[2], r_args[0], i_start, chunk_size, n);
550
+ }
551
+ }
552
+ }
553
+ };
554
+
555
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
556
+ struct TernaryOpListFunctor {
557
+ using opmath_t = at::opmath_type<T>;
558
+ template <typename Op>
559
+ __device__ __forceinline__ void operator()(
560
+ int chunk_size,
561
+ TensorListMetadata<depth>& tl,
562
+ Op op) {
563
+ static_assert(depth == 3 || depth == 4, "");
564
+ static_assert(depth >= r_args_depth, "");
565
+ static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
566
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
567
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
568
+ auto n = tl.numel_for_tensor[tensor_loc];
569
+
570
+ T* args[depth];
571
+ const bool all_aligned =
572
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
573
+ n -= chunk_idx * chunk_size;
574
+ T r_args[r_args_depth][kILP];
575
+
576
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
577
+ for (int64_t i_start = threadIdx.x;
578
+ i_start * kILP < n && i_start * kILP < chunk_size;
579
+ i_start += blockDim.x) {
580
+ load_store(r_args[0], args[0], 0, i_start);
581
+ load_store(r_args[1], args[1], 0, i_start);
582
+ load_store(r_args[2], args[2], 0, i_start);
583
+ #pragma unroll
584
+ for (int ii = 0; ii < kILP; ii++) {
585
+ r_args[0][ii] =
586
+ op(static_cast<opmath_t>(r_args[0][ii]),
587
+ static_cast<opmath_t>(r_args[1][ii]),
588
+ static_cast<opmath_t>(r_args[2][ii]));
589
+ }
590
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
591
+ }
592
+ } else {
593
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
594
+ i_start += blockDim.x * kILP) {
595
+ load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
596
+ #pragma unroll
597
+ for (int ii = 0; ii < kILP; ii++) {
598
+ r_args[0][ii] =
599
+ op(static_cast<opmath_t>(r_args[0][ii]),
600
+ static_cast<opmath_t>(r_args[1][ii]),
601
+ static_cast<opmath_t>(r_args[2][ii]));
602
+ }
603
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
604
+ }
605
+ }
606
+ }
607
+ };
608
+
609
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
610
+ struct TernaryOpScalarFunctor {
611
+ using opmath_t = at::opmath_type<T>;
612
+ template <typename Op>
613
+ __device__ __forceinline__ void operator()(
614
+ int chunk_size,
615
+ TensorListMetadata<depth>& tl,
616
+ Op op,
617
+ opmath_t alpha) {
618
+ static_assert(depth == 2 || depth == 3, "");
619
+ static_assert(depth >= r_args_depth, "");
620
+ static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
621
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
622
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
623
+ auto n = tl.numel_for_tensor[tensor_loc];
624
+
625
+ T* args[depth];
626
+ const bool all_aligned =
627
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
628
+ n -= chunk_idx * chunk_size;
629
+ T r_args[r_args_depth][kILP];
630
+
631
+ // to make things simple, we put aligned case in a different code path
632
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
633
+ for (int64_t i_start = threadIdx.x;
634
+ i_start * kILP < n && i_start * kILP < chunk_size;
635
+ i_start += blockDim.x) {
636
+ // load
637
+ load_store(r_args[0], args[0], 0, i_start);
638
+ load_store(r_args[1], args[1], 0, i_start);
639
+ #pragma unroll
640
+ for (int ii = 0; ii < kILP; ii++) {
641
+ r_args[0][ii] =
642
+ op(static_cast<opmath_t>(r_args[0][ii]),
643
+ static_cast<opmath_t>(r_args[1][ii]),
644
+ alpha);
645
+ }
646
+ // store
647
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
648
+ }
649
+ } else {
650
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
651
+ i_start += blockDim.x * kILP) {
652
+ load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
653
+ #pragma unroll
654
+ for (int ii = 0; ii < kILP; ii++) {
655
+ r_args[0][ii] =
656
+ op(static_cast<opmath_t>(r_args[0][ii]),
657
+ static_cast<opmath_t>(r_args[1][ii]),
658
+ alpha);
659
+ }
660
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
661
+ }
662
+ }
663
+ }
664
+ };
665
+
666
+ template <typename T>
667
+ struct power_functor {
668
+ C10_DEVICE T operator()(const T& a, const T& b) const {
669
+ return at::native::pow_(a, b);
670
+ }
671
+ };
672
+
673
+ template <typename T>
674
+ struct reverse_power_functor {
675
+ C10_DEVICE T operator()(const T& a, const T& b) const {
676
+ return at::native::pow_(b, a);
677
+ }
678
+ };
679
+
680
+ } // namespace
681
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/NumericUtils.h>
4
+
5
+ namespace at::native {
6
+
7
+ // std:: does not have clamp functors
8
+ template <typename T>
9
+ struct minimum {
10
+ __device__ T operator()(const T& a, const T& b) const {
11
+ return (_isnan(a) || a < b) ? a : b;
12
+ }
13
+ };
14
+
15
+ template <typename T>
16
+ struct maximum {
17
+ __device__ T operator()(const T& a, const T& b) const {
18
+ return (_isnan(a) || a > b) ? a : b;
19
+ }
20
+ };
21
+
22
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/cuda/KernelUtils.cuh>
3
+ #include <ATen/native/GridSamplerUtils.h>
4
+
5
+ namespace at { namespace native {
6
+
7
+ using detail::GridSamplerInterpolation;
8
+ using detail::GridSamplerPadding;
9
+
10
+ // Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
11
+ // where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
12
+ // if align_corners: -1 and +1 get sent to the centers of the corner pixels
13
+ // -1 --> 0
14
+ // +1 --> (size - 1)
15
+ // scale_factor = (size - 1) / 2
16
+ // if not align_corners: -1 and +1 get sent to the image edges
17
+ // -1 --> -0.5
18
+ // +1 --> (size - 1) + 0.5 == size - 0.5
19
+ // scale_factor = size / 2
20
+ template <typename scalar_t>
21
+ static __forceinline__ __device__
22
+ scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
23
+ if (align_corners) {
24
+ // unnormalize coord from [-1, 1] to [0, size - 1]
25
+ return ((coord + 1.f) / 2) * (size - 1);
26
+ } else {
27
+ // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
28
+ return ((coord + 1.f) * size - 1) / 2;
29
+ }
30
+ }
31
+
32
+ // grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
33
+ // except that it also returns the `d output / d input` via pointer argument
34
+ // `grad_in`.
35
+ // This is useful in the backward pass of grid_sampler.
36
+ template <typename scalar_t>
37
+ static __forceinline__ __device__
38
+ scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size,
39
+ bool align_corners, scalar_t *grad_in) {
40
+ if (align_corners) {
41
+ // unnormalize coord from [-1, 1] to [0, size - 1]
42
+ *grad_in = static_cast<scalar_t>(size - 1) / 2;
43
+ return ((coord + 1.f) / 2) * (size - 1);
44
+ } else {
45
+ // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
46
+ *grad_in = static_cast<scalar_t>(size) / 2;
47
+ return ((coord + 1.f) * size - 1) / 2;
48
+ }
49
+ }
50
+
51
+ // Clips coordinates to between 0 and clip_limit - 1
52
+ template <typename scalar_t>
53
+ static __forceinline__ __device__
54
+ scalar_t clip_coordinates(scalar_t in, int clip_limit) {
55
+ return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));
56
+ }
57
+
58
+ // clip_coordinates_set_grad works similarly to clip_coordinates except that
59
+ // it also returns the `d output / d input` via pointer argument `grad_in`.
60
+ // This is useful in the backward pass of grid_sampler.
61
+ template <typename scalar_t>
62
+ static __forceinline__ __device__
63
+ scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) {
64
+ // Note that it is important for the gradient calculation that borders
65
+ // are considered out of bounds.
66
+ if (in <= static_cast<scalar_t>(0)) {
67
+ *grad_in = static_cast<scalar_t>(0);
68
+ return static_cast<scalar_t>(0);
69
+ } else {
70
+ scalar_t max = static_cast<scalar_t>(clip_limit - 1);
71
+ if (in >= max) {
72
+ *grad_in = static_cast<scalar_t>(0);
73
+ return max;
74
+ } else {
75
+ *grad_in = static_cast<scalar_t>(1);
76
+ return in;
77
+ }
78
+ }
79
+ }
80
+
81
+ // Reflects coordinates until they fall between low and high (inclusive).
82
+ // The bounds are passed as twice their value so that half-integer values
83
+ // can be represented as ints.
84
+ template <typename scalar_t>
85
+ static __forceinline__ __device__
86
+ scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
87
+ if (twice_low == twice_high) {
88
+ return static_cast<scalar_t>(0);
89
+ }
90
+ scalar_t min = static_cast<scalar_t>(twice_low) / 2;
91
+ scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
92
+ in = ::fabs(in - min);
93
+ // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
94
+ scalar_t extra = ::fmod(in, span);
95
+ int flips = static_cast<int>(::floor(in / span));
96
+ if (flips % 2 == 0) {
97
+ return extra + min;
98
+ } else {
99
+ return span - extra + min;
100
+ }
101
+ }
102
+
103
+ // reflect_coordinates_set_grad works similarly to reflect_coordinates except
104
+ // that it also returns the `d output / d input` via pointer argument
105
+ // `grad_in`.
106
+ // This is useful in the backward pass of grid_sampler.
107
+ template <typename scalar_t>
108
+ static __forceinline__ __device__
109
+ scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high,
110
+ scalar_t *grad_in) {
111
+ if (twice_low == twice_high) {
112
+ *grad_in = static_cast<scalar_t>(0);
113
+ return static_cast<scalar_t>(0);
114
+ }
115
+ int grad_in_mult_;
116
+ scalar_t min = static_cast<scalar_t>(twice_low) / 2;
117
+ scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
118
+ in = in - min;
119
+ if (in < static_cast<scalar_t>(0)) {
120
+ grad_in_mult_ = -1;
121
+ in = -in;
122
+ } else {
123
+ grad_in_mult_ = 1;
124
+ }
125
+ // `fmod` returns same sign as `in`, which is positive after the `if` above.
126
+ scalar_t extra = ::fmod(in, span);
127
+ int flips = static_cast<int>(::floor(in / span));
128
+ if (flips % 2 == 0) {
129
+ *grad_in = static_cast<scalar_t>(grad_in_mult_);
130
+ return extra + min;
131
+ } else {
132
+ *grad_in = static_cast<scalar_t>(-grad_in_mult_);
133
+ return span - extra + min;
134
+ }
135
+ }
136
+
137
+ template<typename scalar_t>
138
+ static __forceinline__ __device__
139
+ scalar_t safe_downgrade_to_int_range(scalar_t x){
140
+ // -100.0 does not have special meaning. This is just to make sure
141
+ // it's not within_bounds_2d or within_bounds_3d, and does not cause
142
+ // undefined behavior. See #35506.
143
+ if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast<double>(x)))
144
+ return static_cast<scalar_t>(-100.0);
145
+ return x;
146
+ }
147
+
148
+ template<typename scalar_t>
149
+ static __forceinline__ __device__
150
+ scalar_t compute_coordinates(scalar_t coord, int size,
151
+ GridSamplerPadding padding_mode,
152
+ bool align_corners) {
153
+ if (padding_mode == GridSamplerPadding::Border) {
154
+ // clip coordinates to image borders
155
+ coord = clip_coordinates(coord, size);
156
+ } else if (padding_mode == GridSamplerPadding::Reflection) {
157
+ // reflect coordinates by image borders
158
+ if (align_corners) {
159
+ coord = reflect_coordinates(coord, 0, 2*(size - 1));
160
+ } else {
161
+ coord = reflect_coordinates(coord, -1, 2*size - 1);
162
+ }
163
+ // clip coordinates to image borders
164
+ coord = clip_coordinates(coord, size);
165
+ }
166
+
167
+ coord = safe_downgrade_to_int_range(coord);
168
+ return coord;
169
+ }
170
+
171
+ // Computes the pixel source index value for a grid coordinate
172
+ template <typename scalar_t>
173
+ static __forceinline__ __device__
174
+ scalar_t grid_sampler_compute_source_index(
175
+ scalar_t coord,
176
+ int size,
177
+ GridSamplerPadding padding_mode,
178
+ bool align_corners) {
179
+ coord = grid_sampler_unnormalize(coord, size, align_corners);
180
+ coord = compute_coordinates(coord, size, padding_mode, align_corners);
181
+ return coord;
182
+ }
183
+
184
+ // grid_sampler_compute_source_index_set_grad works similarly to
185
+ // grid_sampler_compute_source_index except that it also returns the
186
+ // `d output / d input` via pointer argument `grad_in`.
187
+ // This is useful in the backward pass of grid_sampler.
188
+ template <typename scalar_t>
189
+ static __forceinline__ __device__
190
+ scalar_t grid_sampler_compute_source_index_set_grad(
191
+ scalar_t coord,
192
+ int size,
193
+ GridSamplerPadding padding_mode,
194
+ bool align_corners,
195
+ scalar_t *grad_in) {
196
+ scalar_t grad_clip, grad_refl;
197
+ coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
198
+ if (padding_mode == GridSamplerPadding::Border) {
199
+ // clip coordinates to image borders
200
+ coord = clip_coordinates_set_grad(coord, size, &grad_clip);
201
+ *grad_in = (*grad_in) * grad_clip;
202
+ } else if (padding_mode == GridSamplerPadding::Reflection) {
203
+ // reflect coordinates by image borders
204
+ if (align_corners) {
205
+ coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
206
+ } else {
207
+ coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
208
+ }
209
+ // clip coordinates to image borders
210
+ coord = clip_coordinates_set_grad(coord, size, &grad_clip);
211
+ *grad_in = (*grad_in) * grad_refl * grad_clip;
212
+ }
213
+
214
+ coord = safe_downgrade_to_int_range(coord);
215
+ return coord;
216
+ }
217
+
218
+ static __forceinline__ __device__
219
+ bool within_bounds_2d(int h, int w, int H, int W) {
220
+ return h >= 0 && h < H && w >= 0 && w < W;
221
+ }
222
+
223
+ static __forceinline__ __device__
224
+ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
225
+ return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
226
+ }
227
+
228
+ template<typename scalar_t>
229
+ static __forceinline__ __device__
230
+ scalar_t get_value_bounded(
231
+ scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
232
+ GridSamplerPadding padding_mode,
233
+ bool align_corners) {
234
+
235
+ x = compute_coordinates(x, W, padding_mode, align_corners);
236
+ y = compute_coordinates(y, H, padding_mode, align_corners);
237
+
238
+ int ix = static_cast<int>(x);
239
+ int iy = static_cast<int>(y);
240
+
241
+ if (within_bounds_2d(iy, ix, H, W)) {
242
+ return data[iy * sH + ix * sW];
243
+ }
244
+ return static_cast<scalar_t>(0);
245
+ }
246
+
247
+ template<typename scalar_t, typename index_t>
248
+ static __forceinline__ __device__
249
+ void safe_add_2d(scalar_t *data, int h, int w,
250
+ int sH, int sW, int H, int W,
251
+ scalar_t delta,
252
+ const index_t NC_offset,
253
+ const index_t memory_span) {
254
+ if (within_bounds_2d(h, w, H, W)) {
255
+ fastAtomicAdd(data,
256
+ NC_offset + h * sH + w * sW,
257
+ memory_span,
258
+ delta,
259
+ true);
260
+ }
261
+ }
262
+
263
+ template<typename scalar_t, typename index_t>
264
+ static __forceinline__ __device__
265
+ void safe_add_3d(scalar_t *data, int d, int h, int w,
266
+ int sD, int sH, int sW, int D, int H, int W,
267
+ scalar_t delta,
268
+ const index_t NC_offset,
269
+ const index_t memory_span) {
270
+ if (within_bounds_3d(d, h, w, D, H, W)) {
271
+ fastAtomicAdd(data,
272
+ NC_offset + d * sD + h * sH + w * sW,
273
+ memory_span,
274
+ delta,
275
+ true);
276
+ }
277
+ }
278
+
279
+ template<typename scalar_t, typename index_t>
280
+ static __forceinline__ __device__
281
+ void add_value_bounded(
282
+ scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
283
+ scalar_t delta,
284
+ GridSamplerPadding padding_mode,
285
+ bool align_corners,
286
+ const index_t NC_offset,
287
+ const index_t memory_span) {
288
+
289
+ x = compute_coordinates(x, W, padding_mode, align_corners);
290
+ y = compute_coordinates(y, H, padding_mode, align_corners);
291
+
292
+ int ix = static_cast<int>(x);
293
+ int iy = static_cast<int>(y);
294
+
295
+ safe_add_2d(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span);
296
+ }
297
+
298
+ // Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
299
+ template<typename scalar_t>
300
+ static __forceinline__ __device__
301
+ void get_cubic_coefficients_grad(
302
+ scalar_t coeffs[4],
303
+ scalar_t t) {
304
+
305
+ // Must be the same as forward calculation in
306
+ // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients
307
+ scalar_t A = -0.75;
308
+
309
+ scalar_t x;
310
+ x = -1 - t; // 1 < x = |-1 - tx| < 2
311
+ coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
312
+ x = -t; // x = |0 - tx| <= 1
313
+ coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
314
+ x = 1 - t; // x = |1 - tx| <= 1
315
+ coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
316
+ x = 2 - t; // 1 < x = |2 - tx| < 2
317
+ coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
318
+ }
319
+
320
+
321
+ }} // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/jit_macros.h>
4
+
5
+ #if AT_USE_JITERATOR()
6
+
7
+ #include <ATen/cuda/CUDAConfig.h>
8
+
9
+ #include <ATen/OpMathType.h>
10
+ #include <ATen/TensorIterator.h>
11
+ #include <ATen/native/TensorIteratorDynamicCasting.h>
12
+
13
+ #include <ATen/native/cuda/MemoryAccess.cuh>
14
+
15
+ #include <ATen/native/cuda/CUDAJitLoops.cuh>
16
+
17
+ namespace at {
18
+ namespace native {
19
+
20
+ /* Note [Jiterator]
21
+ The "jiterator" simply just-in-time compiles the same kernels that
22
+ Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time,
23
+ build size, and initial CUDA context size.
24
+
25
+ By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels.
26
+ This behavior is controlled with two environment variables:
27
+ - USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use
28
+ - PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels
29
+
30
+ The jiterator currently has some limitations, however. It cannot:
31
+ - handle math on complex datatypes
32
+ - handle kernels with scalar parameters
33
+
34
+ These improvements will likely come soon.
35
+
36
+ For examples of how to use the jiterator see the i1 and gcd kernel
37
+ implementations, which pass jittable strings implementing their
38
+ operations instead of the typical CUDA functors.
39
+
40
+ To pass a runtime argument (similar to lambda captures in non-JIT kernels),
41
+ we need to pass to additional arguments to `jitted_gpu_kernel` by value.
42
+ Currently only primitive C++ types used for computation are valid.
43
+ The order of these extra arguments should be same as the order they appear
44
+ in kernel's function signature. (look at polygamma for example)
45
+
46
+ NOTE: One big restriction being that these arguments should be after the
47
+ arguments provided by TensorIterator. Eg. While capturing `n`, where
48
+ `scalar_t x` and `scalar_t y` are provided by TensorIterator,
49
+ * foo(scalar_t x, scalar_t y, int n) works!
50
+ * foo(int n, scalar_t x, scalar_y) doesn't work
51
+ * foo(scalar_t x, int n, scalar_y) doesn't work
52
+
53
+ */
54
+
55
+ // Entrypoint for jitted GPU kernels.
56
+ // Only handles elementwise unary and binary kernels with a
57
+ // common dtype and a single output.
58
+ // NOTE: this assumes the op's iterator has a common_dtype.
59
+ // NOTE: We use std::tuple instead of parameter pack
60
+ // for `extra_args` due to following
61
+ // bug on older versions of clang
62
+ // https://bugs.llvm.org/show_bug.cgi?id=23029
63
+ template <
64
+ char const* name,
65
+ typename return_type,
66
+ typename f_inputs_type,
67
+ int arity,
68
+ typename... Args>
69
+ void jitted_gpu_kernel(
70
+ TensorIteratorBase& iter,
71
+ const std::string& f,
72
+ at::cuda::jit::BinaryFuncVariant scalar_pos =
73
+ at::cuda::jit::BinaryFuncVariant::NoScalar,
74
+ at::opmath_type<f_inputs_type> scalar_val = 0,
75
+ std::tuple<Args...> extra_args = std::make_tuple()) {
76
+ // TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
77
+ // Maybe it could be refactored?
78
+ for (int arg = 0; arg < iter.ntensors(); arg++) {
79
+ TORCH_INTERNAL_ASSERT(
80
+ iter.device(arg).is_cuda(),
81
+ "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
82
+ }
83
+
84
+ if (iter.numel() == 0) {
85
+ return;
86
+ }
87
+
88
+ if (!iter.can_use_32bit_indexing()) {
89
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
90
+ jitted_gpu_kernel<name, return_type, f_inputs_type, arity>(
91
+ sub_iter, f, scalar_pos, scalar_val, extra_args);
92
+ }
93
+
94
+ return;
95
+ }
96
+
97
+ // Computes if dynamic casting is needed
98
+ // Dynamic casting is needed if an input's dtype differs from the common dtype
99
+ // or if the result dtype differs from the output's dtype
100
+ // Note: this is intentionally divergent from calling needs_dynamic_casting,
101
+ // which is more general and inspects a lambda to determine if dynamic
102
+ // casting is needed.
103
+ bool needs_dynamic_casting = false;
104
+
105
+ // Checks output
106
+ const ScalarType return_scalar_type = c10::CppTypeToScalarType<return_type>::value;
107
+ const auto dtype0 = iter.dtype(0);
108
+ if (dtype0 != return_scalar_type) {
109
+ needs_dynamic_casting = true;
110
+ }
111
+
112
+ // Checks input(s)
113
+ const ScalarType inputs_scalar_type = c10::CppTypeToScalarType<f_inputs_type>::value;
114
+ for (auto i = decltype(arity){1}; i < (arity + 1); ++i) {
115
+ const auto dtypei = iter.dtype(i);
116
+ if (dtypei != inputs_scalar_type) {
117
+ needs_dynamic_casting = true;
118
+ break;
119
+ }
120
+ }
121
+ if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) {
122
+ // NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used
123
+ // for computation in the generated code and hence we pass a dummy
124
+ // value of `0`.
125
+ jitted_gpu_kernel_impl<
126
+ /*name*/ name,
127
+ /*return_type=*/return_type,
128
+ /*f_inputs_type=*/f_inputs_type,
129
+ arity,
130
+ at::cuda::jit::BinaryFuncVariant::NoScalar>(
131
+ iter, f, needs_dynamic_casting, /*scalar_val=*/scalar_val, extra_args);
132
+ } else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) {
133
+ jitted_gpu_kernel_impl<
134
+ /*name*/ name,
135
+ /*return_type=*/return_type,
136
+ /*f_inputs_type=*/f_inputs_type,
137
+ arity,
138
+ at::cuda::jit::BinaryFuncVariant::RhsScalar>(
139
+ iter,
140
+ f,
141
+ needs_dynamic_casting,
142
+ scalar_val,
143
+ extra_args);
144
+
145
+ } else {
146
+ jitted_gpu_kernel_impl<
147
+ /*name*/ name,
148
+ /*return_type=*/return_type,
149
+ /*f_inputs_type=*/f_inputs_type,
150
+ arity,
151
+ at::cuda::jit::BinaryFuncVariant::LhsScalar>(
152
+ iter,
153
+ f,
154
+ needs_dynamic_casting,
155
+ scalar_val,
156
+ extra_args);
157
+ }
158
+ }
159
+
160
+ // TODO: support runtime state capture similar to `jitted_gpu_kernel`.
161
+ template <char const *name, typename return_type, typename f_inputs_type>
162
+ void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) {
163
+ TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
164
+ //currently jiterator only handles binary functions where both inputs are of the same type (f_inputs_type)
165
+ using opmath_t = at::opmath_type<f_inputs_type>;
166
+ if (iter.is_cpu_scalar(1)) {
167
+ auto scalar_val = iter.scalar_value<opmath_t>(1);
168
+ iter.remove_operand(1);
169
+ // TODO: When all kernels that use gpu_kernel_with_scalars are
170
+ // ported to structured, this device guard can be deleted. This
171
+ // works around incorrect device guard generation for pre-structured
172
+ // kernels device guards, but structured kernels do it right and
173
+ // we can assume the device is already set correctly
174
+ const OptionalDeviceGuard device_guard(iter.device(1));
175
+ jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val);
176
+ } else if (iter.is_cpu_scalar(2)) {
177
+ auto scalar_val = iter.scalar_value<opmath_t>(2);
178
+ iter.remove_operand(2);
179
+ jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val);
180
+ } else {
181
+ jitted_gpu_kernel<name, return_type, f_inputs_type, 2>(iter, f);
182
+ }
183
+ }
184
+
185
+ }} // at::native
186
+
187
+ #endif // AT_USE_JITERATOR()