koichi12 commited on
Commit
b755ec5
·
verified ·
1 Parent(s): d0084d8

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h +2 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Functions.h +1427 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MemoryOverlap.h +42 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NativeMetaFunctions.h +1303 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NumericUtils.h +203 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/OpaqueTensorImpl.h +187 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Operators.h +1358 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h +17 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelNativeTBB.h +52 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h +34 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h +13 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SparseTensorImpl.h +400 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/StorageUtils.h +49 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorIndexing.h +735 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h +21 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h +113 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TypeDefault.h +30 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Version.h +18 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Atomic.cuh +508 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh +537 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h +138 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGraph.h +92 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh +57 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h +318 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh +15 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh +121 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h +10 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h +54 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h +151 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh +36 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh +124 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh +119 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh +43 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh +116 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/jiterator.h +40 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h +174 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h +379 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h +34 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorConversions.h +26 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/group_norm.h +42 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cdist_forward_native.h +22 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_clamp_min_cpu_dispatch.h +28 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_log_softmax_meta_dispatch.h +25 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_make_per_tensor_quantized_tensor_cpu_dispatch.h +23 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_backward_compositeexplicitautograd_dispatch.h +24 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_backward_native.h +23 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_cpu_dispatch.h +23 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nnpack_spatial_convolution_ops.h +39 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_scaled_dot_product_efficient_attention_backward_ops.h +28 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_serialization_subcmul_ops.h +28 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <c10/util/ArrayRef.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Functions.h ADDED
@@ -0,0 +1,1427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Functions.h
4
+
5
+ #ifdef TORCH_ASSERT_NO_OPERATORS
6
+ #error This change adds a dependency on native_functions.yaml, \
7
+ meaning the file will need to be re-compiled every time an operator \
8
+ is changed or added. Consider if your change would be better placed in \
9
+ another file, or if a more specific header might achieve the same goal. \
10
+ See NOTE: [Tensor vs. TensorBase]
11
+ #endif
12
+
13
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
14
+ #error This change adds a dependency on all pytorch operators, meaning the \
15
+ file will need to be re-compiled every time an operator is changed or added. \
16
+ Consider including a specific operator from <ATen/ops/{my_operator}.h> and \
17
+ see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
18
+ #endif
19
+
20
+ // NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS]
21
+ //
22
+ // In ATen, certain generated headers files include the definitions of
23
+ // every single operator in PyTorch. Unfortunately this means every
24
+ // time an operator signature is updated or changed in
25
+ // native_functions.yaml, you (and every other PyTorch developer) need
26
+ // to recompile every source file that includes any of these headers.
27
+ //
28
+ // To break up these header dependencies, and improve incremental
29
+ // build times for all PyTorch developers. These headers are split
30
+ // into per-operator headers in the `ATen/ops` folder. This limits
31
+ // incremental builds to only changes to methods of `Tensor`, or files
32
+ // that use the specific operator being changed. With `at::sum` as an
33
+ // example, you should include
34
+ //
35
+ // <ATen/ops/sum.h> // instead of ATen/Functions.h
36
+ // <ATen/ops/sum_native.h> // instead of ATen/NativeFunctions.h
37
+ // <ATen/ops/sum_ops.h> // instead of ATen/Operators.h
38
+ // <ATen/ops/sum_cpu_dispatch.h> // instead of ATen/CPUFunctions.h
39
+ //
40
+ // However, even if you're careful to use this in your own code.
41
+ // `Functions.h` might be included indirectly through another header
42
+ // without you realising. To avoid this, you can add
43
+ //
44
+ // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
45
+ //
46
+ // to the top of your source file. This way any time the non-specific
47
+ // headers are included, the compiler will error out.
48
+ //
49
+ // Also, be aware that `ops` are not available in all build
50
+ // configurations (namely fb-internal) so you must guard these
51
+ // includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g.
52
+ //
53
+ // #ifndef AT_PER_OPERATOR_HEADERS
54
+ // #include <ATen/Functions.h>
55
+ // #else
56
+ // #include <ATen/ops/sum.h>
57
+ // #endif
58
+
59
+ #include <ATen/Context.h>
60
+ #include <ATen/DeviceGuard.h>
61
+ #include <ATen/TensorUtils.h>
62
+ #include <ATen/TracerMode.h>
63
+ #include <ATen/core/Generator.h>
64
+ #include <ATen/core/Reduction.h>
65
+ #include <c10/core/SymInt.h>
66
+ #include <ATen/core/Tensor.h>
67
+ #include <c10/core/Scalar.h>
68
+ #include <c10/core/Storage.h>
69
+ #include <c10/core/TensorOptions.h>
70
+ #include <c10/util/Deprecated.h>
71
+ #include <c10/util/Optional.h>
72
+ #include <c10/util/OptionalArrayRef.h>
73
+
74
+ #include <ATen/ops/from_blob.h>
75
+ #include <ATen/ops/tensor.h>
76
+
77
+ #include <ATen/ops/_adaptive_avg_pool2d.h>
78
+ #include <ATen/ops/_adaptive_avg_pool2d_backward.h>
79
+ #include <ATen/ops/_adaptive_avg_pool3d.h>
80
+ #include <ATen/ops/_adaptive_avg_pool3d_backward.h>
81
+ #include <ATen/ops/_add_batch_dim.h>
82
+ #include <ATen/ops/_add_relu.h>
83
+ #include <ATen/ops/_addmm_activation.h>
84
+ #include <ATen/ops/_aminmax.h>
85
+ #include <ATen/ops/_amp_foreach_non_finite_check_and_unscale.h>
86
+ #include <ATen/ops/_amp_update_scale.h>
87
+ #include <ATen/ops/_assert_async.h>
88
+ #include <ATen/ops/_assert_scalar.h>
89
+ #include <ATen/ops/_assert_tensor_metadata.h>
90
+ #include <ATen/ops/_autocast_to_full_precision.h>
91
+ #include <ATen/ops/_autocast_to_reduced_precision.h>
92
+ #include <ATen/ops/_backward.h>
93
+ #include <ATen/ops/_batch_norm_impl_index.h>
94
+ #include <ATen/ops/_batch_norm_impl_index_backward.h>
95
+ #include <ATen/ops/_cast_Byte.h>
96
+ #include <ATen/ops/_cast_Char.h>
97
+ #include <ATen/ops/_cast_Double.h>
98
+ #include <ATen/ops/_cast_Float.h>
99
+ #include <ATen/ops/_cast_Half.h>
100
+ #include <ATen/ops/_cast_Int.h>
101
+ #include <ATen/ops/_cast_Long.h>
102
+ #include <ATen/ops/_cast_Short.h>
103
+ #include <ATen/ops/_cdist_backward.h>
104
+ #include <ATen/ops/_cdist_forward.h>
105
+ #include <ATen/ops/_cholesky_solve_helper.h>
106
+ #include <ATen/ops/_choose_qparams_per_tensor.h>
107
+ #include <ATen/ops/_chunk_cat.h>
108
+ #include <ATen/ops/_coalesce.h>
109
+ #include <ATen/ops/_coalesced.h>
110
+ #include <ATen/ops/_compute_linear_combination.h>
111
+ #include <ATen/ops/_conj.h>
112
+ #include <ATen/ops/_conj_copy.h>
113
+ #include <ATen/ops/_conj_physical.h>
114
+ #include <ATen/ops/_conv_depthwise2d.h>
115
+ #include <ATen/ops/_convert_indices_from_coo_to_csr.h>
116
+ #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
117
+ #include <ATen/ops/_convert_weight_to_int4pack.h>
118
+ #include <ATen/ops/_convolution.h>
119
+ #include <ATen/ops/_convolution_double_backward.h>
120
+ #include <ATen/ops/_convolution_mode.h>
121
+ #include <ATen/ops/_copy_from.h>
122
+ #include <ATen/ops/_copy_from_and_resize.h>
123
+ #include <ATen/ops/_cslt_compress.h>
124
+ #include <ATen/ops/_cslt_sparse_mm.h>
125
+ #include <ATen/ops/_cslt_sparse_mm_search.h>
126
+ #include <ATen/ops/_ctc_loss.h>
127
+ #include <ATen/ops/_ctc_loss_backward.h>
128
+ #include <ATen/ops/_cudnn_ctc_loss.h>
129
+ #include <ATen/ops/_cudnn_init_dropout_state.h>
130
+ #include <ATen/ops/_cudnn_rnn.h>
131
+ #include <ATen/ops/_cudnn_rnn_backward.h>
132
+ #include <ATen/ops/_cudnn_rnn_flatten_weight.h>
133
+ #include <ATen/ops/_cufft_clear_plan_cache.h>
134
+ #include <ATen/ops/_cufft_get_plan_cache_max_size.h>
135
+ #include <ATen/ops/_cufft_get_plan_cache_size.h>
136
+ #include <ATen/ops/_cufft_set_plan_cache_max_size.h>
137
+ #include <ATen/ops/_cummax_helper.h>
138
+ #include <ATen/ops/_cummin_helper.h>
139
+ #include <ATen/ops/_debug_has_internal_overlap.h>
140
+ #include <ATen/ops/_dimI.h>
141
+ #include <ATen/ops/_dimV.h>
142
+ #include <ATen/ops/_dim_arange.h>
143
+ #include <ATen/ops/_dirichlet_grad.h>
144
+ #include <ATen/ops/_efficient_attention_backward.h>
145
+ #include <ATen/ops/_efficient_attention_forward.h>
146
+ #include <ATen/ops/_efficientzerotensor.h>
147
+ #include <ATen/ops/_embedding_bag.h>
148
+ #include <ATen/ops/_embedding_bag_backward.h>
149
+ #include <ATen/ops/_embedding_bag_dense_backward.h>
150
+ #include <ATen/ops/_embedding_bag_forward_only.h>
151
+ #include <ATen/ops/_embedding_bag_per_sample_weights_backward.h>
152
+ #include <ATen/ops/_embedding_bag_sparse_backward.h>
153
+ #include <ATen/ops/_empty_affine_quantized.h>
154
+ #include <ATen/ops/_empty_per_channel_affine_quantized.h>
155
+ #include <ATen/ops/_euclidean_dist.h>
156
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine.h>
157
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward.h>
158
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine.h>
159
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward.h>
160
+ #include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.h>
161
+ #include <ATen/ops/_fft_c2c.h>
162
+ #include <ATen/ops/_fft_c2r.h>
163
+ #include <ATen/ops/_fft_r2c.h>
164
+ #include <ATen/ops/_fill_mem_eff_dropout_mask.h>
165
+ #include <ATen/ops/_flash_attention_backward.h>
166
+ #include <ATen/ops/_flash_attention_forward.h>
167
+ #include <ATen/ops/_foobar.h>
168
+ #include <ATen/ops/_foreach_abs.h>
169
+ #include <ATen/ops/_foreach_acos.h>
170
+ #include <ATen/ops/_foreach_add.h>
171
+ #include <ATen/ops/_foreach_addcdiv.h>
172
+ #include <ATen/ops/_foreach_addcmul.h>
173
+ #include <ATen/ops/_foreach_asin.h>
174
+ #include <ATen/ops/_foreach_atan.h>
175
+ #include <ATen/ops/_foreach_ceil.h>
176
+ #include <ATen/ops/_foreach_clamp_max.h>
177
+ #include <ATen/ops/_foreach_clamp_min.h>
178
+ #include <ATen/ops/_foreach_copy.h>
179
+ #include <ATen/ops/_foreach_cos.h>
180
+ #include <ATen/ops/_foreach_cosh.h>
181
+ #include <ATen/ops/_foreach_div.h>
182
+ #include <ATen/ops/_foreach_erf.h>
183
+ #include <ATen/ops/_foreach_erfc.h>
184
+ #include <ATen/ops/_foreach_exp.h>
185
+ #include <ATen/ops/_foreach_expm1.h>
186
+ #include <ATen/ops/_foreach_floor.h>
187
+ #include <ATen/ops/_foreach_frac.h>
188
+ #include <ATen/ops/_foreach_lerp.h>
189
+ #include <ATen/ops/_foreach_lgamma.h>
190
+ #include <ATen/ops/_foreach_log.h>
191
+ #include <ATen/ops/_foreach_log10.h>
192
+ #include <ATen/ops/_foreach_log1p.h>
193
+ #include <ATen/ops/_foreach_log2.h>
194
+ #include <ATen/ops/_foreach_maximum.h>
195
+ #include <ATen/ops/_foreach_minimum.h>
196
+ #include <ATen/ops/_foreach_mul.h>
197
+ #include <ATen/ops/_foreach_neg.h>
198
+ #include <ATen/ops/_foreach_norm.h>
199
+ #include <ATen/ops/_foreach_pow.h>
200
+ #include <ATen/ops/_foreach_reciprocal.h>
201
+ #include <ATen/ops/_foreach_round.h>
202
+ #include <ATen/ops/_foreach_sigmoid.h>
203
+ #include <ATen/ops/_foreach_sign.h>
204
+ #include <ATen/ops/_foreach_sin.h>
205
+ #include <ATen/ops/_foreach_sinh.h>
206
+ #include <ATen/ops/_foreach_sqrt.h>
207
+ #include <ATen/ops/_foreach_sub.h>
208
+ #include <ATen/ops/_foreach_tan.h>
209
+ #include <ATen/ops/_foreach_tanh.h>
210
+ #include <ATen/ops/_foreach_trunc.h>
211
+ #include <ATen/ops/_foreach_zero.h>
212
+ #include <ATen/ops/_functional_assert_async.h>
213
+ #include <ATen/ops/_functional_assert_scalar.h>
214
+ #include <ATen/ops/_functional_sym_constrain_range.h>
215
+ #include <ATen/ops/_functional_sym_constrain_range_for_size.h>
216
+ #include <ATen/ops/_fused_adam.h>
217
+ #include <ATen/ops/_fused_adamw.h>
218
+ #include <ATen/ops/_fused_dropout.h>
219
+ #include <ATen/ops/_fused_moving_avg_obs_fq_helper.h>
220
+ #include <ATen/ops/_fused_sdp_choice.h>
221
+ #include <ATen/ops/_fused_sgd.h>
222
+ #include <ATen/ops/_fw_primal.h>
223
+ #include <ATen/ops/_fw_primal_copy.h>
224
+ #include <ATen/ops/_gather_sparse_backward.h>
225
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback.h>
226
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward.h>
227
+ #include <ATen/ops/_has_compatible_shallow_copy_type.h>
228
+ #include <ATen/ops/_has_same_storage_numel.h>
229
+ #include <ATen/ops/_histogramdd_bin_edges.h>
230
+ #include <ATen/ops/_histogramdd_from_bin_cts.h>
231
+ #include <ATen/ops/_histogramdd_from_bin_tensors.h>
232
+ #include <ATen/ops/_index_put_impl.h>
233
+ #include <ATen/ops/_indices.h>
234
+ #include <ATen/ops/_indices_copy.h>
235
+ #include <ATen/ops/_int_mm.h>
236
+ #include <ATen/ops/_is_all_true.h>
237
+ #include <ATen/ops/_is_any_true.h>
238
+ #include <ATen/ops/_is_zerotensor.h>
239
+ #include <ATen/ops/_lazy_clone.h>
240
+ #include <ATen/ops/_linalg_check_errors.h>
241
+ #include <ATen/ops/_linalg_det.h>
242
+ #include <ATen/ops/_linalg_eigh.h>
243
+ #include <ATen/ops/_linalg_eigvals.h>
244
+ #include <ATen/ops/_linalg_slogdet.h>
245
+ #include <ATen/ops/_linalg_solve_ex.h>
246
+ #include <ATen/ops/_linalg_svd.h>
247
+ #include <ATen/ops/_local_scalar_dense.h>
248
+ #include <ATen/ops/_log_softmax.h>
249
+ #include <ATen/ops/_log_softmax_backward_data.h>
250
+ #include <ATen/ops/_logcumsumexp.h>
251
+ #include <ATen/ops/_lstm_mps.h>
252
+ #include <ATen/ops/_lu_with_info.h>
253
+ #include <ATen/ops/_make_dep_token.h>
254
+ #include <ATen/ops/_make_dual.h>
255
+ #include <ATen/ops/_make_dual_copy.h>
256
+ #include <ATen/ops/_make_per_channel_quantized_tensor.h>
257
+ #include <ATen/ops/_make_per_tensor_quantized_tensor.h>
258
+ #include <ATen/ops/_masked_scale.h>
259
+ #include <ATen/ops/_masked_softmax.h>
260
+ #include <ATen/ops/_masked_softmax_backward.h>
261
+ #include <ATen/ops/_mixed_dtypes_linear.h>
262
+ #include <ATen/ops/_mkldnn_reshape.h>
263
+ #include <ATen/ops/_mkldnn_transpose.h>
264
+ #include <ATen/ops/_mps_convolution.h>
265
+ #include <ATen/ops/_mps_convolution_transpose.h>
266
+ #include <ATen/ops/_native_batch_norm_legit.h>
267
+ #include <ATen/ops/_native_batch_norm_legit_no_training.h>
268
+ #include <ATen/ops/_native_multi_head_attention.h>
269
+ #include <ATen/ops/_neg_view.h>
270
+ #include <ATen/ops/_neg_view_copy.h>
271
+ #include <ATen/ops/_nested_from_padded.h>
272
+ #include <ATen/ops/_nested_from_padded_and_nested_example.h>
273
+ #include <ATen/ops/_nested_get_jagged_dummy.h>
274
+ #include <ATen/ops/_nested_get_lengths.h>
275
+ #include <ATen/ops/_nested_get_offsets.h>
276
+ #include <ATen/ops/_nested_get_ragged_idx.h>
277
+ #include <ATen/ops/_nested_get_values.h>
278
+ #include <ATen/ops/_nested_get_values_copy.h>
279
+ #include <ATen/ops/_nested_select_backward.h>
280
+ #include <ATen/ops/_nested_sum_backward.h>
281
+ #include <ATen/ops/_nested_tensor_from_mask.h>
282
+ #include <ATen/ops/_nested_tensor_from_mask_left_aligned.h>
283
+ #include <ATen/ops/_nested_tensor_from_tensor_list.h>
284
+ #include <ATen/ops/_nested_tensor_size.h>
285
+ #include <ATen/ops/_nested_tensor_softmax_with_shape.h>
286
+ #include <ATen/ops/_nested_tensor_storage_offsets.h>
287
+ #include <ATen/ops/_nested_tensor_strides.h>
288
+ #include <ATen/ops/_nested_view_from_buffer.h>
289
+ #include <ATen/ops/_nested_view_from_buffer_copy.h>
290
+ #include <ATen/ops/_nested_view_from_jagged.h>
291
+ #include <ATen/ops/_nested_view_from_jagged_copy.h>
292
+ #include <ATen/ops/_new_zeros_with_same_feature_meta.h>
293
+ #include <ATen/ops/_nnpack_available.h>
294
+ #include <ATen/ops/_nnpack_spatial_convolution.h>
295
+ #include <ATen/ops/_nnz.h>
296
+ #include <ATen/ops/_pack_padded_sequence.h>
297
+ #include <ATen/ops/_pack_padded_sequence_backward.h>
298
+ #include <ATen/ops/_pad_circular.h>
299
+ #include <ATen/ops/_pad_enum.h>
300
+ #include <ATen/ops/_pad_packed_sequence.h>
301
+ #include <ATen/ops/_pdist_backward.h>
302
+ #include <ATen/ops/_pdist_forward.h>
303
+ #include <ATen/ops/_pin_memory.h>
304
+ #include <ATen/ops/_prelu_kernel.h>
305
+ #include <ATen/ops/_prelu_kernel_backward.h>
306
+ #include <ATen/ops/_print.h>
307
+ #include <ATen/ops/_propagate_xla_data.h>
308
+ #include <ATen/ops/_remove_batch_dim.h>
309
+ #include <ATen/ops/_reshape_alias.h>
310
+ #include <ATen/ops/_reshape_alias_copy.h>
311
+ #include <ATen/ops/_reshape_copy.h>
312
+ #include <ATen/ops/_reshape_from_tensor.h>
313
+ #include <ATen/ops/_resize_output.h>
314
+ #include <ATen/ops/_rowwise_prune.h>
315
+ #include <ATen/ops/_sample_dirichlet.h>
316
+ #include <ATen/ops/_saturate_weight_to_fp16.h>
317
+ #include <ATen/ops/_scaled_dot_product_attention_math.h>
318
+ #include <ATen/ops/_scaled_dot_product_cudnn_attention.h>
319
+ #include <ATen/ops/_scaled_dot_product_efficient_attention.h>
320
+ #include <ATen/ops/_scaled_dot_product_efficient_attention_backward.h>
321
+ #include <ATen/ops/_scaled_dot_product_flash_attention.h>
322
+ #include <ATen/ops/_scaled_dot_product_flash_attention_backward.h>
323
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu.h>
324
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
325
+ #include <ATen/ops/_scaled_mm.h>
326
+ #include <ATen/ops/_segment_reduce_backward.h>
327
+ #include <ATen/ops/_shape_as_tensor.h>
328
+ #include <ATen/ops/_slow_conv2d_backward.h>
329
+ #include <ATen/ops/_slow_conv2d_forward.h>
330
+ #include <ATen/ops/_sobol_engine_draw.h>
331
+ #include <ATen/ops/_sobol_engine_ff.h>
332
+ #include <ATen/ops/_sobol_engine_initialize_state.h>
333
+ #include <ATen/ops/_sobol_engine_scramble.h>
334
+ #include <ATen/ops/_softmax.h>
335
+ #include <ATen/ops/_softmax_backward_data.h>
336
+ #include <ATen/ops/_sparse_addmm.h>
337
+ #include <ATen/ops/_sparse_broadcast_to.h>
338
+ #include <ATen/ops/_sparse_broadcast_to_copy.h>
339
+ #include <ATen/ops/_sparse_bsc_tensor_unsafe.h>
340
+ #include <ATen/ops/_sparse_bsr_tensor_unsafe.h>
341
+ #include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
342
+ #include <ATen/ops/_sparse_coo_tensor_unsafe.h>
343
+ #include <ATen/ops/_sparse_coo_tensor_with_dims.h>
344
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
345
+ #include <ATen/ops/_sparse_csc_tensor_unsafe.h>
346
+ #include <ATen/ops/_sparse_csr_prod.h>
347
+ #include <ATen/ops/_sparse_csr_sum.h>
348
+ #include <ATen/ops/_sparse_csr_tensor_unsafe.h>
349
+ #include <ATen/ops/_sparse_log_softmax.h>
350
+ #include <ATen/ops/_sparse_log_softmax_backward_data.h>
351
+ #include <ATen/ops/_sparse_mask_projection.h>
352
+ #include <ATen/ops/_sparse_mm.h>
353
+ #include <ATen/ops/_sparse_mm_reduce_impl.h>
354
+ #include <ATen/ops/_sparse_mm_reduce_impl_backward.h>
355
+ #include <ATen/ops/_sparse_semi_structured_linear.h>
356
+ #include <ATen/ops/_sparse_softmax.h>
357
+ #include <ATen/ops/_sparse_softmax_backward_data.h>
358
+ #include <ATen/ops/_sparse_sparse_matmul.h>
359
+ #include <ATen/ops/_sparse_sum.h>
360
+ #include <ATen/ops/_sparse_sum_backward.h>
361
+ #include <ATen/ops/_spdiags.h>
362
+ #include <ATen/ops/_stack.h>
363
+ #include <ATen/ops/_standard_gamma.h>
364
+ #include <ATen/ops/_standard_gamma_grad.h>
365
+ #include <ATen/ops/_test_ambiguous_defaults.h>
366
+ #include <ATen/ops/_test_autograd_multiple_dispatch.h>
367
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view.h>
368
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_copy.h>
369
+ #include <ATen/ops/_test_check_tensor.h>
370
+ #include <ATen/ops/_test_functorch_fallback.h>
371
+ #include <ATen/ops/_test_optional_filled_intlist.h>
372
+ #include <ATen/ops/_test_optional_floatlist.h>
373
+ #include <ATen/ops/_test_optional_intlist.h>
374
+ #include <ATen/ops/_test_parallel_materialize.h>
375
+ #include <ATen/ops/_test_serialization_subcmul.h>
376
+ #include <ATen/ops/_test_string_default.h>
377
+ #include <ATen/ops/_test_warn_in_autograd.h>
378
+ #include <ATen/ops/_thnn_differentiable_gru_cell_backward.h>
379
+ #include <ATen/ops/_thnn_differentiable_lstm_cell_backward.h>
380
+ #include <ATen/ops/_thnn_fused_gru_cell.h>
381
+ #include <ATen/ops/_thnn_fused_gru_cell_backward.h>
382
+ #include <ATen/ops/_thnn_fused_lstm_cell.h>
383
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward.h>
384
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl.h>
385
+ #include <ATen/ops/_to_copy.h>
386
+ #include <ATen/ops/_to_cpu.h>
387
+ #include <ATen/ops/_to_dense.h>
388
+ #include <ATen/ops/_to_sparse.h>
389
+ #include <ATen/ops/_to_sparse_bsc.h>
390
+ #include <ATen/ops/_to_sparse_bsr.h>
391
+ #include <ATen/ops/_to_sparse_csc.h>
392
+ #include <ATen/ops/_to_sparse_csr.h>
393
+ #include <ATen/ops/_to_sparse_semi_structured.h>
394
+ #include <ATen/ops/_transform_bias_rescale_qkv.h>
395
+ #include <ATen/ops/_transformer_encoder_layer_fwd.h>
396
+ #include <ATen/ops/_trilinear.h>
397
+ #include <ATen/ops/_triton_multi_head_attention.h>
398
+ #include <ATen/ops/_triton_scaled_dot_attention.h>
399
+ #include <ATen/ops/_unique.h>
400
+ #include <ATen/ops/_unique2.h>
401
+ #include <ATen/ops/_unpack_dual.h>
402
+ #include <ATen/ops/_unsafe_index.h>
403
+ #include <ATen/ops/_unsafe_index_put.h>
404
+ #include <ATen/ops/_unsafe_view.h>
405
+ #include <ATen/ops/_upsample_bicubic2d_aa.h>
406
+ #include <ATen/ops/_upsample_bicubic2d_aa_backward.h>
407
+ #include <ATen/ops/_upsample_bilinear2d_aa.h>
408
+ #include <ATen/ops/_upsample_bilinear2d_aa_backward.h>
409
+ #include <ATen/ops/_upsample_nearest_exact1d.h>
410
+ #include <ATen/ops/_upsample_nearest_exact1d_backward.h>
411
+ #include <ATen/ops/_upsample_nearest_exact2d.h>
412
+ #include <ATen/ops/_upsample_nearest_exact2d_backward.h>
413
+ #include <ATen/ops/_upsample_nearest_exact3d.h>
414
+ #include <ATen/ops/_upsample_nearest_exact3d_backward.h>
415
+ #include <ATen/ops/_use_cudnn_ctc_loss.h>
416
+ #include <ATen/ops/_use_cudnn_rnn_flatten_weight.h>
417
+ #include <ATen/ops/_validate_compressed_sparse_indices.h>
418
+ #include <ATen/ops/_validate_sparse_bsc_tensor_args.h>
419
+ #include <ATen/ops/_validate_sparse_bsr_tensor_args.h>
420
+ #include <ATen/ops/_validate_sparse_compressed_tensor_args.h>
421
+ #include <ATen/ops/_validate_sparse_coo_tensor_args.h>
422
+ #include <ATen/ops/_validate_sparse_csc_tensor_args.h>
423
+ #include <ATen/ops/_validate_sparse_csr_tensor_args.h>
424
+ #include <ATen/ops/_values.h>
425
+ #include <ATen/ops/_values_copy.h>
426
+ #include <ATen/ops/_version.h>
427
+ #include <ATen/ops/_weight_int4pack_mm.h>
428
+ #include <ATen/ops/_weight_int8pack_mm.h>
429
+ #include <ATen/ops/_weight_norm.h>
430
+ #include <ATen/ops/_weight_norm_differentiable_backward.h>
431
+ #include <ATen/ops/_weight_norm_interface.h>
432
+ #include <ATen/ops/_weight_norm_interface_backward.h>
433
+ #include <ATen/ops/abs.h>
434
+ #include <ATen/ops/absolute.h>
435
+ #include <ATen/ops/acos.h>
436
+ #include <ATen/ops/acosh.h>
437
+ #include <ATen/ops/adaptive_avg_pool1d.h>
438
+ #include <ATen/ops/adaptive_avg_pool2d.h>
439
+ #include <ATen/ops/adaptive_avg_pool3d.h>
440
+ #include <ATen/ops/adaptive_avg_pool3d_backward.h>
441
+ #include <ATen/ops/adaptive_max_pool1d.h>
442
+ #include <ATen/ops/adaptive_max_pool2d.h>
443
+ #include <ATen/ops/adaptive_max_pool2d_backward.h>
444
+ #include <ATen/ops/adaptive_max_pool3d.h>
445
+ #include <ATen/ops/adaptive_max_pool3d_backward.h>
446
+ #include <ATen/ops/add.h>
447
+ #include <ATen/ops/addbmm.h>
448
+ #include <ATen/ops/addcdiv.h>
449
+ #include <ATen/ops/addcmul.h>
450
+ #include <ATen/ops/addmm.h>
451
+ #include <ATen/ops/addmv.h>
452
+ #include <ATen/ops/addr.h>
453
+ #include <ATen/ops/adjoint.h>
454
+ #include <ATen/ops/affine_grid_generator.h>
455
+ #include <ATen/ops/affine_grid_generator_backward.h>
456
+ #include <ATen/ops/alias.h>
457
+ #include <ATen/ops/alias_copy.h>
458
+ #include <ATen/ops/align_as.h>
459
+ #include <ATen/ops/align_tensors.h>
460
+ #include <ATen/ops/align_to.h>
461
+ #include <ATen/ops/all.h>
462
+ #include <ATen/ops/allclose.h>
463
+ #include <ATen/ops/alpha_dropout.h>
464
+ #include <ATen/ops/amax.h>
465
+ #include <ATen/ops/amin.h>
466
+ #include <ATen/ops/aminmax.h>
467
+ #include <ATen/ops/and.h>
468
+ #include <ATen/ops/angle.h>
469
+ #include <ATen/ops/any.h>
470
+ #include <ATen/ops/arange.h>
471
+ #include <ATen/ops/arccos.h>
472
+ #include <ATen/ops/arccosh.h>
473
+ #include <ATen/ops/arcsin.h>
474
+ #include <ATen/ops/arcsinh.h>
475
+ #include <ATen/ops/arctan.h>
476
+ #include <ATen/ops/arctan2.h>
477
+ #include <ATen/ops/arctanh.h>
478
+ #include <ATen/ops/argmax.h>
479
+ #include <ATen/ops/argmin.h>
480
+ #include <ATen/ops/argsort.h>
481
+ #include <ATen/ops/argwhere.h>
482
+ #include <ATen/ops/as_strided.h>
483
+ #include <ATen/ops/as_strided_copy.h>
484
+ #include <ATen/ops/as_strided_scatter.h>
485
+ #include <ATen/ops/asin.h>
486
+ #include <ATen/ops/asinh.h>
487
+ #include <ATen/ops/atan.h>
488
+ #include <ATen/ops/atan2.h>
489
+ #include <ATen/ops/atanh.h>
490
+ #include <ATen/ops/atleast_1d.h>
491
+ #include <ATen/ops/atleast_2d.h>
492
+ #include <ATen/ops/atleast_3d.h>
493
+ #include <ATen/ops/avg_pool1d.h>
494
+ #include <ATen/ops/avg_pool2d.h>
495
+ #include <ATen/ops/avg_pool2d_backward.h>
496
+ #include <ATen/ops/avg_pool3d.h>
497
+ #include <ATen/ops/avg_pool3d_backward.h>
498
+ #include <ATen/ops/baddbmm.h>
499
+ #include <ATen/ops/bartlett_window.h>
500
+ #include <ATen/ops/batch_norm.h>
501
+ #include <ATen/ops/batch_norm_backward_elemt.h>
502
+ #include <ATen/ops/batch_norm_backward_reduce.h>
503
+ #include <ATen/ops/batch_norm_elemt.h>
504
+ #include <ATen/ops/batch_norm_gather_stats.h>
505
+ #include <ATen/ops/batch_norm_gather_stats_with_counts.h>
506
+ #include <ATen/ops/batch_norm_stats.h>
507
+ #include <ATen/ops/batch_norm_update_stats.h>
508
+ #include <ATen/ops/bernoulli.h>
509
+ #include <ATen/ops/bilinear.h>
510
+ #include <ATen/ops/binary_cross_entropy.h>
511
+ #include <ATen/ops/binary_cross_entropy_backward.h>
512
+ #include <ATen/ops/binary_cross_entropy_with_logits.h>
513
+ #include <ATen/ops/bincount.h>
514
+ #include <ATen/ops/binomial.h>
515
+ #include <ATen/ops/bitwise_and.h>
516
+ #include <ATen/ops/bitwise_left_shift.h>
517
+ #include <ATen/ops/bitwise_not.h>
518
+ #include <ATen/ops/bitwise_or.h>
519
+ #include <ATen/ops/bitwise_right_shift.h>
520
+ #include <ATen/ops/bitwise_xor.h>
521
+ #include <ATen/ops/blackman_window.h>
522
+ #include <ATen/ops/block_diag.h>
523
+ #include <ATen/ops/bmm.h>
524
+ #include <ATen/ops/broadcast_tensors.h>
525
+ #include <ATen/ops/broadcast_to.h>
526
+ #include <ATen/ops/bucketize.h>
527
+ #include <ATen/ops/can_cast.h>
528
+ #include <ATen/ops/cartesian_prod.h>
529
+ #include <ATen/ops/cat.h>
530
+ #include <ATen/ops/cauchy.h>
531
+ #include <ATen/ops/ccol_indices.h>
532
+ #include <ATen/ops/ccol_indices_copy.h>
533
+ #include <ATen/ops/cdist.h>
534
+ #include <ATen/ops/ceil.h>
535
+ #include <ATen/ops/celu.h>
536
+ #include <ATen/ops/chain_matmul.h>
537
+ #include <ATen/ops/chalf.h>
538
+ #include <ATen/ops/channel_shuffle.h>
539
+ #include <ATen/ops/cholesky.h>
540
+ #include <ATen/ops/cholesky_inverse.h>
541
+ #include <ATen/ops/cholesky_solve.h>
542
+ #include <ATen/ops/choose_qparams_optimized.h>
543
+ #include <ATen/ops/chunk.h>
544
+ #include <ATen/ops/clamp.h>
545
+ #include <ATen/ops/clamp_max.h>
546
+ #include <ATen/ops/clamp_min.h>
547
+ #include <ATen/ops/clip.h>
548
+ #include <ATen/ops/clone.h>
549
+ #include <ATen/ops/coalesce.h>
550
+ #include <ATen/ops/col2im.h>
551
+ #include <ATen/ops/col_indices.h>
552
+ #include <ATen/ops/col_indices_copy.h>
553
+ #include <ATen/ops/column_stack.h>
554
+ #include <ATen/ops/combinations.h>
555
+ #include <ATen/ops/complex.h>
556
+ #include <ATen/ops/concat.h>
557
+ #include <ATen/ops/concatenate.h>
558
+ #include <ATen/ops/conj.h>
559
+ #include <ATen/ops/conj_physical.h>
560
+ #include <ATen/ops/constant_pad_nd.h>
561
+ #include <ATen/ops/contiguous.h>
562
+ #include <ATen/ops/conv1d.h>
563
+ #include <ATen/ops/conv2d.h>
564
+ #include <ATen/ops/conv3d.h>
565
+ #include <ATen/ops/conv_depthwise3d.h>
566
+ #include <ATen/ops/conv_tbc.h>
567
+ #include <ATen/ops/conv_tbc_backward.h>
568
+ #include <ATen/ops/conv_transpose1d.h>
569
+ #include <ATen/ops/conv_transpose2d.h>
570
+ #include <ATen/ops/conv_transpose3d.h>
571
+ #include <ATen/ops/convolution.h>
572
+ #include <ATen/ops/convolution_backward.h>
573
+ #include <ATen/ops/convolution_backward_overrideable.h>
574
+ #include <ATen/ops/convolution_overrideable.h>
575
+ #include <ATen/ops/copy.h>
576
+ #include <ATen/ops/copy_sparse_to_sparse.h>
577
+ #include <ATen/ops/copysign.h>
578
+ #include <ATen/ops/corrcoef.h>
579
+ #include <ATen/ops/cos.h>
580
+ #include <ATen/ops/cosh.h>
581
+ #include <ATen/ops/cosine_embedding_loss.h>
582
+ #include <ATen/ops/cosine_similarity.h>
583
+ #include <ATen/ops/count_nonzero.h>
584
+ #include <ATen/ops/cov.h>
585
+ #include <ATen/ops/cross.h>
586
+ #include <ATen/ops/cross_entropy_loss.h>
587
+ #include <ATen/ops/crow_indices.h>
588
+ #include <ATen/ops/crow_indices_copy.h>
589
+ #include <ATen/ops/ctc_loss.h>
590
+ #include <ATen/ops/cudnn_affine_grid_generator.h>
591
+ #include <ATen/ops/cudnn_affine_grid_generator_backward.h>
592
+ #include <ATen/ops/cudnn_batch_norm.h>
593
+ #include <ATen/ops/cudnn_batch_norm_backward.h>
594
+ #include <ATen/ops/cudnn_convolution.h>
595
+ #include <ATen/ops/cudnn_convolution_add_relu.h>
596
+ #include <ATen/ops/cudnn_convolution_relu.h>
597
+ #include <ATen/ops/cudnn_convolution_transpose.h>
598
+ #include <ATen/ops/cudnn_grid_sampler.h>
599
+ #include <ATen/ops/cudnn_grid_sampler_backward.h>
600
+ #include <ATen/ops/cudnn_is_acceptable.h>
601
+ #include <ATen/ops/cummax.h>
602
+ #include <ATen/ops/cummaxmin_backward.h>
603
+ #include <ATen/ops/cummin.h>
604
+ #include <ATen/ops/cumprod.h>
605
+ #include <ATen/ops/cumprod_backward.h>
606
+ #include <ATen/ops/cumsum.h>
607
+ #include <ATen/ops/cumulative_trapezoid.h>
608
+ #include <ATen/ops/data.h>
609
+ #include <ATen/ops/deg2rad.h>
610
+ #include <ATen/ops/dense_dim.h>
611
+ #include <ATen/ops/dequantize.h>
612
+ #include <ATen/ops/det.h>
613
+ #include <ATen/ops/detach.h>
614
+ #include <ATen/ops/detach_copy.h>
615
+ #include <ATen/ops/diag.h>
616
+ #include <ATen/ops/diag_embed.h>
617
+ #include <ATen/ops/diagflat.h>
618
+ #include <ATen/ops/diagonal.h>
619
+ #include <ATen/ops/diagonal_backward.h>
620
+ #include <ATen/ops/diagonal_copy.h>
621
+ #include <ATen/ops/diagonal_scatter.h>
622
+ #include <ATen/ops/diff.h>
623
+ #include <ATen/ops/digamma.h>
624
+ #include <ATen/ops/dist.h>
625
+ #include <ATen/ops/div.h>
626
+ #include <ATen/ops/divide.h>
627
+ #include <ATen/ops/dot.h>
628
+ #include <ATen/ops/dropout.h>
629
+ #include <ATen/ops/dsplit.h>
630
+ #include <ATen/ops/dstack.h>
631
+ #include <ATen/ops/einsum.h>
632
+ #include <ATen/ops/elu.h>
633
+ #include <ATen/ops/elu_backward.h>
634
+ #include <ATen/ops/embedding.h>
635
+ #include <ATen/ops/embedding_backward.h>
636
+ #include <ATen/ops/embedding_bag.h>
637
+ #include <ATen/ops/embedding_dense_backward.h>
638
+ #include <ATen/ops/embedding_renorm.h>
639
+ #include <ATen/ops/embedding_sparse_backward.h>
640
+ #include <ATen/ops/empty.h>
641
+ #include <ATen/ops/empty_like.h>
642
+ #include <ATen/ops/empty_permuted.h>
643
+ #include <ATen/ops/empty_quantized.h>
644
+ #include <ATen/ops/empty_strided.h>
645
+ #include <ATen/ops/eq.h>
646
+ #include <ATen/ops/equal.h>
647
+ #include <ATen/ops/erf.h>
648
+ #include <ATen/ops/erfc.h>
649
+ #include <ATen/ops/erfinv.h>
650
+ #include <ATen/ops/exp.h>
651
+ #include <ATen/ops/exp2.h>
652
+ #include <ATen/ops/expand.h>
653
+ #include <ATen/ops/expand_as.h>
654
+ #include <ATen/ops/expand_copy.h>
655
+ #include <ATen/ops/expm1.h>
656
+ #include <ATen/ops/exponential.h>
657
+ #include <ATen/ops/eye.h>
658
+ #include <ATen/ops/fake_quantize_per_channel_affine.h>
659
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask.h>
660
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h>
661
+ #include <ATen/ops/fake_quantize_per_tensor_affine.h>
662
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask.h>
663
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward.h>
664
+ #include <ATen/ops/fbgemm_linear_fp16_weight.h>
665
+ #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation.h>
666
+ #include <ATen/ops/fbgemm_linear_int8_weight.h>
667
+ #include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
668
+ #include <ATen/ops/fbgemm_linear_quantize_weight.h>
669
+ #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16.h>
670
+ #include <ATen/ops/fbgemm_pack_quantized_matrix.h>
671
+ #include <ATen/ops/feature_alpha_dropout.h>
672
+ #include <ATen/ops/feature_dropout.h>
673
+ #include <ATen/ops/fft_fft.h>
674
+ #include <ATen/ops/fft_fft2.h>
675
+ #include <ATen/ops/fft_fftfreq.h>
676
+ #include <ATen/ops/fft_fftn.h>
677
+ #include <ATen/ops/fft_fftshift.h>
678
+ #include <ATen/ops/fft_hfft.h>
679
+ #include <ATen/ops/fft_hfft2.h>
680
+ #include <ATen/ops/fft_hfftn.h>
681
+ #include <ATen/ops/fft_ifft.h>
682
+ #include <ATen/ops/fft_ifft2.h>
683
+ #include <ATen/ops/fft_ifftn.h>
684
+ #include <ATen/ops/fft_ifftshift.h>
685
+ #include <ATen/ops/fft_ihfft.h>
686
+ #include <ATen/ops/fft_ihfft2.h>
687
+ #include <ATen/ops/fft_ihfftn.h>
688
+ #include <ATen/ops/fft_irfft.h>
689
+ #include <ATen/ops/fft_irfft2.h>
690
+ #include <ATen/ops/fft_irfftn.h>
691
+ #include <ATen/ops/fft_rfft.h>
692
+ #include <ATen/ops/fft_rfft2.h>
693
+ #include <ATen/ops/fft_rfftfreq.h>
694
+ #include <ATen/ops/fft_rfftn.h>
695
+ #include <ATen/ops/fill.h>
696
+ #include <ATen/ops/fill_diagonal.h>
697
+ #include <ATen/ops/fix.h>
698
+ #include <ATen/ops/flatten.h>
699
+ #include <ATen/ops/flatten_dense_tensors.h>
700
+ #include <ATen/ops/flip.h>
701
+ #include <ATen/ops/fliplr.h>
702
+ #include <ATen/ops/flipud.h>
703
+ #include <ATen/ops/float_power.h>
704
+ #include <ATen/ops/floor.h>
705
+ #include <ATen/ops/floor_divide.h>
706
+ #include <ATen/ops/fmax.h>
707
+ #include <ATen/ops/fmin.h>
708
+ #include <ATen/ops/fmod.h>
709
+ #include <ATen/ops/frac.h>
710
+ #include <ATen/ops/fractional_max_pool2d.h>
711
+ #include <ATen/ops/fractional_max_pool2d_backward.h>
712
+ #include <ATen/ops/fractional_max_pool3d.h>
713
+ #include <ATen/ops/fractional_max_pool3d_backward.h>
714
+ #include <ATen/ops/frexp.h>
715
+ #include <ATen/ops/frobenius_norm.h>
716
+ #include <ATen/ops/from_file.h>
717
+ #include <ATen/ops/full.h>
718
+ #include <ATen/ops/full_like.h>
719
+ #include <ATen/ops/fused_moving_avg_obs_fake_quant.h>
720
+ #include <ATen/ops/gather.h>
721
+ #include <ATen/ops/gather_backward.h>
722
+ #include <ATen/ops/gcd.h>
723
+ #include <ATen/ops/ge.h>
724
+ #include <ATen/ops/gelu.h>
725
+ #include <ATen/ops/gelu_backward.h>
726
+ #include <ATen/ops/geometric.h>
727
+ #include <ATen/ops/geqrf.h>
728
+ #include <ATen/ops/ger.h>
729
+ #include <ATen/ops/glu.h>
730
+ #include <ATen/ops/glu_backward.h>
731
+ #include <ATen/ops/glu_backward_jvp.h>
732
+ #include <ATen/ops/glu_jvp.h>
733
+ #include <ATen/ops/gradient.h>
734
+ #include <ATen/ops/greater.h>
735
+ #include <ATen/ops/greater_equal.h>
736
+ #include <ATen/ops/grid_sampler.h>
737
+ #include <ATen/ops/grid_sampler_2d.h>
738
+ #include <ATen/ops/grid_sampler_2d_backward.h>
739
+ #include <ATen/ops/grid_sampler_3d.h>
740
+ #include <ATen/ops/grid_sampler_3d_backward.h>
741
+ #include <ATen/ops/group_norm.h>
742
+ #include <ATen/ops/gru.h>
743
+ #include <ATen/ops/gru_cell.h>
744
+ #include <ATen/ops/gt.h>
745
+ #include <ATen/ops/hamming_window.h>
746
+ #include <ATen/ops/hann_window.h>
747
+ #include <ATen/ops/hardshrink.h>
748
+ #include <ATen/ops/hardshrink_backward.h>
749
+ #include <ATen/ops/hardsigmoid.h>
750
+ #include <ATen/ops/hardsigmoid_backward.h>
751
+ #include <ATen/ops/hardswish.h>
752
+ #include <ATen/ops/hardswish_backward.h>
753
+ #include <ATen/ops/hardtanh.h>
754
+ #include <ATen/ops/hardtanh_backward.h>
755
+ #include <ATen/ops/heaviside.h>
756
+ #include <ATen/ops/hinge_embedding_loss.h>
757
+ #include <ATen/ops/histc.h>
758
+ #include <ATen/ops/histogram.h>
759
+ #include <ATen/ops/histogramdd.h>
760
+ #include <ATen/ops/hsplit.h>
761
+ #include <ATen/ops/hspmm.h>
762
+ #include <ATen/ops/hstack.h>
763
+ #include <ATen/ops/huber_loss.h>
764
+ #include <ATen/ops/huber_loss_backward.h>
765
+ #include <ATen/ops/hypot.h>
766
+ #include <ATen/ops/i0.h>
767
+ #include <ATen/ops/igamma.h>
768
+ #include <ATen/ops/igammac.h>
769
+ #include <ATen/ops/im2col.h>
770
+ #include <ATen/ops/imag.h>
771
+ #include <ATen/ops/index.h>
772
+ #include <ATen/ops/index_add.h>
773
+ #include <ATen/ops/index_copy.h>
774
+ #include <ATen/ops/index_fill.h>
775
+ #include <ATen/ops/index_put.h>
776
+ #include <ATen/ops/index_reduce.h>
777
+ #include <ATen/ops/index_select.h>
778
+ #include <ATen/ops/index_select_backward.h>
779
+ #include <ATen/ops/indices.h>
780
+ #include <ATen/ops/indices_copy.h>
781
+ #include <ATen/ops/infinitely_differentiable_gelu_backward.h>
782
+ #include <ATen/ops/inner.h>
783
+ #include <ATen/ops/instance_norm.h>
784
+ #include <ATen/ops/int_repr.h>
785
+ #include <ATen/ops/inverse.h>
786
+ #include <ATen/ops/is_coalesced.h>
787
+ #include <ATen/ops/is_complex.h>
788
+ #include <ATen/ops/is_conj.h>
789
+ #include <ATen/ops/is_distributed.h>
790
+ #include <ATen/ops/is_floating_point.h>
791
+ #include <ATen/ops/is_inference.h>
792
+ #include <ATen/ops/is_leaf.h>
793
+ #include <ATen/ops/is_neg.h>
794
+ #include <ATen/ops/is_nonzero.h>
795
+ #include <ATen/ops/is_pinned.h>
796
+ #include <ATen/ops/is_same_size.h>
797
+ #include <ATen/ops/is_set_to.h>
798
+ #include <ATen/ops/is_signed.h>
799
+ #include <ATen/ops/is_vulkan_available.h>
800
+ #include <ATen/ops/isclose.h>
801
+ #include <ATen/ops/isfinite.h>
802
+ #include <ATen/ops/isin.h>
803
+ #include <ATen/ops/isinf.h>
804
+ #include <ATen/ops/isnan.h>
805
+ #include <ATen/ops/isneginf.h>
806
+ #include <ATen/ops/isposinf.h>
807
+ #include <ATen/ops/isreal.h>
808
+ #include <ATen/ops/istft.h>
809
+ #include <ATen/ops/item.h>
810
+ #include <ATen/ops/kaiser_window.h>
811
+ #include <ATen/ops/kl_div.h>
812
+ #include <ATen/ops/kron.h>
813
+ #include <ATen/ops/kthvalue.h>
814
+ #include <ATen/ops/l1_loss.h>
815
+ #include <ATen/ops/layer_norm.h>
816
+ #include <ATen/ops/lcm.h>
817
+ #include <ATen/ops/ldexp.h>
818
+ #include <ATen/ops/le.h>
819
+ #include <ATen/ops/leaky_relu.h>
820
+ #include <ATen/ops/leaky_relu_backward.h>
821
+ #include <ATen/ops/lerp.h>
822
+ #include <ATen/ops/less.h>
823
+ #include <ATen/ops/less_equal.h>
824
+ #include <ATen/ops/lgamma.h>
825
+ #include <ATen/ops/lift.h>
826
+ #include <ATen/ops/lift_fresh.h>
827
+ #include <ATen/ops/lift_fresh_copy.h>
828
+ #include <ATen/ops/linalg_cholesky.h>
829
+ #include <ATen/ops/linalg_cholesky_ex.h>
830
+ #include <ATen/ops/linalg_cond.h>
831
+ #include <ATen/ops/linalg_cross.h>
832
+ #include <ATen/ops/linalg_det.h>
833
+ #include <ATen/ops/linalg_diagonal.h>
834
+ #include <ATen/ops/linalg_eig.h>
835
+ #include <ATen/ops/linalg_eigh.h>
836
+ #include <ATen/ops/linalg_eigvals.h>
837
+ #include <ATen/ops/linalg_eigvalsh.h>
838
+ #include <ATen/ops/linalg_householder_product.h>
839
+ #include <ATen/ops/linalg_inv.h>
840
+ #include <ATen/ops/linalg_inv_ex.h>
841
+ #include <ATen/ops/linalg_ldl_factor.h>
842
+ #include <ATen/ops/linalg_ldl_factor_ex.h>
843
+ #include <ATen/ops/linalg_ldl_solve.h>
844
+ #include <ATen/ops/linalg_lstsq.h>
845
+ #include <ATen/ops/linalg_lu.h>
846
+ #include <ATen/ops/linalg_lu_factor.h>
847
+ #include <ATen/ops/linalg_lu_factor_ex.h>
848
+ #include <ATen/ops/linalg_lu_solve.h>
849
+ #include <ATen/ops/linalg_matmul.h>
850
+ #include <ATen/ops/linalg_matrix_exp.h>
851
+ #include <ATen/ops/linalg_matrix_norm.h>
852
+ #include <ATen/ops/linalg_matrix_power.h>
853
+ #include <ATen/ops/linalg_matrix_rank.h>
854
+ #include <ATen/ops/linalg_multi_dot.h>
855
+ #include <ATen/ops/linalg_norm.h>
856
+ #include <ATen/ops/linalg_pinv.h>
857
+ #include <ATen/ops/linalg_qr.h>
858
+ #include <ATen/ops/linalg_slogdet.h>
859
+ #include <ATen/ops/linalg_solve.h>
860
+ #include <ATen/ops/linalg_solve_ex.h>
861
+ #include <ATen/ops/linalg_solve_triangular.h>
862
+ #include <ATen/ops/linalg_svd.h>
863
+ #include <ATen/ops/linalg_svdvals.h>
864
+ #include <ATen/ops/linalg_tensorinv.h>
865
+ #include <ATen/ops/linalg_tensorsolve.h>
866
+ #include <ATen/ops/linalg_vander.h>
867
+ #include <ATen/ops/linalg_vecdot.h>
868
+ #include <ATen/ops/linalg_vector_norm.h>
869
+ #include <ATen/ops/linear.h>
870
+ #include <ATen/ops/linear_backward.h>
871
+ #include <ATen/ops/linspace.h>
872
+ #include <ATen/ops/log.h>
873
+ #include <ATen/ops/log10.h>
874
+ #include <ATen/ops/log1p.h>
875
+ #include <ATen/ops/log2.h>
876
+ #include <ATen/ops/log_normal.h>
877
+ #include <ATen/ops/log_sigmoid.h>
878
+ #include <ATen/ops/log_sigmoid_backward.h>
879
+ #include <ATen/ops/log_sigmoid_forward.h>
880
+ #include <ATen/ops/log_softmax.h>
881
+ #include <ATen/ops/logaddexp.h>
882
+ #include <ATen/ops/logaddexp2.h>
883
+ #include <ATen/ops/logcumsumexp.h>
884
+ #include <ATen/ops/logdet.h>
885
+ #include <ATen/ops/logical_and.h>
886
+ #include <ATen/ops/logical_not.h>
887
+ #include <ATen/ops/logical_or.h>
888
+ #include <ATen/ops/logical_xor.h>
889
+ #include <ATen/ops/logit.h>
890
+ #include <ATen/ops/logit_backward.h>
891
+ #include <ATen/ops/logspace.h>
892
+ #include <ATen/ops/logsumexp.h>
893
+ #include <ATen/ops/lshift.h>
894
+ #include <ATen/ops/lstm.h>
895
+ #include <ATen/ops/lstm_cell.h>
896
+ #include <ATen/ops/lstm_mps_backward.h>
897
+ #include <ATen/ops/lt.h>
898
+ #include <ATen/ops/lu_solve.h>
899
+ #include <ATen/ops/lu_unpack.h>
900
+ #include <ATen/ops/mH.h>
901
+ #include <ATen/ops/mT.h>
902
+ #include <ATen/ops/margin_ranking_loss.h>
903
+ #include <ATen/ops/masked_fill.h>
904
+ #include <ATen/ops/masked_scatter.h>
905
+ #include <ATen/ops/masked_scatter_backward.h>
906
+ #include <ATen/ops/masked_select.h>
907
+ #include <ATen/ops/masked_select_backward.h>
908
+ #include <ATen/ops/matmul.h>
909
+ #include <ATen/ops/matmul_backward.h>
910
+ #include <ATen/ops/matrix_H.h>
911
+ #include <ATen/ops/matrix_exp.h>
912
+ #include <ATen/ops/matrix_exp_backward.h>
913
+ #include <ATen/ops/matrix_power.h>
914
+ #include <ATen/ops/max.h>
915
+ #include <ATen/ops/max_pool1d.h>
916
+ #include <ATen/ops/max_pool1d_with_indices.h>
917
+ #include <ATen/ops/max_pool2d.h>
918
+ #include <ATen/ops/max_pool2d_backward.h>
919
+ #include <ATen/ops/max_pool2d_with_indices.h>
920
+ #include <ATen/ops/max_pool2d_with_indices_backward.h>
921
+ #include <ATen/ops/max_pool3d.h>
922
+ #include <ATen/ops/max_pool3d_with_indices.h>
923
+ #include <ATen/ops/max_pool3d_with_indices_backward.h>
924
+ #include <ATen/ops/max_unpool2d.h>
925
+ #include <ATen/ops/max_unpool3d.h>
926
+ #include <ATen/ops/maximum.h>
927
+ #include <ATen/ops/mean.h>
928
+ #include <ATen/ops/median.h>
929
+ #include <ATen/ops/meshgrid.h>
930
+ #include <ATen/ops/min.h>
931
+ #include <ATen/ops/minimum.h>
932
+ #include <ATen/ops/miopen_batch_norm.h>
933
+ #include <ATen/ops/miopen_batch_norm_backward.h>
934
+ #include <ATen/ops/miopen_convolution.h>
935
+ #include <ATen/ops/miopen_convolution_add_relu.h>
936
+ #include <ATen/ops/miopen_convolution_relu.h>
937
+ #include <ATen/ops/miopen_convolution_transpose.h>
938
+ #include <ATen/ops/miopen_depthwise_convolution.h>
939
+ #include <ATen/ops/miopen_rnn.h>
940
+ #include <ATen/ops/miopen_rnn_backward.h>
941
+ #include <ATen/ops/mish.h>
942
+ #include <ATen/ops/mish_backward.h>
943
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d.h>
944
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward.h>
945
+ #include <ATen/ops/mkldnn_convolution.h>
946
+ #include <ATen/ops/mkldnn_linear.h>
947
+ #include <ATen/ops/mkldnn_linear_backward.h>
948
+ #include <ATen/ops/mkldnn_linear_backward_input.h>
949
+ #include <ATen/ops/mkldnn_linear_backward_weights.h>
950
+ #include <ATen/ops/mkldnn_max_pool2d.h>
951
+ #include <ATen/ops/mkldnn_max_pool2d_backward.h>
952
+ #include <ATen/ops/mkldnn_max_pool3d.h>
953
+ #include <ATen/ops/mkldnn_max_pool3d_backward.h>
954
+ #include <ATen/ops/mkldnn_reorder_conv2d_weight.h>
955
+ #include <ATen/ops/mkldnn_reorder_conv3d_weight.h>
956
+ #include <ATen/ops/mkldnn_rnn_layer.h>
957
+ #include <ATen/ops/mkldnn_rnn_layer_backward.h>
958
+ #include <ATen/ops/mm.h>
959
+ #include <ATen/ops/mode.h>
960
+ #include <ATen/ops/moveaxis.h>
961
+ #include <ATen/ops/movedim.h>
962
+ #include <ATen/ops/mps_convolution_backward.h>
963
+ #include <ATen/ops/mps_convolution_transpose_backward.h>
964
+ #include <ATen/ops/mse_loss.h>
965
+ #include <ATen/ops/mse_loss_backward.h>
966
+ #include <ATen/ops/msort.h>
967
+ #include <ATen/ops/mul.h>
968
+ #include <ATen/ops/multi_margin_loss.h>
969
+ #include <ATen/ops/multi_margin_loss_backward.h>
970
+ #include <ATen/ops/multilabel_margin_loss.h>
971
+ #include <ATen/ops/multilabel_margin_loss_backward.h>
972
+ #include <ATen/ops/multilabel_margin_loss_forward.h>
973
+ #include <ATen/ops/multinomial.h>
974
+ #include <ATen/ops/multiply.h>
975
+ #include <ATen/ops/mv.h>
976
+ #include <ATen/ops/mvlgamma.h>
977
+ #include <ATen/ops/nan_to_num.h>
978
+ #include <ATen/ops/nanmean.h>
979
+ #include <ATen/ops/nanmedian.h>
980
+ #include <ATen/ops/nanquantile.h>
981
+ #include <ATen/ops/nansum.h>
982
+ #include <ATen/ops/narrow.h>
983
+ #include <ATen/ops/narrow_copy.h>
984
+ #include <ATen/ops/native_batch_norm.h>
985
+ #include <ATen/ops/native_batch_norm_backward.h>
986
+ #include <ATen/ops/native_channel_shuffle.h>
987
+ #include <ATen/ops/native_dropout.h>
988
+ #include <ATen/ops/native_dropout_backward.h>
989
+ #include <ATen/ops/native_group_norm.h>
990
+ #include <ATen/ops/native_group_norm_backward.h>
991
+ #include <ATen/ops/native_layer_norm.h>
992
+ #include <ATen/ops/native_layer_norm_backward.h>
993
+ #include <ATen/ops/native_norm.h>
994
+ #include <ATen/ops/ne.h>
995
+ #include <ATen/ops/neg.h>
996
+ #include <ATen/ops/negative.h>
997
+ #include <ATen/ops/nested_to_padded_tensor.h>
998
+ #include <ATen/ops/new_empty.h>
999
+ #include <ATen/ops/new_empty_strided.h>
1000
+ #include <ATen/ops/new_full.h>
1001
+ #include <ATen/ops/new_ones.h>
1002
+ #include <ATen/ops/new_zeros.h>
1003
+ #include <ATen/ops/nextafter.h>
1004
+ #include <ATen/ops/nll_loss.h>
1005
+ #include <ATen/ops/nll_loss2d.h>
1006
+ #include <ATen/ops/nll_loss2d_backward.h>
1007
+ #include <ATen/ops/nll_loss2d_forward.h>
1008
+ #include <ATen/ops/nll_loss_backward.h>
1009
+ #include <ATen/ops/nll_loss_forward.h>
1010
+ #include <ATen/ops/nll_loss_nd.h>
1011
+ #include <ATen/ops/nonzero.h>
1012
+ #include <ATen/ops/nonzero_numpy.h>
1013
+ #include <ATen/ops/nonzero_static.h>
1014
+ #include <ATen/ops/norm.h>
1015
+ #include <ATen/ops/norm_except_dim.h>
1016
+ #include <ATen/ops/normal.h>
1017
+ #include <ATen/ops/not_equal.h>
1018
+ #include <ATen/ops/nuclear_norm.h>
1019
+ #include <ATen/ops/numpy_T.h>
1020
+ #include <ATen/ops/one_hot.h>
1021
+ #include <ATen/ops/ones.h>
1022
+ #include <ATen/ops/ones_like.h>
1023
+ #include <ATen/ops/or.h>
1024
+ #include <ATen/ops/orgqr.h>
1025
+ #include <ATen/ops/ormqr.h>
1026
+ #include <ATen/ops/outer.h>
1027
+ #include <ATen/ops/output_nr.h>
1028
+ #include <ATen/ops/pad.h>
1029
+ #include <ATen/ops/pad_sequence.h>
1030
+ #include <ATen/ops/pairwise_distance.h>
1031
+ #include <ATen/ops/pdist.h>
1032
+ #include <ATen/ops/permute.h>
1033
+ #include <ATen/ops/permute_copy.h>
1034
+ #include <ATen/ops/pin_memory.h>
1035
+ #include <ATen/ops/pinverse.h>
1036
+ #include <ATen/ops/pixel_shuffle.h>
1037
+ #include <ATen/ops/pixel_unshuffle.h>
1038
+ #include <ATen/ops/poisson.h>
1039
+ #include <ATen/ops/poisson_nll_loss.h>
1040
+ #include <ATen/ops/polar.h>
1041
+ #include <ATen/ops/polygamma.h>
1042
+ #include <ATen/ops/positive.h>
1043
+ #include <ATen/ops/pow.h>
1044
+ #include <ATen/ops/prelu.h>
1045
+ #include <ATen/ops/prod.h>
1046
+ #include <ATen/ops/promote_types.h>
1047
+ #include <ATen/ops/put.h>
1048
+ #include <ATen/ops/q_per_channel_axis.h>
1049
+ #include <ATen/ops/q_per_channel_scales.h>
1050
+ #include <ATen/ops/q_per_channel_zero_points.h>
1051
+ #include <ATen/ops/q_scale.h>
1052
+ #include <ATen/ops/q_zero_point.h>
1053
+ #include <ATen/ops/qr.h>
1054
+ #include <ATen/ops/qscheme.h>
1055
+ #include <ATen/ops/quantile.h>
1056
+ #include <ATen/ops/quantize_per_channel.h>
1057
+ #include <ATen/ops/quantize_per_tensor.h>
1058
+ #include <ATen/ops/quantize_per_tensor_dynamic.h>
1059
+ #include <ATen/ops/quantized_batch_norm.h>
1060
+ #include <ATen/ops/quantized_gru_cell.h>
1061
+ #include <ATen/ops/quantized_lstm_cell.h>
1062
+ #include <ATen/ops/quantized_max_pool1d.h>
1063
+ #include <ATen/ops/quantized_max_pool2d.h>
1064
+ #include <ATen/ops/quantized_max_pool3d.h>
1065
+ #include <ATen/ops/quantized_rnn_relu_cell.h>
1066
+ #include <ATen/ops/quantized_rnn_tanh_cell.h>
1067
+ #include <ATen/ops/rad2deg.h>
1068
+ #include <ATen/ops/rand.h>
1069
+ #include <ATen/ops/rand_like.h>
1070
+ #include <ATen/ops/randint.h>
1071
+ #include <ATen/ops/randint_like.h>
1072
+ #include <ATen/ops/randn.h>
1073
+ #include <ATen/ops/randn_like.h>
1074
+ #include <ATen/ops/random.h>
1075
+ #include <ATen/ops/randperm.h>
1076
+ #include <ATen/ops/range.h>
1077
+ #include <ATen/ops/ravel.h>
1078
+ #include <ATen/ops/real.h>
1079
+ #include <ATen/ops/reciprocal.h>
1080
+ #include <ATen/ops/record_stream.h>
1081
+ #include <ATen/ops/refine_names.h>
1082
+ #include <ATen/ops/reflection_pad1d.h>
1083
+ #include <ATen/ops/reflection_pad1d_backward.h>
1084
+ #include <ATen/ops/reflection_pad2d.h>
1085
+ #include <ATen/ops/reflection_pad2d_backward.h>
1086
+ #include <ATen/ops/reflection_pad3d.h>
1087
+ #include <ATen/ops/reflection_pad3d_backward.h>
1088
+ #include <ATen/ops/relu.h>
1089
+ #include <ATen/ops/relu6.h>
1090
+ #include <ATen/ops/remainder.h>
1091
+ #include <ATen/ops/rename.h>
1092
+ #include <ATen/ops/renorm.h>
1093
+ #include <ATen/ops/repeat.h>
1094
+ #include <ATen/ops/repeat_interleave.h>
1095
+ #include <ATen/ops/replication_pad1d.h>
1096
+ #include <ATen/ops/replication_pad1d_backward.h>
1097
+ #include <ATen/ops/replication_pad2d.h>
1098
+ #include <ATen/ops/replication_pad2d_backward.h>
1099
+ #include <ATen/ops/replication_pad3d.h>
1100
+ #include <ATen/ops/replication_pad3d_backward.h>
1101
+ #include <ATen/ops/requires_grad.h>
1102
+ #include <ATen/ops/reshape.h>
1103
+ #include <ATen/ops/reshape_as.h>
1104
+ #include <ATen/ops/resize.h>
1105
+ #include <ATen/ops/resize_as.h>
1106
+ #include <ATen/ops/resize_as_sparse.h>
1107
+ #include <ATen/ops/resolve_conj.h>
1108
+ #include <ATen/ops/resolve_neg.h>
1109
+ #include <ATen/ops/result_type.h>
1110
+ #include <ATen/ops/retain_grad.h>
1111
+ #include <ATen/ops/retains_grad.h>
1112
+ #include <ATen/ops/rnn_relu.h>
1113
+ #include <ATen/ops/rnn_relu_cell.h>
1114
+ #include <ATen/ops/rnn_tanh.h>
1115
+ #include <ATen/ops/rnn_tanh_cell.h>
1116
+ #include <ATen/ops/roll.h>
1117
+ #include <ATen/ops/rot90.h>
1118
+ #include <ATen/ops/round.h>
1119
+ #include <ATen/ops/row_indices.h>
1120
+ #include <ATen/ops/row_indices_copy.h>
1121
+ #include <ATen/ops/row_stack.h>
1122
+ #include <ATen/ops/rrelu.h>
1123
+ #include <ATen/ops/rrelu_with_noise.h>
1124
+ #include <ATen/ops/rrelu_with_noise_backward.h>
1125
+ #include <ATen/ops/rshift.h>
1126
+ #include <ATen/ops/rsqrt.h>
1127
+ #include <ATen/ops/rsub.h>
1128
+ #include <ATen/ops/scalar_tensor.h>
1129
+ #include <ATen/ops/scaled_dot_product_attention.h>
1130
+ #include <ATen/ops/scatter.h>
1131
+ #include <ATen/ops/scatter_add.h>
1132
+ #include <ATen/ops/scatter_reduce.h>
1133
+ #include <ATen/ops/searchsorted.h>
1134
+ #include <ATen/ops/segment_reduce.h>
1135
+ #include <ATen/ops/select.h>
1136
+ #include <ATen/ops/select_backward.h>
1137
+ #include <ATen/ops/select_copy.h>
1138
+ #include <ATen/ops/select_scatter.h>
1139
+ #include <ATen/ops/selu.h>
1140
+ #include <ATen/ops/set.h>
1141
+ #include <ATen/ops/set_data.h>
1142
+ #include <ATen/ops/sgn.h>
1143
+ #include <ATen/ops/sigmoid.h>
1144
+ #include <ATen/ops/sigmoid_backward.h>
1145
+ #include <ATen/ops/sign.h>
1146
+ #include <ATen/ops/signbit.h>
1147
+ #include <ATen/ops/silu.h>
1148
+ #include <ATen/ops/silu_backward.h>
1149
+ #include <ATen/ops/sin.h>
1150
+ #include <ATen/ops/sinc.h>
1151
+ #include <ATen/ops/sinh.h>
1152
+ #include <ATen/ops/size.h>
1153
+ #include <ATen/ops/slice.h>
1154
+ #include <ATen/ops/slice_backward.h>
1155
+ #include <ATen/ops/slice_copy.h>
1156
+ #include <ATen/ops/slice_inverse.h>
1157
+ #include <ATen/ops/slice_scatter.h>
1158
+ #include <ATen/ops/slogdet.h>
1159
+ #include <ATen/ops/slow_conv3d.h>
1160
+ #include <ATen/ops/slow_conv3d_forward.h>
1161
+ #include <ATen/ops/slow_conv_dilated2d.h>
1162
+ #include <ATen/ops/slow_conv_dilated3d.h>
1163
+ #include <ATen/ops/slow_conv_transpose2d.h>
1164
+ #include <ATen/ops/slow_conv_transpose3d.h>
1165
+ #include <ATen/ops/smm.h>
1166
+ #include <ATen/ops/smooth_l1_loss.h>
1167
+ #include <ATen/ops/smooth_l1_loss_backward.h>
1168
+ #include <ATen/ops/soft_margin_loss.h>
1169
+ #include <ATen/ops/soft_margin_loss_backward.h>
1170
+ #include <ATen/ops/softmax.h>
1171
+ #include <ATen/ops/softplus.h>
1172
+ #include <ATen/ops/softplus_backward.h>
1173
+ #include <ATen/ops/softshrink.h>
1174
+ #include <ATen/ops/softshrink_backward.h>
1175
+ #include <ATen/ops/sort.h>
1176
+ #include <ATen/ops/sparse_bsc_tensor.h>
1177
+ #include <ATen/ops/sparse_bsr_tensor.h>
1178
+ #include <ATen/ops/sparse_compressed_tensor.h>
1179
+ #include <ATen/ops/sparse_coo_tensor.h>
1180
+ #include <ATen/ops/sparse_csc_tensor.h>
1181
+ #include <ATen/ops/sparse_csr_tensor.h>
1182
+ #include <ATen/ops/sparse_dim.h>
1183
+ #include <ATen/ops/sparse_mask.h>
1184
+ #include <ATen/ops/sparse_resize.h>
1185
+ #include <ATen/ops/sparse_resize_and_clear.h>
1186
+ #include <ATen/ops/sparse_sampled_addmm.h>
1187
+ #include <ATen/ops/special_airy_ai.h>
1188
+ #include <ATen/ops/special_bessel_j0.h>
1189
+ #include <ATen/ops/special_bessel_j1.h>
1190
+ #include <ATen/ops/special_bessel_y0.h>
1191
+ #include <ATen/ops/special_bessel_y1.h>
1192
+ #include <ATen/ops/special_chebyshev_polynomial_t.h>
1193
+ #include <ATen/ops/special_chebyshev_polynomial_u.h>
1194
+ #include <ATen/ops/special_chebyshev_polynomial_v.h>
1195
+ #include <ATen/ops/special_chebyshev_polynomial_w.h>
1196
+ #include <ATen/ops/special_digamma.h>
1197
+ #include <ATen/ops/special_entr.h>
1198
+ #include <ATen/ops/special_erf.h>
1199
+ #include <ATen/ops/special_erfc.h>
1200
+ #include <ATen/ops/special_erfcx.h>
1201
+ #include <ATen/ops/special_erfinv.h>
1202
+ #include <ATen/ops/special_exp2.h>
1203
+ #include <ATen/ops/special_expit.h>
1204
+ #include <ATen/ops/special_expm1.h>
1205
+ #include <ATen/ops/special_gammainc.h>
1206
+ #include <ATen/ops/special_gammaincc.h>
1207
+ #include <ATen/ops/special_gammaln.h>
1208
+ #include <ATen/ops/special_hermite_polynomial_h.h>
1209
+ #include <ATen/ops/special_hermite_polynomial_he.h>
1210
+ #include <ATen/ops/special_i0.h>
1211
+ #include <ATen/ops/special_i0e.h>
1212
+ #include <ATen/ops/special_i1.h>
1213
+ #include <ATen/ops/special_i1e.h>
1214
+ #include <ATen/ops/special_laguerre_polynomial_l.h>
1215
+ #include <ATen/ops/special_legendre_polynomial_p.h>
1216
+ #include <ATen/ops/special_log1p.h>
1217
+ #include <ATen/ops/special_log_ndtr.h>
1218
+ #include <ATen/ops/special_log_softmax.h>
1219
+ #include <ATen/ops/special_logit.h>
1220
+ #include <ATen/ops/special_logsumexp.h>
1221
+ #include <ATen/ops/special_modified_bessel_i0.h>
1222
+ #include <ATen/ops/special_modified_bessel_i1.h>
1223
+ #include <ATen/ops/special_modified_bessel_k0.h>
1224
+ #include <ATen/ops/special_modified_bessel_k1.h>
1225
+ #include <ATen/ops/special_multigammaln.h>
1226
+ #include <ATen/ops/special_ndtr.h>
1227
+ #include <ATen/ops/special_ndtri.h>
1228
+ #include <ATen/ops/special_polygamma.h>
1229
+ #include <ATen/ops/special_psi.h>
1230
+ #include <ATen/ops/special_round.h>
1231
+ #include <ATen/ops/special_scaled_modified_bessel_k0.h>
1232
+ #include <ATen/ops/special_scaled_modified_bessel_k1.h>
1233
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_t.h>
1234
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_u.h>
1235
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_v.h>
1236
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_w.h>
1237
+ #include <ATen/ops/special_sinc.h>
1238
+ #include <ATen/ops/special_softmax.h>
1239
+ #include <ATen/ops/special_spherical_bessel_j0.h>
1240
+ #include <ATen/ops/special_xlog1py.h>
1241
+ #include <ATen/ops/special_xlogy.h>
1242
+ #include <ATen/ops/special_zeta.h>
1243
+ #include <ATen/ops/split.h>
1244
+ #include <ATen/ops/split_copy.h>
1245
+ #include <ATen/ops/split_with_sizes.h>
1246
+ #include <ATen/ops/split_with_sizes_copy.h>
1247
+ #include <ATen/ops/sqrt.h>
1248
+ #include <ATen/ops/square.h>
1249
+ #include <ATen/ops/squeeze.h>
1250
+ #include <ATen/ops/squeeze_copy.h>
1251
+ #include <ATen/ops/sspaddmm.h>
1252
+ #include <ATen/ops/stack.h>
1253
+ #include <ATen/ops/std.h>
1254
+ #include <ATen/ops/std_mean.h>
1255
+ #include <ATen/ops/stft.h>
1256
+ #include <ATen/ops/stride.h>
1257
+ #include <ATen/ops/sub.h>
1258
+ #include <ATen/ops/subtract.h>
1259
+ #include <ATen/ops/sum.h>
1260
+ #include <ATen/ops/sum_to_size.h>
1261
+ #include <ATen/ops/svd.h>
1262
+ #include <ATen/ops/swapaxes.h>
1263
+ #include <ATen/ops/swapdims.h>
1264
+ #include <ATen/ops/sym_constrain_range.h>
1265
+ #include <ATen/ops/sym_constrain_range_for_size.h>
1266
+ #include <ATen/ops/sym_numel.h>
1267
+ #include <ATen/ops/sym_size.h>
1268
+ #include <ATen/ops/sym_storage_offset.h>
1269
+ #include <ATen/ops/sym_stride.h>
1270
+ #include <ATen/ops/t.h>
1271
+ #include <ATen/ops/t_copy.h>
1272
+ #include <ATen/ops/take.h>
1273
+ #include <ATen/ops/take_along_dim.h>
1274
+ #include <ATen/ops/tan.h>
1275
+ #include <ATen/ops/tanh.h>
1276
+ #include <ATen/ops/tanh_backward.h>
1277
+ #include <ATen/ops/tensor_split.h>
1278
+ #include <ATen/ops/tensordot.h>
1279
+ #include <ATen/ops/thnn_conv2d.h>
1280
+ #include <ATen/ops/threshold.h>
1281
+ #include <ATen/ops/threshold_backward.h>
1282
+ #include <ATen/ops/tile.h>
1283
+ #include <ATen/ops/to.h>
1284
+ #include <ATen/ops/to_dense.h>
1285
+ #include <ATen/ops/to_dense_backward.h>
1286
+ #include <ATen/ops/to_mkldnn.h>
1287
+ #include <ATen/ops/to_mkldnn_backward.h>
1288
+ #include <ATen/ops/to_padded_tensor.h>
1289
+ #include <ATen/ops/to_sparse.h>
1290
+ #include <ATen/ops/to_sparse_bsc.h>
1291
+ #include <ATen/ops/to_sparse_bsr.h>
1292
+ #include <ATen/ops/to_sparse_csc.h>
1293
+ #include <ATen/ops/to_sparse_csr.h>
1294
+ #include <ATen/ops/topk.h>
1295
+ #include <ATen/ops/trace.h>
1296
+ #include <ATen/ops/trace_backward.h>
1297
+ #include <ATen/ops/transpose.h>
1298
+ #include <ATen/ops/transpose_copy.h>
1299
+ #include <ATen/ops/trapezoid.h>
1300
+ #include <ATen/ops/trapz.h>
1301
+ #include <ATen/ops/triangular_solve.h>
1302
+ #include <ATen/ops/tril.h>
1303
+ #include <ATen/ops/tril_indices.h>
1304
+ #include <ATen/ops/triplet_margin_loss.h>
1305
+ #include <ATen/ops/triu.h>
1306
+ #include <ATen/ops/triu_indices.h>
1307
+ #include <ATen/ops/true_divide.h>
1308
+ #include <ATen/ops/trunc.h>
1309
+ #include <ATen/ops/type_as.h>
1310
+ #include <ATen/ops/unbind.h>
1311
+ #include <ATen/ops/unbind_copy.h>
1312
+ #include <ATen/ops/unflatten.h>
1313
+ #include <ATen/ops/unflatten_dense_tensors.h>
1314
+ #include <ATen/ops/unfold.h>
1315
+ #include <ATen/ops/unfold_backward.h>
1316
+ #include <ATen/ops/unfold_copy.h>
1317
+ #include <ATen/ops/uniform.h>
1318
+ #include <ATen/ops/unique_consecutive.h>
1319
+ #include <ATen/ops/unique_dim.h>
1320
+ #include <ATen/ops/unique_dim_consecutive.h>
1321
+ #include <ATen/ops/unsafe_chunk.h>
1322
+ #include <ATen/ops/unsafe_split.h>
1323
+ #include <ATen/ops/unsafe_split_with_sizes.h>
1324
+ #include <ATen/ops/unsqueeze.h>
1325
+ #include <ATen/ops/unsqueeze_copy.h>
1326
+ #include <ATen/ops/upsample_bicubic2d.h>
1327
+ #include <ATen/ops/upsample_bicubic2d_backward.h>
1328
+ #include <ATen/ops/upsample_bilinear2d.h>
1329
+ #include <ATen/ops/upsample_bilinear2d_backward.h>
1330
+ #include <ATen/ops/upsample_linear1d.h>
1331
+ #include <ATen/ops/upsample_linear1d_backward.h>
1332
+ #include <ATen/ops/upsample_nearest1d.h>
1333
+ #include <ATen/ops/upsample_nearest1d_backward.h>
1334
+ #include <ATen/ops/upsample_nearest2d.h>
1335
+ #include <ATen/ops/upsample_nearest2d_backward.h>
1336
+ #include <ATen/ops/upsample_nearest3d.h>
1337
+ #include <ATen/ops/upsample_nearest3d_backward.h>
1338
+ #include <ATen/ops/upsample_trilinear3d.h>
1339
+ #include <ATen/ops/upsample_trilinear3d_backward.h>
1340
+ #include <ATen/ops/value_selecting_reduction_backward.h>
1341
+ #include <ATen/ops/values.h>
1342
+ #include <ATen/ops/values_copy.h>
1343
+ #include <ATen/ops/vander.h>
1344
+ #include <ATen/ops/var.h>
1345
+ #include <ATen/ops/var_mean.h>
1346
+ #include <ATen/ops/vdot.h>
1347
+ #include <ATen/ops/view.h>
1348
+ #include <ATen/ops/view_as.h>
1349
+ #include <ATen/ops/view_as_complex.h>
1350
+ #include <ATen/ops/view_as_complex_copy.h>
1351
+ #include <ATen/ops/view_as_real.h>
1352
+ #include <ATen/ops/view_as_real_copy.h>
1353
+ #include <ATen/ops/view_copy.h>
1354
+ #include <ATen/ops/vsplit.h>
1355
+ #include <ATen/ops/vstack.h>
1356
+ #include <ATen/ops/where.h>
1357
+ #include <ATen/ops/xlogy.h>
1358
+ #include <ATen/ops/xor.h>
1359
+ #include <ATen/ops/zero.h>
1360
+ #include <ATen/ops/zeros.h>
1361
+ #include <ATen/ops/zeros_like.h>
1362
+
1363
+ namespace at {
1364
+
1365
+
1366
+
1367
+ // Special C++ only overloads for std()-like functions (See gh-40287)
1368
+ // These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
1369
+ // So, for example std(0) would select the std(unbiased=False) overload
1370
+ TORCH_API inline Tensor var(const Tensor& self, int dim) {
1371
+ return at::var(self, IntArrayRef{dim});
1372
+ }
1373
+ TORCH_API inline std::tuple<Tensor, Tensor> var_mean(const Tensor& self, int dim) {
1374
+ return at::var_mean(self, IntArrayRef{dim});
1375
+ }
1376
+ TORCH_API inline Tensor std(const Tensor& self, int dim) {
1377
+ return at::std(self, IntArrayRef{dim});
1378
+ }
1379
+ TORCH_API inline std::tuple<Tensor, Tensor> std_mean(const Tensor& self, int dim) {
1380
+ return at::std_mean(self, IntArrayRef{dim});
1381
+ }
1382
+
1383
+ inline int64_t numel(const Tensor& tensor) {
1384
+ return tensor.numel();
1385
+ }
1386
+
1387
+ inline int64_t size(const Tensor& tensor, int64_t dim) {
1388
+ return tensor.size(dim);
1389
+ }
1390
+
1391
+ inline int64_t stride(const Tensor& tensor, int64_t dim) {
1392
+ return tensor.stride(dim);
1393
+ }
1394
+
1395
+ inline bool is_complex(const Tensor& tensor) {
1396
+ return tensor.is_complex();
1397
+ }
1398
+
1399
+ inline bool is_floating_point(const Tensor& tensor) {
1400
+ return tensor.is_floating_point();
1401
+ }
1402
+
1403
+ inline bool is_signed(const Tensor& tensor) {
1404
+ return tensor.is_signed();
1405
+ }
1406
+
1407
+ inline bool is_inference(const Tensor& tensor) {
1408
+ return tensor.is_inference();
1409
+ }
1410
+
1411
+ inline bool _is_zerotensor(const Tensor& tensor) {
1412
+ return tensor._is_zerotensor();
1413
+ }
1414
+
1415
+ inline bool is_conj(const Tensor& tensor) {
1416
+ return tensor.is_conj();
1417
+ }
1418
+
1419
+ inline Tensor conj(const Tensor& tensor) {
1420
+ return tensor.conj();
1421
+ }
1422
+
1423
+ inline bool is_neg(const Tensor& tensor) {
1424
+ return tensor.is_neg();
1425
+ }
1426
+
1427
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MemoryOverlap.h ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Export.h>
4
+
5
+ namespace c10 {
6
+ struct TensorImpl;
7
+ }
8
+
9
+ namespace at {
10
+ class TensorBase;
11
+
12
+ // MemOverlap: Whether or not there is memory overlap
13
+ //
14
+ // No: Absolutely no memory overlap
15
+ // Yes: Absolutely yes memory overlap
16
+ // TooHard: There might be memory overlap, but it was too expensive to compute.
17
+ //
18
+ // NB: Please update the python test for these if you renumber them.
19
+ enum class MemOverlap { No, Yes, TooHard };
20
+
21
+ enum class MemOverlapStatus { Full, Partial, No, TooHard };
22
+
23
+ TORCH_API MemOverlap has_internal_overlap(const TensorBase& t);
24
+ TORCH_API MemOverlap has_internal_overlap(c10::TensorImpl* t);
25
+
26
+ TORCH_API void assert_no_internal_overlap(const TensorBase& t);
27
+ TORCH_API void assert_no_internal_overlap(c10::TensorImpl* t);
28
+
29
+ TORCH_API MemOverlapStatus
30
+ get_overlap_status(const TensorBase& a, const TensorBase& b);
31
+ TORCH_API MemOverlapStatus
32
+ get_overlap_status(const c10::TensorImpl* a, const c10::TensorImpl* b);
33
+
34
+ TORCH_API void assert_no_partial_overlap(
35
+ const TensorBase& a,
36
+ const TensorBase& b);
37
+ void assert_no_partial_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
38
+
39
+ TORCH_API void assert_no_overlap(const TensorBase& a, const TensorBase& b);
40
+ TORCH_API void assert_no_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
41
+
42
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NativeMetaFunctions.h ADDED
@@ -0,0 +1,1303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeMetaFunctions.h
4
+
5
+ #include <ATen/core/Tensor.h>
6
+ #include <ATen/core/IListRef.h>
7
+ #include <ATen/TensorMeta.h>
8
+ #include <ATen/TensorIterator.h>
9
+
10
+ #include <ATen/ops/_adaptive_avg_pool2d_meta.h>
11
+ #include <ATen/ops/_adaptive_avg_pool2d_backward_meta.h>
12
+ #include <ATen/ops/_adaptive_avg_pool3d_meta.h>
13
+ #include <ATen/ops/_adaptive_avg_pool3d_backward_meta.h>
14
+ #include <ATen/ops/_add_batch_dim_meta.h>
15
+ #include <ATen/ops/_add_relu_meta.h>
16
+ #include <ATen/ops/_addmm_activation_meta.h>
17
+ #include <ATen/ops/_aminmax_meta.h>
18
+ #include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_meta.h>
19
+ #include <ATen/ops/_amp_update_scale_meta.h>
20
+ #include <ATen/ops/_assert_async_meta.h>
21
+ #include <ATen/ops/_assert_scalar_meta.h>
22
+ #include <ATen/ops/_assert_tensor_metadata_meta.h>
23
+ #include <ATen/ops/_autocast_to_full_precision_meta.h>
24
+ #include <ATen/ops/_autocast_to_reduced_precision_meta.h>
25
+ #include <ATen/ops/_backward_meta.h>
26
+ #include <ATen/ops/_batch_norm_impl_index_meta.h>
27
+ #include <ATen/ops/_batch_norm_impl_index_backward_meta.h>
28
+ #include <ATen/ops/_cast_Byte_meta.h>
29
+ #include <ATen/ops/_cast_Char_meta.h>
30
+ #include <ATen/ops/_cast_Double_meta.h>
31
+ #include <ATen/ops/_cast_Float_meta.h>
32
+ #include <ATen/ops/_cast_Half_meta.h>
33
+ #include <ATen/ops/_cast_Int_meta.h>
34
+ #include <ATen/ops/_cast_Long_meta.h>
35
+ #include <ATen/ops/_cast_Short_meta.h>
36
+ #include <ATen/ops/_cdist_backward_meta.h>
37
+ #include <ATen/ops/_cdist_forward_meta.h>
38
+ #include <ATen/ops/_cholesky_solve_helper_meta.h>
39
+ #include <ATen/ops/_choose_qparams_per_tensor_meta.h>
40
+ #include <ATen/ops/_chunk_cat_meta.h>
41
+ #include <ATen/ops/_coalesce_meta.h>
42
+ #include <ATen/ops/_coalesced_meta.h>
43
+ #include <ATen/ops/_compute_linear_combination_meta.h>
44
+ #include <ATen/ops/_conj_meta.h>
45
+ #include <ATen/ops/_conj_copy_meta.h>
46
+ #include <ATen/ops/_conj_physical_meta.h>
47
+ #include <ATen/ops/_conv_depthwise2d_meta.h>
48
+ #include <ATen/ops/_convert_indices_from_coo_to_csr_meta.h>
49
+ #include <ATen/ops/_convert_indices_from_csr_to_coo_meta.h>
50
+ #include <ATen/ops/_convert_weight_to_int4pack_meta.h>
51
+ #include <ATen/ops/_convolution_meta.h>
52
+ #include <ATen/ops/_convolution_double_backward_meta.h>
53
+ #include <ATen/ops/_convolution_mode_meta.h>
54
+ #include <ATen/ops/_copy_from_meta.h>
55
+ #include <ATen/ops/_copy_from_and_resize_meta.h>
56
+ #include <ATen/ops/_cslt_compress_meta.h>
57
+ #include <ATen/ops/_cslt_sparse_mm_meta.h>
58
+ #include <ATen/ops/_cslt_sparse_mm_search_meta.h>
59
+ #include <ATen/ops/_ctc_loss_meta.h>
60
+ #include <ATen/ops/_ctc_loss_backward_meta.h>
61
+ #include <ATen/ops/_cudnn_ctc_loss_meta.h>
62
+ #include <ATen/ops/_cudnn_init_dropout_state_meta.h>
63
+ #include <ATen/ops/_cudnn_rnn_meta.h>
64
+ #include <ATen/ops/_cudnn_rnn_backward_meta.h>
65
+ #include <ATen/ops/_cudnn_rnn_flatten_weight_meta.h>
66
+ #include <ATen/ops/_cufft_clear_plan_cache_meta.h>
67
+ #include <ATen/ops/_cufft_get_plan_cache_max_size_meta.h>
68
+ #include <ATen/ops/_cufft_get_plan_cache_size_meta.h>
69
+ #include <ATen/ops/_cufft_set_plan_cache_max_size_meta.h>
70
+ #include <ATen/ops/_cummax_helper_meta.h>
71
+ #include <ATen/ops/_cummin_helper_meta.h>
72
+ #include <ATen/ops/_debug_has_internal_overlap_meta.h>
73
+ #include <ATen/ops/_dimI_meta.h>
74
+ #include <ATen/ops/_dimV_meta.h>
75
+ #include <ATen/ops/_dim_arange_meta.h>
76
+ #include <ATen/ops/_dirichlet_grad_meta.h>
77
+ #include <ATen/ops/_efficient_attention_backward_meta.h>
78
+ #include <ATen/ops/_efficient_attention_forward_meta.h>
79
+ #include <ATen/ops/_efficientzerotensor_meta.h>
80
+ #include <ATen/ops/_embedding_bag_meta.h>
81
+ #include <ATen/ops/_embedding_bag_backward_meta.h>
82
+ #include <ATen/ops/_embedding_bag_dense_backward_meta.h>
83
+ #include <ATen/ops/_embedding_bag_forward_only_meta.h>
84
+ #include <ATen/ops/_embedding_bag_per_sample_weights_backward_meta.h>
85
+ #include <ATen/ops/_embedding_bag_sparse_backward_meta.h>
86
+ #include <ATen/ops/_empty_affine_quantized_meta.h>
87
+ #include <ATen/ops/_empty_per_channel_affine_quantized_meta.h>
88
+ #include <ATen/ops/_euclidean_dist_meta.h>
89
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_meta.h>
90
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_meta.h>
91
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_meta.h>
92
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_meta.h>
93
+ #include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_meta.h>
94
+ #include <ATen/ops/_fft_c2c_meta.h>
95
+ #include <ATen/ops/_fft_c2r_meta.h>
96
+ #include <ATen/ops/_fft_r2c_meta.h>
97
+ #include <ATen/ops/_fill_mem_eff_dropout_mask_meta.h>
98
+ #include <ATen/ops/_flash_attention_backward_meta.h>
99
+ #include <ATen/ops/_flash_attention_forward_meta.h>
100
+ #include <ATen/ops/_foobar_meta.h>
101
+ #include <ATen/ops/_foreach_abs_meta.h>
102
+ #include <ATen/ops/_foreach_acos_meta.h>
103
+ #include <ATen/ops/_foreach_add_meta.h>
104
+ #include <ATen/ops/_foreach_addcdiv_meta.h>
105
+ #include <ATen/ops/_foreach_addcmul_meta.h>
106
+ #include <ATen/ops/_foreach_asin_meta.h>
107
+ #include <ATen/ops/_foreach_atan_meta.h>
108
+ #include <ATen/ops/_foreach_ceil_meta.h>
109
+ #include <ATen/ops/_foreach_clamp_max_meta.h>
110
+ #include <ATen/ops/_foreach_clamp_min_meta.h>
111
+ #include <ATen/ops/_foreach_copy_meta.h>
112
+ #include <ATen/ops/_foreach_cos_meta.h>
113
+ #include <ATen/ops/_foreach_cosh_meta.h>
114
+ #include <ATen/ops/_foreach_div_meta.h>
115
+ #include <ATen/ops/_foreach_erf_meta.h>
116
+ #include <ATen/ops/_foreach_erfc_meta.h>
117
+ #include <ATen/ops/_foreach_exp_meta.h>
118
+ #include <ATen/ops/_foreach_expm1_meta.h>
119
+ #include <ATen/ops/_foreach_floor_meta.h>
120
+ #include <ATen/ops/_foreach_frac_meta.h>
121
+ #include <ATen/ops/_foreach_lerp_meta.h>
122
+ #include <ATen/ops/_foreach_lgamma_meta.h>
123
+ #include <ATen/ops/_foreach_log_meta.h>
124
+ #include <ATen/ops/_foreach_log10_meta.h>
125
+ #include <ATen/ops/_foreach_log1p_meta.h>
126
+ #include <ATen/ops/_foreach_log2_meta.h>
127
+ #include <ATen/ops/_foreach_maximum_meta.h>
128
+ #include <ATen/ops/_foreach_minimum_meta.h>
129
+ #include <ATen/ops/_foreach_mul_meta.h>
130
+ #include <ATen/ops/_foreach_neg_meta.h>
131
+ #include <ATen/ops/_foreach_norm_meta.h>
132
+ #include <ATen/ops/_foreach_pow_meta.h>
133
+ #include <ATen/ops/_foreach_reciprocal_meta.h>
134
+ #include <ATen/ops/_foreach_round_meta.h>
135
+ #include <ATen/ops/_foreach_sigmoid_meta.h>
136
+ #include <ATen/ops/_foreach_sign_meta.h>
137
+ #include <ATen/ops/_foreach_sin_meta.h>
138
+ #include <ATen/ops/_foreach_sinh_meta.h>
139
+ #include <ATen/ops/_foreach_sqrt_meta.h>
140
+ #include <ATen/ops/_foreach_sub_meta.h>
141
+ #include <ATen/ops/_foreach_tan_meta.h>
142
+ #include <ATen/ops/_foreach_tanh_meta.h>
143
+ #include <ATen/ops/_foreach_trunc_meta.h>
144
+ #include <ATen/ops/_foreach_zero_meta.h>
145
+ #include <ATen/ops/_functional_assert_async_meta.h>
146
+ #include <ATen/ops/_functional_assert_scalar_meta.h>
147
+ #include <ATen/ops/_functional_sym_constrain_range_meta.h>
148
+ #include <ATen/ops/_functional_sym_constrain_range_for_size_meta.h>
149
+ #include <ATen/ops/_fused_adam_meta.h>
150
+ #include <ATen/ops/_fused_adamw_meta.h>
151
+ #include <ATen/ops/_fused_dropout_meta.h>
152
+ #include <ATen/ops/_fused_moving_avg_obs_fq_helper_meta.h>
153
+ #include <ATen/ops/_fused_sdp_choice_meta.h>
154
+ #include <ATen/ops/_fused_sgd_meta.h>
155
+ #include <ATen/ops/_fw_primal_meta.h>
156
+ #include <ATen/ops/_fw_primal_copy_meta.h>
157
+ #include <ATen/ops/_gather_sparse_backward_meta.h>
158
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_meta.h>
159
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_meta.h>
160
+ #include <ATen/ops/_has_compatible_shallow_copy_type_meta.h>
161
+ #include <ATen/ops/_has_same_storage_numel_meta.h>
162
+ #include <ATen/ops/_histogramdd_bin_edges_meta.h>
163
+ #include <ATen/ops/_histogramdd_from_bin_cts_meta.h>
164
+ #include <ATen/ops/_histogramdd_from_bin_tensors_meta.h>
165
+ #include <ATen/ops/_index_put_impl_meta.h>
166
+ #include <ATen/ops/_indices_meta.h>
167
+ #include <ATen/ops/_indices_copy_meta.h>
168
+ #include <ATen/ops/_int_mm_meta.h>
169
+ #include <ATen/ops/_is_all_true_meta.h>
170
+ #include <ATen/ops/_is_any_true_meta.h>
171
+ #include <ATen/ops/_is_zerotensor_meta.h>
172
+ #include <ATen/ops/_lazy_clone_meta.h>
173
+ #include <ATen/ops/_linalg_check_errors_meta.h>
174
+ #include <ATen/ops/_linalg_det_meta.h>
175
+ #include <ATen/ops/_linalg_eigh_meta.h>
176
+ #include <ATen/ops/_linalg_eigvals_meta.h>
177
+ #include <ATen/ops/_linalg_slogdet_meta.h>
178
+ #include <ATen/ops/_linalg_solve_ex_meta.h>
179
+ #include <ATen/ops/_linalg_svd_meta.h>
180
+ #include <ATen/ops/_local_scalar_dense_meta.h>
181
+ #include <ATen/ops/_log_softmax_meta.h>
182
+ #include <ATen/ops/_log_softmax_backward_data_meta.h>
183
+ #include <ATen/ops/_logcumsumexp_meta.h>
184
+ #include <ATen/ops/_lstm_mps_meta.h>
185
+ #include <ATen/ops/_lu_with_info_meta.h>
186
+ #include <ATen/ops/_make_dep_token_meta.h>
187
+ #include <ATen/ops/_make_dual_meta.h>
188
+ #include <ATen/ops/_make_dual_copy_meta.h>
189
+ #include <ATen/ops/_make_per_channel_quantized_tensor_meta.h>
190
+ #include <ATen/ops/_make_per_tensor_quantized_tensor_meta.h>
191
+ #include <ATen/ops/_masked_scale_meta.h>
192
+ #include <ATen/ops/_masked_softmax_meta.h>
193
+ #include <ATen/ops/_masked_softmax_backward_meta.h>
194
+ #include <ATen/ops/_mixed_dtypes_linear_meta.h>
195
+ #include <ATen/ops/_mkldnn_reshape_meta.h>
196
+ #include <ATen/ops/_mkldnn_transpose_meta.h>
197
+ #include <ATen/ops/_mps_convolution_meta.h>
198
+ #include <ATen/ops/_mps_convolution_transpose_meta.h>
199
+ #include <ATen/ops/_native_batch_norm_legit_meta.h>
200
+ #include <ATen/ops/_native_batch_norm_legit_no_training_meta.h>
201
+ #include <ATen/ops/_native_multi_head_attention_meta.h>
202
+ #include <ATen/ops/_neg_view_meta.h>
203
+ #include <ATen/ops/_neg_view_copy_meta.h>
204
+ #include <ATen/ops/_nested_from_padded_meta.h>
205
+ #include <ATen/ops/_nested_from_padded_and_nested_example_meta.h>
206
+ #include <ATen/ops/_nested_get_jagged_dummy_meta.h>
207
+ #include <ATen/ops/_nested_get_lengths_meta.h>
208
+ #include <ATen/ops/_nested_get_offsets_meta.h>
209
+ #include <ATen/ops/_nested_get_ragged_idx_meta.h>
210
+ #include <ATen/ops/_nested_get_values_meta.h>
211
+ #include <ATen/ops/_nested_get_values_copy_meta.h>
212
+ #include <ATen/ops/_nested_select_backward_meta.h>
213
+ #include <ATen/ops/_nested_sum_backward_meta.h>
214
+ #include <ATen/ops/_nested_tensor_from_mask_meta.h>
215
+ #include <ATen/ops/_nested_tensor_from_mask_left_aligned_meta.h>
216
+ #include <ATen/ops/_nested_tensor_from_tensor_list_meta.h>
217
+ #include <ATen/ops/_nested_tensor_size_meta.h>
218
+ #include <ATen/ops/_nested_tensor_softmax_with_shape_meta.h>
219
+ #include <ATen/ops/_nested_tensor_storage_offsets_meta.h>
220
+ #include <ATen/ops/_nested_tensor_strides_meta.h>
221
+ #include <ATen/ops/_nested_view_from_buffer_meta.h>
222
+ #include <ATen/ops/_nested_view_from_buffer_copy_meta.h>
223
+ #include <ATen/ops/_nested_view_from_jagged_meta.h>
224
+ #include <ATen/ops/_nested_view_from_jagged_copy_meta.h>
225
+ #include <ATen/ops/_new_zeros_with_same_feature_meta_meta.h>
226
+ #include <ATen/ops/_nnpack_available_meta.h>
227
+ #include <ATen/ops/_nnpack_spatial_convolution_meta.h>
228
+ #include <ATen/ops/_nnz_meta.h>
229
+ #include <ATen/ops/_pack_padded_sequence_meta.h>
230
+ #include <ATen/ops/_pack_padded_sequence_backward_meta.h>
231
+ #include <ATen/ops/_pad_circular_meta.h>
232
+ #include <ATen/ops/_pad_enum_meta.h>
233
+ #include <ATen/ops/_pad_packed_sequence_meta.h>
234
+ #include <ATen/ops/_pdist_backward_meta.h>
235
+ #include <ATen/ops/_pdist_forward_meta.h>
236
+ #include <ATen/ops/_pin_memory_meta.h>
237
+ #include <ATen/ops/_prelu_kernel_meta.h>
238
+ #include <ATen/ops/_prelu_kernel_backward_meta.h>
239
+ #include <ATen/ops/_print_meta.h>
240
+ #include <ATen/ops/_propagate_xla_data_meta.h>
241
+ #include <ATen/ops/_remove_batch_dim_meta.h>
242
+ #include <ATen/ops/_reshape_alias_meta.h>
243
+ #include <ATen/ops/_reshape_alias_copy_meta.h>
244
+ #include <ATen/ops/_reshape_copy_meta.h>
245
+ #include <ATen/ops/_reshape_from_tensor_meta.h>
246
+ #include <ATen/ops/_resize_output_meta.h>
247
+ #include <ATen/ops/_rowwise_prune_meta.h>
248
+ #include <ATen/ops/_sample_dirichlet_meta.h>
249
+ #include <ATen/ops/_saturate_weight_to_fp16_meta.h>
250
+ #include <ATen/ops/_scaled_dot_product_attention_math_meta.h>
251
+ #include <ATen/ops/_scaled_dot_product_cudnn_attention_meta.h>
252
+ #include <ATen/ops/_scaled_dot_product_efficient_attention_meta.h>
253
+ #include <ATen/ops/_scaled_dot_product_efficient_attention_backward_meta.h>
254
+ #include <ATen/ops/_scaled_dot_product_flash_attention_meta.h>
255
+ #include <ATen/ops/_scaled_dot_product_flash_attention_backward_meta.h>
256
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_meta.h>
257
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_meta.h>
258
+ #include <ATen/ops/_scaled_mm_meta.h>
259
+ #include <ATen/ops/_segment_reduce_backward_meta.h>
260
+ #include <ATen/ops/_shape_as_tensor_meta.h>
261
+ #include <ATen/ops/_slow_conv2d_backward_meta.h>
262
+ #include <ATen/ops/_slow_conv2d_forward_meta.h>
263
+ #include <ATen/ops/_sobol_engine_draw_meta.h>
264
+ #include <ATen/ops/_sobol_engine_ff_meta.h>
265
+ #include <ATen/ops/_sobol_engine_initialize_state_meta.h>
266
+ #include <ATen/ops/_sobol_engine_scramble_meta.h>
267
+ #include <ATen/ops/_softmax_meta.h>
268
+ #include <ATen/ops/_softmax_backward_data_meta.h>
269
+ #include <ATen/ops/_sparse_addmm_meta.h>
270
+ #include <ATen/ops/_sparse_broadcast_to_meta.h>
271
+ #include <ATen/ops/_sparse_broadcast_to_copy_meta.h>
272
+ #include <ATen/ops/_sparse_bsc_tensor_unsafe_meta.h>
273
+ #include <ATen/ops/_sparse_bsr_tensor_unsafe_meta.h>
274
+ #include <ATen/ops/_sparse_compressed_tensor_unsafe_meta.h>
275
+ #include <ATen/ops/_sparse_coo_tensor_unsafe_meta.h>
276
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_meta.h>
277
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta.h>
278
+ #include <ATen/ops/_sparse_csc_tensor_unsafe_meta.h>
279
+ #include <ATen/ops/_sparse_csr_prod_meta.h>
280
+ #include <ATen/ops/_sparse_csr_sum_meta.h>
281
+ #include <ATen/ops/_sparse_csr_tensor_unsafe_meta.h>
282
+ #include <ATen/ops/_sparse_log_softmax_meta.h>
283
+ #include <ATen/ops/_sparse_log_softmax_backward_data_meta.h>
284
+ #include <ATen/ops/_sparse_mask_projection_meta.h>
285
+ #include <ATen/ops/_sparse_mm_meta.h>
286
+ #include <ATen/ops/_sparse_mm_reduce_impl_meta.h>
287
+ #include <ATen/ops/_sparse_mm_reduce_impl_backward_meta.h>
288
+ #include <ATen/ops/_sparse_semi_structured_linear_meta.h>
289
+ #include <ATen/ops/_sparse_softmax_meta.h>
290
+ #include <ATen/ops/_sparse_softmax_backward_data_meta.h>
291
+ #include <ATen/ops/_sparse_sparse_matmul_meta.h>
292
+ #include <ATen/ops/_sparse_sum_meta.h>
293
+ #include <ATen/ops/_sparse_sum_backward_meta.h>
294
+ #include <ATen/ops/_spdiags_meta.h>
295
+ #include <ATen/ops/_stack_meta.h>
296
+ #include <ATen/ops/_standard_gamma_meta.h>
297
+ #include <ATen/ops/_standard_gamma_grad_meta.h>
298
+ #include <ATen/ops/_test_ambiguous_defaults_meta.h>
299
+ #include <ATen/ops/_test_autograd_multiple_dispatch_meta.h>
300
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_meta.h>
301
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_meta.h>
302
+ #include <ATen/ops/_test_check_tensor_meta.h>
303
+ #include <ATen/ops/_test_functorch_fallback_meta.h>
304
+ #include <ATen/ops/_test_optional_filled_intlist_meta.h>
305
+ #include <ATen/ops/_test_optional_floatlist_meta.h>
306
+ #include <ATen/ops/_test_optional_intlist_meta.h>
307
+ #include <ATen/ops/_test_parallel_materialize_meta.h>
308
+ #include <ATen/ops/_test_serialization_subcmul_meta.h>
309
+ #include <ATen/ops/_test_string_default_meta.h>
310
+ #include <ATen/ops/_test_warn_in_autograd_meta.h>
311
+ #include <ATen/ops/_thnn_differentiable_gru_cell_backward_meta.h>
312
+ #include <ATen/ops/_thnn_differentiable_lstm_cell_backward_meta.h>
313
+ #include <ATen/ops/_thnn_fused_gru_cell_meta.h>
314
+ #include <ATen/ops/_thnn_fused_gru_cell_backward_meta.h>
315
+ #include <ATen/ops/_thnn_fused_lstm_cell_meta.h>
316
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_meta.h>
317
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_meta.h>
318
+ #include <ATen/ops/_to_copy_meta.h>
319
+ #include <ATen/ops/_to_cpu_meta.h>
320
+ #include <ATen/ops/_to_dense_meta.h>
321
+ #include <ATen/ops/_to_sparse_meta.h>
322
+ #include <ATen/ops/_to_sparse_bsc_meta.h>
323
+ #include <ATen/ops/_to_sparse_bsr_meta.h>
324
+ #include <ATen/ops/_to_sparse_csc_meta.h>
325
+ #include <ATen/ops/_to_sparse_csr_meta.h>
326
+ #include <ATen/ops/_to_sparse_semi_structured_meta.h>
327
+ #include <ATen/ops/_transform_bias_rescale_qkv_meta.h>
328
+ #include <ATen/ops/_transformer_encoder_layer_fwd_meta.h>
329
+ #include <ATen/ops/_trilinear_meta.h>
330
+ #include <ATen/ops/_triton_multi_head_attention_meta.h>
331
+ #include <ATen/ops/_triton_scaled_dot_attention_meta.h>
332
+ #include <ATen/ops/_unique_meta.h>
333
+ #include <ATen/ops/_unique2_meta.h>
334
+ #include <ATen/ops/_unpack_dual_meta.h>
335
+ #include <ATen/ops/_unsafe_index_meta.h>
336
+ #include <ATen/ops/_unsafe_index_put_meta.h>
337
+ #include <ATen/ops/_unsafe_view_meta.h>
338
+ #include <ATen/ops/_upsample_bicubic2d_aa_meta.h>
339
+ #include <ATen/ops/_upsample_bicubic2d_aa_backward_meta.h>
340
+ #include <ATen/ops/_upsample_bilinear2d_aa_meta.h>
341
+ #include <ATen/ops/_upsample_bilinear2d_aa_backward_meta.h>
342
+ #include <ATen/ops/_upsample_nearest_exact1d_meta.h>
343
+ #include <ATen/ops/_upsample_nearest_exact1d_backward_meta.h>
344
+ #include <ATen/ops/_upsample_nearest_exact2d_meta.h>
345
+ #include <ATen/ops/_upsample_nearest_exact2d_backward_meta.h>
346
+ #include <ATen/ops/_upsample_nearest_exact3d_meta.h>
347
+ #include <ATen/ops/_upsample_nearest_exact3d_backward_meta.h>
348
+ #include <ATen/ops/_use_cudnn_ctc_loss_meta.h>
349
+ #include <ATen/ops/_use_cudnn_rnn_flatten_weight_meta.h>
350
+ #include <ATen/ops/_validate_compressed_sparse_indices_meta.h>
351
+ #include <ATen/ops/_validate_sparse_bsc_tensor_args_meta.h>
352
+ #include <ATen/ops/_validate_sparse_bsr_tensor_args_meta.h>
353
+ #include <ATen/ops/_validate_sparse_compressed_tensor_args_meta.h>
354
+ #include <ATen/ops/_validate_sparse_coo_tensor_args_meta.h>
355
+ #include <ATen/ops/_validate_sparse_csc_tensor_args_meta.h>
356
+ #include <ATen/ops/_validate_sparse_csr_tensor_args_meta.h>
357
+ #include <ATen/ops/_values_meta.h>
358
+ #include <ATen/ops/_values_copy_meta.h>
359
+ #include <ATen/ops/_version_meta.h>
360
+ #include <ATen/ops/_weight_int4pack_mm_meta.h>
361
+ #include <ATen/ops/_weight_int8pack_mm_meta.h>
362
+ #include <ATen/ops/_weight_norm_meta.h>
363
+ #include <ATen/ops/_weight_norm_differentiable_backward_meta.h>
364
+ #include <ATen/ops/_weight_norm_interface_meta.h>
365
+ #include <ATen/ops/_weight_norm_interface_backward_meta.h>
366
+ #include <ATen/ops/abs_meta.h>
367
+ #include <ATen/ops/absolute_meta.h>
368
+ #include <ATen/ops/acos_meta.h>
369
+ #include <ATen/ops/acosh_meta.h>
370
+ #include <ATen/ops/adaptive_avg_pool1d_meta.h>
371
+ #include <ATen/ops/adaptive_avg_pool2d_meta.h>
372
+ #include <ATen/ops/adaptive_avg_pool3d_meta.h>
373
+ #include <ATen/ops/adaptive_avg_pool3d_backward_meta.h>
374
+ #include <ATen/ops/adaptive_max_pool1d_meta.h>
375
+ #include <ATen/ops/adaptive_max_pool2d_meta.h>
376
+ #include <ATen/ops/adaptive_max_pool2d_backward_meta.h>
377
+ #include <ATen/ops/adaptive_max_pool3d_meta.h>
378
+ #include <ATen/ops/adaptive_max_pool3d_backward_meta.h>
379
+ #include <ATen/ops/add_meta.h>
380
+ #include <ATen/ops/addbmm_meta.h>
381
+ #include <ATen/ops/addcdiv_meta.h>
382
+ #include <ATen/ops/addcmul_meta.h>
383
+ #include <ATen/ops/addmm_meta.h>
384
+ #include <ATen/ops/addmv_meta.h>
385
+ #include <ATen/ops/addr_meta.h>
386
+ #include <ATen/ops/adjoint_meta.h>
387
+ #include <ATen/ops/affine_grid_generator_meta.h>
388
+ #include <ATen/ops/affine_grid_generator_backward_meta.h>
389
+ #include <ATen/ops/alias_meta.h>
390
+ #include <ATen/ops/alias_copy_meta.h>
391
+ #include <ATen/ops/align_as_meta.h>
392
+ #include <ATen/ops/align_tensors_meta.h>
393
+ #include <ATen/ops/align_to_meta.h>
394
+ #include <ATen/ops/all_meta.h>
395
+ #include <ATen/ops/allclose_meta.h>
396
+ #include <ATen/ops/alpha_dropout_meta.h>
397
+ #include <ATen/ops/amax_meta.h>
398
+ #include <ATen/ops/amin_meta.h>
399
+ #include <ATen/ops/aminmax_meta.h>
400
+ #include <ATen/ops/and_meta.h>
401
+ #include <ATen/ops/angle_meta.h>
402
+ #include <ATen/ops/any_meta.h>
403
+ #include <ATen/ops/arange_meta.h>
404
+ #include <ATen/ops/arccos_meta.h>
405
+ #include <ATen/ops/arccosh_meta.h>
406
+ #include <ATen/ops/arcsin_meta.h>
407
+ #include <ATen/ops/arcsinh_meta.h>
408
+ #include <ATen/ops/arctan_meta.h>
409
+ #include <ATen/ops/arctan2_meta.h>
410
+ #include <ATen/ops/arctanh_meta.h>
411
+ #include <ATen/ops/argmax_meta.h>
412
+ #include <ATen/ops/argmin_meta.h>
413
+ #include <ATen/ops/argsort_meta.h>
414
+ #include <ATen/ops/argwhere_meta.h>
415
+ #include <ATen/ops/as_strided_meta.h>
416
+ #include <ATen/ops/as_strided_copy_meta.h>
417
+ #include <ATen/ops/as_strided_scatter_meta.h>
418
+ #include <ATen/ops/asin_meta.h>
419
+ #include <ATen/ops/asinh_meta.h>
420
+ #include <ATen/ops/atan_meta.h>
421
+ #include <ATen/ops/atan2_meta.h>
422
+ #include <ATen/ops/atanh_meta.h>
423
+ #include <ATen/ops/atleast_1d_meta.h>
424
+ #include <ATen/ops/atleast_2d_meta.h>
425
+ #include <ATen/ops/atleast_3d_meta.h>
426
+ #include <ATen/ops/avg_pool1d_meta.h>
427
+ #include <ATen/ops/avg_pool2d_meta.h>
428
+ #include <ATen/ops/avg_pool2d_backward_meta.h>
429
+ #include <ATen/ops/avg_pool3d_meta.h>
430
+ #include <ATen/ops/avg_pool3d_backward_meta.h>
431
+ #include <ATen/ops/baddbmm_meta.h>
432
+ #include <ATen/ops/bartlett_window_meta.h>
433
+ #include <ATen/ops/batch_norm_meta.h>
434
+ #include <ATen/ops/batch_norm_backward_elemt_meta.h>
435
+ #include <ATen/ops/batch_norm_backward_reduce_meta.h>
436
+ #include <ATen/ops/batch_norm_elemt_meta.h>
437
+ #include <ATen/ops/batch_norm_gather_stats_meta.h>
438
+ #include <ATen/ops/batch_norm_gather_stats_with_counts_meta.h>
439
+ #include <ATen/ops/batch_norm_stats_meta.h>
440
+ #include <ATen/ops/batch_norm_update_stats_meta.h>
441
+ #include <ATen/ops/bernoulli_meta.h>
442
+ #include <ATen/ops/bilinear_meta.h>
443
+ #include <ATen/ops/binary_cross_entropy_meta.h>
444
+ #include <ATen/ops/binary_cross_entropy_backward_meta.h>
445
+ #include <ATen/ops/binary_cross_entropy_with_logits_meta.h>
446
+ #include <ATen/ops/bincount_meta.h>
447
+ #include <ATen/ops/binomial_meta.h>
448
+ #include <ATen/ops/bitwise_and_meta.h>
449
+ #include <ATen/ops/bitwise_left_shift_meta.h>
450
+ #include <ATen/ops/bitwise_not_meta.h>
451
+ #include <ATen/ops/bitwise_or_meta.h>
452
+ #include <ATen/ops/bitwise_right_shift_meta.h>
453
+ #include <ATen/ops/bitwise_xor_meta.h>
454
+ #include <ATen/ops/blackman_window_meta.h>
455
+ #include <ATen/ops/block_diag_meta.h>
456
+ #include <ATen/ops/bmm_meta.h>
457
+ #include <ATen/ops/broadcast_tensors_meta.h>
458
+ #include <ATen/ops/broadcast_to_meta.h>
459
+ #include <ATen/ops/bucketize_meta.h>
460
+ #include <ATen/ops/can_cast_meta.h>
461
+ #include <ATen/ops/cartesian_prod_meta.h>
462
+ #include <ATen/ops/cat_meta.h>
463
+ #include <ATen/ops/cauchy_meta.h>
464
+ #include <ATen/ops/ccol_indices_meta.h>
465
+ #include <ATen/ops/ccol_indices_copy_meta.h>
466
+ #include <ATen/ops/cdist_meta.h>
467
+ #include <ATen/ops/ceil_meta.h>
468
+ #include <ATen/ops/celu_meta.h>
469
+ #include <ATen/ops/chain_matmul_meta.h>
470
+ #include <ATen/ops/chalf_meta.h>
471
+ #include <ATen/ops/channel_shuffle_meta.h>
472
+ #include <ATen/ops/cholesky_meta.h>
473
+ #include <ATen/ops/cholesky_inverse_meta.h>
474
+ #include <ATen/ops/cholesky_solve_meta.h>
475
+ #include <ATen/ops/choose_qparams_optimized_meta.h>
476
+ #include <ATen/ops/chunk_meta.h>
477
+ #include <ATen/ops/clamp_meta.h>
478
+ #include <ATen/ops/clamp_max_meta.h>
479
+ #include <ATen/ops/clamp_min_meta.h>
480
+ #include <ATen/ops/clip_meta.h>
481
+ #include <ATen/ops/clone_meta.h>
482
+ #include <ATen/ops/coalesce_meta.h>
483
+ #include <ATen/ops/col2im_meta.h>
484
+ #include <ATen/ops/col_indices_meta.h>
485
+ #include <ATen/ops/col_indices_copy_meta.h>
486
+ #include <ATen/ops/column_stack_meta.h>
487
+ #include <ATen/ops/combinations_meta.h>
488
+ #include <ATen/ops/complex_meta.h>
489
+ #include <ATen/ops/concat_meta.h>
490
+ #include <ATen/ops/concatenate_meta.h>
491
+ #include <ATen/ops/conj_meta.h>
492
+ #include <ATen/ops/conj_physical_meta.h>
493
+ #include <ATen/ops/constant_pad_nd_meta.h>
494
+ #include <ATen/ops/contiguous_meta.h>
495
+ #include <ATen/ops/conv1d_meta.h>
496
+ #include <ATen/ops/conv2d_meta.h>
497
+ #include <ATen/ops/conv3d_meta.h>
498
+ #include <ATen/ops/conv_depthwise3d_meta.h>
499
+ #include <ATen/ops/conv_tbc_meta.h>
500
+ #include <ATen/ops/conv_tbc_backward_meta.h>
501
+ #include <ATen/ops/conv_transpose1d_meta.h>
502
+ #include <ATen/ops/conv_transpose2d_meta.h>
503
+ #include <ATen/ops/conv_transpose3d_meta.h>
504
+ #include <ATen/ops/convolution_meta.h>
505
+ #include <ATen/ops/convolution_backward_meta.h>
506
+ #include <ATen/ops/convolution_backward_overrideable_meta.h>
507
+ #include <ATen/ops/convolution_overrideable_meta.h>
508
+ #include <ATen/ops/copy_meta.h>
509
+ #include <ATen/ops/copy_sparse_to_sparse_meta.h>
510
+ #include <ATen/ops/copysign_meta.h>
511
+ #include <ATen/ops/corrcoef_meta.h>
512
+ #include <ATen/ops/cos_meta.h>
513
+ #include <ATen/ops/cosh_meta.h>
514
+ #include <ATen/ops/cosine_embedding_loss_meta.h>
515
+ #include <ATen/ops/cosine_similarity_meta.h>
516
+ #include <ATen/ops/count_nonzero_meta.h>
517
+ #include <ATen/ops/cov_meta.h>
518
+ #include <ATen/ops/cross_meta.h>
519
+ #include <ATen/ops/cross_entropy_loss_meta.h>
520
+ #include <ATen/ops/crow_indices_meta.h>
521
+ #include <ATen/ops/crow_indices_copy_meta.h>
522
+ #include <ATen/ops/ctc_loss_meta.h>
523
+ #include <ATen/ops/cudnn_affine_grid_generator_meta.h>
524
+ #include <ATen/ops/cudnn_affine_grid_generator_backward_meta.h>
525
+ #include <ATen/ops/cudnn_batch_norm_meta.h>
526
+ #include <ATen/ops/cudnn_batch_norm_backward_meta.h>
527
+ #include <ATen/ops/cudnn_convolution_meta.h>
528
+ #include <ATen/ops/cudnn_convolution_add_relu_meta.h>
529
+ #include <ATen/ops/cudnn_convolution_relu_meta.h>
530
+ #include <ATen/ops/cudnn_convolution_transpose_meta.h>
531
+ #include <ATen/ops/cudnn_grid_sampler_meta.h>
532
+ #include <ATen/ops/cudnn_grid_sampler_backward_meta.h>
533
+ #include <ATen/ops/cudnn_is_acceptable_meta.h>
534
+ #include <ATen/ops/cummax_meta.h>
535
+ #include <ATen/ops/cummaxmin_backward_meta.h>
536
+ #include <ATen/ops/cummin_meta.h>
537
+ #include <ATen/ops/cumprod_meta.h>
538
+ #include <ATen/ops/cumprod_backward_meta.h>
539
+ #include <ATen/ops/cumsum_meta.h>
540
+ #include <ATen/ops/cumulative_trapezoid_meta.h>
541
+ #include <ATen/ops/data_meta.h>
542
+ #include <ATen/ops/deg2rad_meta.h>
543
+ #include <ATen/ops/dense_dim_meta.h>
544
+ #include <ATen/ops/dequantize_meta.h>
545
+ #include <ATen/ops/det_meta.h>
546
+ #include <ATen/ops/detach_meta.h>
547
+ #include <ATen/ops/detach_copy_meta.h>
548
+ #include <ATen/ops/diag_meta.h>
549
+ #include <ATen/ops/diag_embed_meta.h>
550
+ #include <ATen/ops/diagflat_meta.h>
551
+ #include <ATen/ops/diagonal_meta.h>
552
+ #include <ATen/ops/diagonal_backward_meta.h>
553
+ #include <ATen/ops/diagonal_copy_meta.h>
554
+ #include <ATen/ops/diagonal_scatter_meta.h>
555
+ #include <ATen/ops/diff_meta.h>
556
+ #include <ATen/ops/digamma_meta.h>
557
+ #include <ATen/ops/dist_meta.h>
558
+ #include <ATen/ops/div_meta.h>
559
+ #include <ATen/ops/divide_meta.h>
560
+ #include <ATen/ops/dot_meta.h>
561
+ #include <ATen/ops/dropout_meta.h>
562
+ #include <ATen/ops/dsplit_meta.h>
563
+ #include <ATen/ops/dstack_meta.h>
564
+ #include <ATen/ops/einsum_meta.h>
565
+ #include <ATen/ops/elu_meta.h>
566
+ #include <ATen/ops/elu_backward_meta.h>
567
+ #include <ATen/ops/embedding_meta.h>
568
+ #include <ATen/ops/embedding_backward_meta.h>
569
+ #include <ATen/ops/embedding_bag_meta.h>
570
+ #include <ATen/ops/embedding_dense_backward_meta.h>
571
+ #include <ATen/ops/embedding_renorm_meta.h>
572
+ #include <ATen/ops/embedding_sparse_backward_meta.h>
573
+ #include <ATen/ops/empty_meta.h>
574
+ #include <ATen/ops/empty_like_meta.h>
575
+ #include <ATen/ops/empty_permuted_meta.h>
576
+ #include <ATen/ops/empty_quantized_meta.h>
577
+ #include <ATen/ops/empty_strided_meta.h>
578
+ #include <ATen/ops/eq_meta.h>
579
+ #include <ATen/ops/equal_meta.h>
580
+ #include <ATen/ops/erf_meta.h>
581
+ #include <ATen/ops/erfc_meta.h>
582
+ #include <ATen/ops/erfinv_meta.h>
583
+ #include <ATen/ops/exp_meta.h>
584
+ #include <ATen/ops/exp2_meta.h>
585
+ #include <ATen/ops/expand_meta.h>
586
+ #include <ATen/ops/expand_as_meta.h>
587
+ #include <ATen/ops/expand_copy_meta.h>
588
+ #include <ATen/ops/expm1_meta.h>
589
+ #include <ATen/ops/exponential_meta.h>
590
+ #include <ATen/ops/eye_meta.h>
591
+ #include <ATen/ops/fake_quantize_per_channel_affine_meta.h>
592
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_meta.h>
593
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_meta.h>
594
+ #include <ATen/ops/fake_quantize_per_tensor_affine_meta.h>
595
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_meta.h>
596
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_meta.h>
597
+ #include <ATen/ops/fbgemm_linear_fp16_weight_meta.h>
598
+ #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_meta.h>
599
+ #include <ATen/ops/fbgemm_linear_int8_weight_meta.h>
600
+ #include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_meta.h>
601
+ #include <ATen/ops/fbgemm_linear_quantize_weight_meta.h>
602
+ #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_meta.h>
603
+ #include <ATen/ops/fbgemm_pack_quantized_matrix_meta.h>
604
+ #include <ATen/ops/feature_alpha_dropout_meta.h>
605
+ #include <ATen/ops/feature_dropout_meta.h>
606
+ #include <ATen/ops/fft_fft_meta.h>
607
+ #include <ATen/ops/fft_fft2_meta.h>
608
+ #include <ATen/ops/fft_fftfreq_meta.h>
609
+ #include <ATen/ops/fft_fftn_meta.h>
610
+ #include <ATen/ops/fft_fftshift_meta.h>
611
+ #include <ATen/ops/fft_hfft_meta.h>
612
+ #include <ATen/ops/fft_hfft2_meta.h>
613
+ #include <ATen/ops/fft_hfftn_meta.h>
614
+ #include <ATen/ops/fft_ifft_meta.h>
615
+ #include <ATen/ops/fft_ifft2_meta.h>
616
+ #include <ATen/ops/fft_ifftn_meta.h>
617
+ #include <ATen/ops/fft_ifftshift_meta.h>
618
+ #include <ATen/ops/fft_ihfft_meta.h>
619
+ #include <ATen/ops/fft_ihfft2_meta.h>
620
+ #include <ATen/ops/fft_ihfftn_meta.h>
621
+ #include <ATen/ops/fft_irfft_meta.h>
622
+ #include <ATen/ops/fft_irfft2_meta.h>
623
+ #include <ATen/ops/fft_irfftn_meta.h>
624
+ #include <ATen/ops/fft_rfft_meta.h>
625
+ #include <ATen/ops/fft_rfft2_meta.h>
626
+ #include <ATen/ops/fft_rfftfreq_meta.h>
627
+ #include <ATen/ops/fft_rfftn_meta.h>
628
+ #include <ATen/ops/fill_meta.h>
629
+ #include <ATen/ops/fill_diagonal_meta.h>
630
+ #include <ATen/ops/fix_meta.h>
631
+ #include <ATen/ops/flatten_meta.h>
632
+ #include <ATen/ops/flatten_dense_tensors_meta.h>
633
+ #include <ATen/ops/flip_meta.h>
634
+ #include <ATen/ops/fliplr_meta.h>
635
+ #include <ATen/ops/flipud_meta.h>
636
+ #include <ATen/ops/float_power_meta.h>
637
+ #include <ATen/ops/floor_meta.h>
638
+ #include <ATen/ops/floor_divide_meta.h>
639
+ #include <ATen/ops/fmax_meta.h>
640
+ #include <ATen/ops/fmin_meta.h>
641
+ #include <ATen/ops/fmod_meta.h>
642
+ #include <ATen/ops/frac_meta.h>
643
+ #include <ATen/ops/fractional_max_pool2d_meta.h>
644
+ #include <ATen/ops/fractional_max_pool2d_backward_meta.h>
645
+ #include <ATen/ops/fractional_max_pool3d_meta.h>
646
+ #include <ATen/ops/fractional_max_pool3d_backward_meta.h>
647
+ #include <ATen/ops/frexp_meta.h>
648
+ #include <ATen/ops/frobenius_norm_meta.h>
649
+ #include <ATen/ops/from_file_meta.h>
650
+ #include <ATen/ops/full_meta.h>
651
+ #include <ATen/ops/full_like_meta.h>
652
+ #include <ATen/ops/fused_moving_avg_obs_fake_quant_meta.h>
653
+ #include <ATen/ops/gather_meta.h>
654
+ #include <ATen/ops/gather_backward_meta.h>
655
+ #include <ATen/ops/gcd_meta.h>
656
+ #include <ATen/ops/ge_meta.h>
657
+ #include <ATen/ops/gelu_meta.h>
658
+ #include <ATen/ops/gelu_backward_meta.h>
659
+ #include <ATen/ops/geometric_meta.h>
660
+ #include <ATen/ops/geqrf_meta.h>
661
+ #include <ATen/ops/ger_meta.h>
662
+ #include <ATen/ops/glu_meta.h>
663
+ #include <ATen/ops/glu_backward_meta.h>
664
+ #include <ATen/ops/glu_backward_jvp_meta.h>
665
+ #include <ATen/ops/glu_jvp_meta.h>
666
+ #include <ATen/ops/gradient_meta.h>
667
+ #include <ATen/ops/greater_meta.h>
668
+ #include <ATen/ops/greater_equal_meta.h>
669
+ #include <ATen/ops/grid_sampler_meta.h>
670
+ #include <ATen/ops/grid_sampler_2d_meta.h>
671
+ #include <ATen/ops/grid_sampler_2d_backward_meta.h>
672
+ #include <ATen/ops/grid_sampler_3d_meta.h>
673
+ #include <ATen/ops/grid_sampler_3d_backward_meta.h>
674
+ #include <ATen/ops/group_norm_meta.h>
675
+ #include <ATen/ops/gru_meta.h>
676
+ #include <ATen/ops/gru_cell_meta.h>
677
+ #include <ATen/ops/gt_meta.h>
678
+ #include <ATen/ops/hamming_window_meta.h>
679
+ #include <ATen/ops/hann_window_meta.h>
680
+ #include <ATen/ops/hardshrink_meta.h>
681
+ #include <ATen/ops/hardshrink_backward_meta.h>
682
+ #include <ATen/ops/hardsigmoid_meta.h>
683
+ #include <ATen/ops/hardsigmoid_backward_meta.h>
684
+ #include <ATen/ops/hardswish_meta.h>
685
+ #include <ATen/ops/hardswish_backward_meta.h>
686
+ #include <ATen/ops/hardtanh_meta.h>
687
+ #include <ATen/ops/hardtanh_backward_meta.h>
688
+ #include <ATen/ops/heaviside_meta.h>
689
+ #include <ATen/ops/hinge_embedding_loss_meta.h>
690
+ #include <ATen/ops/histc_meta.h>
691
+ #include <ATen/ops/histogram_meta.h>
692
+ #include <ATen/ops/histogramdd_meta.h>
693
+ #include <ATen/ops/hsplit_meta.h>
694
+ #include <ATen/ops/hspmm_meta.h>
695
+ #include <ATen/ops/hstack_meta.h>
696
+ #include <ATen/ops/huber_loss_meta.h>
697
+ #include <ATen/ops/huber_loss_backward_meta.h>
698
+ #include <ATen/ops/hypot_meta.h>
699
+ #include <ATen/ops/i0_meta.h>
700
+ #include <ATen/ops/igamma_meta.h>
701
+ #include <ATen/ops/igammac_meta.h>
702
+ #include <ATen/ops/im2col_meta.h>
703
+ #include <ATen/ops/imag_meta.h>
704
+ #include <ATen/ops/index_meta.h>
705
+ #include <ATen/ops/index_add_meta.h>
706
+ #include <ATen/ops/index_copy_meta.h>
707
+ #include <ATen/ops/index_fill_meta.h>
708
+ #include <ATen/ops/index_put_meta.h>
709
+ #include <ATen/ops/index_reduce_meta.h>
710
+ #include <ATen/ops/index_select_meta.h>
711
+ #include <ATen/ops/index_select_backward_meta.h>
712
+ #include <ATen/ops/indices_meta.h>
713
+ #include <ATen/ops/indices_copy_meta.h>
714
+ #include <ATen/ops/infinitely_differentiable_gelu_backward_meta.h>
715
+ #include <ATen/ops/inner_meta.h>
716
+ #include <ATen/ops/instance_norm_meta.h>
717
+ #include <ATen/ops/int_repr_meta.h>
718
+ #include <ATen/ops/inverse_meta.h>
719
+ #include <ATen/ops/is_coalesced_meta.h>
720
+ #include <ATen/ops/is_complex_meta.h>
721
+ #include <ATen/ops/is_conj_meta.h>
722
+ #include <ATen/ops/is_distributed_meta.h>
723
+ #include <ATen/ops/is_floating_point_meta.h>
724
+ #include <ATen/ops/is_inference_meta.h>
725
+ #include <ATen/ops/is_leaf_meta.h>
726
+ #include <ATen/ops/is_neg_meta.h>
727
+ #include <ATen/ops/is_nonzero_meta.h>
728
+ #include <ATen/ops/is_pinned_meta.h>
729
+ #include <ATen/ops/is_same_size_meta.h>
730
+ #include <ATen/ops/is_set_to_meta.h>
731
+ #include <ATen/ops/is_signed_meta.h>
732
+ #include <ATen/ops/is_vulkan_available_meta.h>
733
+ #include <ATen/ops/isclose_meta.h>
734
+ #include <ATen/ops/isfinite_meta.h>
735
+ #include <ATen/ops/isin_meta.h>
736
+ #include <ATen/ops/isinf_meta.h>
737
+ #include <ATen/ops/isnan_meta.h>
738
+ #include <ATen/ops/isneginf_meta.h>
739
+ #include <ATen/ops/isposinf_meta.h>
740
+ #include <ATen/ops/isreal_meta.h>
741
+ #include <ATen/ops/istft_meta.h>
742
+ #include <ATen/ops/item_meta.h>
743
+ #include <ATen/ops/kaiser_window_meta.h>
744
+ #include <ATen/ops/kl_div_meta.h>
745
+ #include <ATen/ops/kron_meta.h>
746
+ #include <ATen/ops/kthvalue_meta.h>
747
+ #include <ATen/ops/l1_loss_meta.h>
748
+ #include <ATen/ops/layer_norm_meta.h>
749
+ #include <ATen/ops/lcm_meta.h>
750
+ #include <ATen/ops/ldexp_meta.h>
751
+ #include <ATen/ops/le_meta.h>
752
+ #include <ATen/ops/leaky_relu_meta.h>
753
+ #include <ATen/ops/leaky_relu_backward_meta.h>
754
+ #include <ATen/ops/lerp_meta.h>
755
+ #include <ATen/ops/less_meta.h>
756
+ #include <ATen/ops/less_equal_meta.h>
757
+ #include <ATen/ops/lgamma_meta.h>
758
+ #include <ATen/ops/lift_meta.h>
759
+ #include <ATen/ops/lift_fresh_meta.h>
760
+ #include <ATen/ops/lift_fresh_copy_meta.h>
761
+ #include <ATen/ops/linalg_cholesky_meta.h>
762
+ #include <ATen/ops/linalg_cholesky_ex_meta.h>
763
+ #include <ATen/ops/linalg_cond_meta.h>
764
+ #include <ATen/ops/linalg_cross_meta.h>
765
+ #include <ATen/ops/linalg_det_meta.h>
766
+ #include <ATen/ops/linalg_diagonal_meta.h>
767
+ #include <ATen/ops/linalg_eig_meta.h>
768
+ #include <ATen/ops/linalg_eigh_meta.h>
769
+ #include <ATen/ops/linalg_eigvals_meta.h>
770
+ #include <ATen/ops/linalg_eigvalsh_meta.h>
771
+ #include <ATen/ops/linalg_householder_product_meta.h>
772
+ #include <ATen/ops/linalg_inv_meta.h>
773
+ #include <ATen/ops/linalg_inv_ex_meta.h>
774
+ #include <ATen/ops/linalg_ldl_factor_meta.h>
775
+ #include <ATen/ops/linalg_ldl_factor_ex_meta.h>
776
+ #include <ATen/ops/linalg_ldl_solve_meta.h>
777
+ #include <ATen/ops/linalg_lstsq_meta.h>
778
+ #include <ATen/ops/linalg_lu_meta.h>
779
+ #include <ATen/ops/linalg_lu_factor_meta.h>
780
+ #include <ATen/ops/linalg_lu_factor_ex_meta.h>
781
+ #include <ATen/ops/linalg_lu_solve_meta.h>
782
+ #include <ATen/ops/linalg_matmul_meta.h>
783
+ #include <ATen/ops/linalg_matrix_exp_meta.h>
784
+ #include <ATen/ops/linalg_matrix_norm_meta.h>
785
+ #include <ATen/ops/linalg_matrix_power_meta.h>
786
+ #include <ATen/ops/linalg_matrix_rank_meta.h>
787
+ #include <ATen/ops/linalg_multi_dot_meta.h>
788
+ #include <ATen/ops/linalg_norm_meta.h>
789
+ #include <ATen/ops/linalg_pinv_meta.h>
790
+ #include <ATen/ops/linalg_qr_meta.h>
791
+ #include <ATen/ops/linalg_slogdet_meta.h>
792
+ #include <ATen/ops/linalg_solve_meta.h>
793
+ #include <ATen/ops/linalg_solve_ex_meta.h>
794
+ #include <ATen/ops/linalg_solve_triangular_meta.h>
795
+ #include <ATen/ops/linalg_svd_meta.h>
796
+ #include <ATen/ops/linalg_svdvals_meta.h>
797
+ #include <ATen/ops/linalg_tensorinv_meta.h>
798
+ #include <ATen/ops/linalg_tensorsolve_meta.h>
799
+ #include <ATen/ops/linalg_vander_meta.h>
800
+ #include <ATen/ops/linalg_vecdot_meta.h>
801
+ #include <ATen/ops/linalg_vector_norm_meta.h>
802
+ #include <ATen/ops/linear_meta.h>
803
+ #include <ATen/ops/linear_backward_meta.h>
804
+ #include <ATen/ops/linspace_meta.h>
805
+ #include <ATen/ops/log_meta.h>
806
+ #include <ATen/ops/log10_meta.h>
807
+ #include <ATen/ops/log1p_meta.h>
808
+ #include <ATen/ops/log2_meta.h>
809
+ #include <ATen/ops/log_normal_meta.h>
810
+ #include <ATen/ops/log_sigmoid_meta.h>
811
+ #include <ATen/ops/log_sigmoid_backward_meta.h>
812
+ #include <ATen/ops/log_sigmoid_forward_meta.h>
813
+ #include <ATen/ops/log_softmax_meta.h>
814
+ #include <ATen/ops/logaddexp_meta.h>
815
+ #include <ATen/ops/logaddexp2_meta.h>
816
+ #include <ATen/ops/logcumsumexp_meta.h>
817
+ #include <ATen/ops/logdet_meta.h>
818
+ #include <ATen/ops/logical_and_meta.h>
819
+ #include <ATen/ops/logical_not_meta.h>
820
+ #include <ATen/ops/logical_or_meta.h>
821
+ #include <ATen/ops/logical_xor_meta.h>
822
+ #include <ATen/ops/logit_meta.h>
823
+ #include <ATen/ops/logit_backward_meta.h>
824
+ #include <ATen/ops/logspace_meta.h>
825
+ #include <ATen/ops/logsumexp_meta.h>
826
+ #include <ATen/ops/lshift_meta.h>
827
+ #include <ATen/ops/lstm_meta.h>
828
+ #include <ATen/ops/lstm_cell_meta.h>
829
+ #include <ATen/ops/lstm_mps_backward_meta.h>
830
+ #include <ATen/ops/lt_meta.h>
831
+ #include <ATen/ops/lu_solve_meta.h>
832
+ #include <ATen/ops/lu_unpack_meta.h>
833
+ #include <ATen/ops/mH_meta.h>
834
+ #include <ATen/ops/mT_meta.h>
835
+ #include <ATen/ops/margin_ranking_loss_meta.h>
836
+ #include <ATen/ops/masked_fill_meta.h>
837
+ #include <ATen/ops/masked_scatter_meta.h>
838
+ #include <ATen/ops/masked_scatter_backward_meta.h>
839
+ #include <ATen/ops/masked_select_meta.h>
840
+ #include <ATen/ops/masked_select_backward_meta.h>
841
+ #include <ATen/ops/matmul_meta.h>
842
+ #include <ATen/ops/matmul_backward_meta.h>
843
+ #include <ATen/ops/matrix_H_meta.h>
844
+ #include <ATen/ops/matrix_exp_meta.h>
845
+ #include <ATen/ops/matrix_exp_backward_meta.h>
846
+ #include <ATen/ops/matrix_power_meta.h>
847
+ #include <ATen/ops/max_meta.h>
848
+ #include <ATen/ops/max_pool1d_meta.h>
849
+ #include <ATen/ops/max_pool1d_with_indices_meta.h>
850
+ #include <ATen/ops/max_pool2d_meta.h>
851
+ #include <ATen/ops/max_pool2d_backward_meta.h>
852
+ #include <ATen/ops/max_pool2d_with_indices_meta.h>
853
+ #include <ATen/ops/max_pool2d_with_indices_backward_meta.h>
854
+ #include <ATen/ops/max_pool3d_meta.h>
855
+ #include <ATen/ops/max_pool3d_with_indices_meta.h>
856
+ #include <ATen/ops/max_pool3d_with_indices_backward_meta.h>
857
+ #include <ATen/ops/max_unpool2d_meta.h>
858
+ #include <ATen/ops/max_unpool3d_meta.h>
859
+ #include <ATen/ops/maximum_meta.h>
860
+ #include <ATen/ops/mean_meta.h>
861
+ #include <ATen/ops/median_meta.h>
862
+ #include <ATen/ops/meshgrid_meta.h>
863
+ #include <ATen/ops/min_meta.h>
864
+ #include <ATen/ops/minimum_meta.h>
865
+ #include <ATen/ops/miopen_batch_norm_meta.h>
866
+ #include <ATen/ops/miopen_batch_norm_backward_meta.h>
867
+ #include <ATen/ops/miopen_convolution_meta.h>
868
+ #include <ATen/ops/miopen_convolution_add_relu_meta.h>
869
+ #include <ATen/ops/miopen_convolution_relu_meta.h>
870
+ #include <ATen/ops/miopen_convolution_transpose_meta.h>
871
+ #include <ATen/ops/miopen_depthwise_convolution_meta.h>
872
+ #include <ATen/ops/miopen_rnn_meta.h>
873
+ #include <ATen/ops/miopen_rnn_backward_meta.h>
874
+ #include <ATen/ops/mish_meta.h>
875
+ #include <ATen/ops/mish_backward_meta.h>
876
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_meta.h>
877
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_meta.h>
878
+ #include <ATen/ops/mkldnn_convolution_meta.h>
879
+ #include <ATen/ops/mkldnn_linear_meta.h>
880
+ #include <ATen/ops/mkldnn_linear_backward_meta.h>
881
+ #include <ATen/ops/mkldnn_linear_backward_input_meta.h>
882
+ #include <ATen/ops/mkldnn_linear_backward_weights_meta.h>
883
+ #include <ATen/ops/mkldnn_max_pool2d_meta.h>
884
+ #include <ATen/ops/mkldnn_max_pool2d_backward_meta.h>
885
+ #include <ATen/ops/mkldnn_max_pool3d_meta.h>
886
+ #include <ATen/ops/mkldnn_max_pool3d_backward_meta.h>
887
+ #include <ATen/ops/mkldnn_reorder_conv2d_weight_meta.h>
888
+ #include <ATen/ops/mkldnn_reorder_conv3d_weight_meta.h>
889
+ #include <ATen/ops/mkldnn_rnn_layer_meta.h>
890
+ #include <ATen/ops/mkldnn_rnn_layer_backward_meta.h>
891
+ #include <ATen/ops/mm_meta.h>
892
+ #include <ATen/ops/mode_meta.h>
893
+ #include <ATen/ops/moveaxis_meta.h>
894
+ #include <ATen/ops/movedim_meta.h>
895
+ #include <ATen/ops/mps_convolution_backward_meta.h>
896
+ #include <ATen/ops/mps_convolution_transpose_backward_meta.h>
897
+ #include <ATen/ops/mse_loss_meta.h>
898
+ #include <ATen/ops/mse_loss_backward_meta.h>
899
+ #include <ATen/ops/msort_meta.h>
900
+ #include <ATen/ops/mul_meta.h>
901
+ #include <ATen/ops/multi_margin_loss_meta.h>
902
+ #include <ATen/ops/multi_margin_loss_backward_meta.h>
903
+ #include <ATen/ops/multilabel_margin_loss_meta.h>
904
+ #include <ATen/ops/multilabel_margin_loss_backward_meta.h>
905
+ #include <ATen/ops/multilabel_margin_loss_forward_meta.h>
906
+ #include <ATen/ops/multinomial_meta.h>
907
+ #include <ATen/ops/multiply_meta.h>
908
+ #include <ATen/ops/mv_meta.h>
909
+ #include <ATen/ops/mvlgamma_meta.h>
910
+ #include <ATen/ops/nan_to_num_meta.h>
911
+ #include <ATen/ops/nanmean_meta.h>
912
+ #include <ATen/ops/nanmedian_meta.h>
913
+ #include <ATen/ops/nanquantile_meta.h>
914
+ #include <ATen/ops/nansum_meta.h>
915
+ #include <ATen/ops/narrow_meta.h>
916
+ #include <ATen/ops/narrow_copy_meta.h>
917
+ #include <ATen/ops/native_batch_norm_meta.h>
918
+ #include <ATen/ops/native_batch_norm_backward_meta.h>
919
+ #include <ATen/ops/native_channel_shuffle_meta.h>
920
+ #include <ATen/ops/native_dropout_meta.h>
921
+ #include <ATen/ops/native_dropout_backward_meta.h>
922
+ #include <ATen/ops/native_group_norm_meta.h>
923
+ #include <ATen/ops/native_group_norm_backward_meta.h>
924
+ #include <ATen/ops/native_layer_norm_meta.h>
925
+ #include <ATen/ops/native_layer_norm_backward_meta.h>
926
+ #include <ATen/ops/native_norm_meta.h>
927
+ #include <ATen/ops/ne_meta.h>
928
+ #include <ATen/ops/neg_meta.h>
929
+ #include <ATen/ops/negative_meta.h>
930
+ #include <ATen/ops/nested_to_padded_tensor_meta.h>
931
+ #include <ATen/ops/new_empty_meta.h>
932
+ #include <ATen/ops/new_empty_strided_meta.h>
933
+ #include <ATen/ops/new_full_meta.h>
934
+ #include <ATen/ops/new_ones_meta.h>
935
+ #include <ATen/ops/new_zeros_meta.h>
936
+ #include <ATen/ops/nextafter_meta.h>
937
+ #include <ATen/ops/nll_loss_meta.h>
938
+ #include <ATen/ops/nll_loss2d_meta.h>
939
+ #include <ATen/ops/nll_loss2d_backward_meta.h>
940
+ #include <ATen/ops/nll_loss2d_forward_meta.h>
941
+ #include <ATen/ops/nll_loss_backward_meta.h>
942
+ #include <ATen/ops/nll_loss_forward_meta.h>
943
+ #include <ATen/ops/nll_loss_nd_meta.h>
944
+ #include <ATen/ops/nonzero_meta.h>
945
+ #include <ATen/ops/nonzero_numpy_meta.h>
946
+ #include <ATen/ops/nonzero_static_meta.h>
947
+ #include <ATen/ops/norm_meta.h>
948
+ #include <ATen/ops/norm_except_dim_meta.h>
949
+ #include <ATen/ops/normal_meta.h>
950
+ #include <ATen/ops/not_equal_meta.h>
951
+ #include <ATen/ops/nuclear_norm_meta.h>
952
+ #include <ATen/ops/numpy_T_meta.h>
953
+ #include <ATen/ops/one_hot_meta.h>
954
+ #include <ATen/ops/ones_meta.h>
955
+ #include <ATen/ops/ones_like_meta.h>
956
+ #include <ATen/ops/or_meta.h>
957
+ #include <ATen/ops/orgqr_meta.h>
958
+ #include <ATen/ops/ormqr_meta.h>
959
+ #include <ATen/ops/outer_meta.h>
960
+ #include <ATen/ops/output_nr_meta.h>
961
+ #include <ATen/ops/pad_meta.h>
962
+ #include <ATen/ops/pad_sequence_meta.h>
963
+ #include <ATen/ops/pairwise_distance_meta.h>
964
+ #include <ATen/ops/pdist_meta.h>
965
+ #include <ATen/ops/permute_meta.h>
966
+ #include <ATen/ops/permute_copy_meta.h>
967
+ #include <ATen/ops/pin_memory_meta.h>
968
+ #include <ATen/ops/pinverse_meta.h>
969
+ #include <ATen/ops/pixel_shuffle_meta.h>
970
+ #include <ATen/ops/pixel_unshuffle_meta.h>
971
+ #include <ATen/ops/poisson_meta.h>
972
+ #include <ATen/ops/poisson_nll_loss_meta.h>
973
+ #include <ATen/ops/polar_meta.h>
974
+ #include <ATen/ops/polygamma_meta.h>
975
+ #include <ATen/ops/positive_meta.h>
976
+ #include <ATen/ops/pow_meta.h>
977
+ #include <ATen/ops/prelu_meta.h>
978
+ #include <ATen/ops/prod_meta.h>
979
+ #include <ATen/ops/promote_types_meta.h>
980
+ #include <ATen/ops/put_meta.h>
981
+ #include <ATen/ops/q_per_channel_axis_meta.h>
982
+ #include <ATen/ops/q_per_channel_scales_meta.h>
983
+ #include <ATen/ops/q_per_channel_zero_points_meta.h>
984
+ #include <ATen/ops/q_scale_meta.h>
985
+ #include <ATen/ops/q_zero_point_meta.h>
986
+ #include <ATen/ops/qr_meta.h>
987
+ #include <ATen/ops/qscheme_meta.h>
988
+ #include <ATen/ops/quantile_meta.h>
989
+ #include <ATen/ops/quantize_per_channel_meta.h>
990
+ #include <ATen/ops/quantize_per_tensor_meta.h>
991
+ #include <ATen/ops/quantize_per_tensor_dynamic_meta.h>
992
+ #include <ATen/ops/quantized_batch_norm_meta.h>
993
+ #include <ATen/ops/quantized_gru_cell_meta.h>
994
+ #include <ATen/ops/quantized_lstm_cell_meta.h>
995
+ #include <ATen/ops/quantized_max_pool1d_meta.h>
996
+ #include <ATen/ops/quantized_max_pool2d_meta.h>
997
+ #include <ATen/ops/quantized_max_pool3d_meta.h>
998
+ #include <ATen/ops/quantized_rnn_relu_cell_meta.h>
999
+ #include <ATen/ops/quantized_rnn_tanh_cell_meta.h>
1000
+ #include <ATen/ops/rad2deg_meta.h>
1001
+ #include <ATen/ops/rand_meta.h>
1002
+ #include <ATen/ops/rand_like_meta.h>
1003
+ #include <ATen/ops/randint_meta.h>
1004
+ #include <ATen/ops/randint_like_meta.h>
1005
+ #include <ATen/ops/randn_meta.h>
1006
+ #include <ATen/ops/randn_like_meta.h>
1007
+ #include <ATen/ops/random_meta.h>
1008
+ #include <ATen/ops/randperm_meta.h>
1009
+ #include <ATen/ops/range_meta.h>
1010
+ #include <ATen/ops/ravel_meta.h>
1011
+ #include <ATen/ops/real_meta.h>
1012
+ #include <ATen/ops/reciprocal_meta.h>
1013
+ #include <ATen/ops/record_stream_meta.h>
1014
+ #include <ATen/ops/refine_names_meta.h>
1015
+ #include <ATen/ops/reflection_pad1d_meta.h>
1016
+ #include <ATen/ops/reflection_pad1d_backward_meta.h>
1017
+ #include <ATen/ops/reflection_pad2d_meta.h>
1018
+ #include <ATen/ops/reflection_pad2d_backward_meta.h>
1019
+ #include <ATen/ops/reflection_pad3d_meta.h>
1020
+ #include <ATen/ops/reflection_pad3d_backward_meta.h>
1021
+ #include <ATen/ops/relu_meta.h>
1022
+ #include <ATen/ops/relu6_meta.h>
1023
+ #include <ATen/ops/remainder_meta.h>
1024
+ #include <ATen/ops/rename_meta.h>
1025
+ #include <ATen/ops/renorm_meta.h>
1026
+ #include <ATen/ops/repeat_meta.h>
1027
+ #include <ATen/ops/repeat_interleave_meta.h>
1028
+ #include <ATen/ops/replication_pad1d_meta.h>
1029
+ #include <ATen/ops/replication_pad1d_backward_meta.h>
1030
+ #include <ATen/ops/replication_pad2d_meta.h>
1031
+ #include <ATen/ops/replication_pad2d_backward_meta.h>
1032
+ #include <ATen/ops/replication_pad3d_meta.h>
1033
+ #include <ATen/ops/replication_pad3d_backward_meta.h>
1034
+ #include <ATen/ops/requires_grad_meta.h>
1035
+ #include <ATen/ops/reshape_meta.h>
1036
+ #include <ATen/ops/reshape_as_meta.h>
1037
+ #include <ATen/ops/resize_meta.h>
1038
+ #include <ATen/ops/resize_as_meta.h>
1039
+ #include <ATen/ops/resize_as_sparse_meta.h>
1040
+ #include <ATen/ops/resolve_conj_meta.h>
1041
+ #include <ATen/ops/resolve_neg_meta.h>
1042
+ #include <ATen/ops/result_type_meta.h>
1043
+ #include <ATen/ops/retain_grad_meta.h>
1044
+ #include <ATen/ops/retains_grad_meta.h>
1045
+ #include <ATen/ops/rnn_relu_meta.h>
1046
+ #include <ATen/ops/rnn_relu_cell_meta.h>
1047
+ #include <ATen/ops/rnn_tanh_meta.h>
1048
+ #include <ATen/ops/rnn_tanh_cell_meta.h>
1049
+ #include <ATen/ops/roll_meta.h>
1050
+ #include <ATen/ops/rot90_meta.h>
1051
+ #include <ATen/ops/round_meta.h>
1052
+ #include <ATen/ops/row_indices_meta.h>
1053
+ #include <ATen/ops/row_indices_copy_meta.h>
1054
+ #include <ATen/ops/row_stack_meta.h>
1055
+ #include <ATen/ops/rrelu_meta.h>
1056
+ #include <ATen/ops/rrelu_with_noise_meta.h>
1057
+ #include <ATen/ops/rrelu_with_noise_backward_meta.h>
1058
+ #include <ATen/ops/rshift_meta.h>
1059
+ #include <ATen/ops/rsqrt_meta.h>
1060
+ #include <ATen/ops/rsub_meta.h>
1061
+ #include <ATen/ops/scalar_tensor_meta.h>
1062
+ #include <ATen/ops/scaled_dot_product_attention_meta.h>
1063
+ #include <ATen/ops/scatter_meta.h>
1064
+ #include <ATen/ops/scatter_add_meta.h>
1065
+ #include <ATen/ops/scatter_reduce_meta.h>
1066
+ #include <ATen/ops/searchsorted_meta.h>
1067
+ #include <ATen/ops/segment_reduce_meta.h>
1068
+ #include <ATen/ops/select_meta.h>
1069
+ #include <ATen/ops/select_backward_meta.h>
1070
+ #include <ATen/ops/select_copy_meta.h>
1071
+ #include <ATen/ops/select_scatter_meta.h>
1072
+ #include <ATen/ops/selu_meta.h>
1073
+ #include <ATen/ops/set_meta.h>
1074
+ #include <ATen/ops/set_data_meta.h>
1075
+ #include <ATen/ops/sgn_meta.h>
1076
+ #include <ATen/ops/sigmoid_meta.h>
1077
+ #include <ATen/ops/sigmoid_backward_meta.h>
1078
+ #include <ATen/ops/sign_meta.h>
1079
+ #include <ATen/ops/signbit_meta.h>
1080
+ #include <ATen/ops/silu_meta.h>
1081
+ #include <ATen/ops/silu_backward_meta.h>
1082
+ #include <ATen/ops/sin_meta.h>
1083
+ #include <ATen/ops/sinc_meta.h>
1084
+ #include <ATen/ops/sinh_meta.h>
1085
+ #include <ATen/ops/size_meta.h>
1086
+ #include <ATen/ops/slice_meta.h>
1087
+ #include <ATen/ops/slice_backward_meta.h>
1088
+ #include <ATen/ops/slice_copy_meta.h>
1089
+ #include <ATen/ops/slice_inverse_meta.h>
1090
+ #include <ATen/ops/slice_scatter_meta.h>
1091
+ #include <ATen/ops/slogdet_meta.h>
1092
+ #include <ATen/ops/slow_conv3d_meta.h>
1093
+ #include <ATen/ops/slow_conv3d_forward_meta.h>
1094
+ #include <ATen/ops/slow_conv_dilated2d_meta.h>
1095
+ #include <ATen/ops/slow_conv_dilated3d_meta.h>
1096
+ #include <ATen/ops/slow_conv_transpose2d_meta.h>
1097
+ #include <ATen/ops/slow_conv_transpose3d_meta.h>
1098
+ #include <ATen/ops/smm_meta.h>
1099
+ #include <ATen/ops/smooth_l1_loss_meta.h>
1100
+ #include <ATen/ops/smooth_l1_loss_backward_meta.h>
1101
+ #include <ATen/ops/soft_margin_loss_meta.h>
1102
+ #include <ATen/ops/soft_margin_loss_backward_meta.h>
1103
+ #include <ATen/ops/softmax_meta.h>
1104
+ #include <ATen/ops/softplus_meta.h>
1105
+ #include <ATen/ops/softplus_backward_meta.h>
1106
+ #include <ATen/ops/softshrink_meta.h>
1107
+ #include <ATen/ops/softshrink_backward_meta.h>
1108
+ #include <ATen/ops/sort_meta.h>
1109
+ #include <ATen/ops/sparse_bsc_tensor_meta.h>
1110
+ #include <ATen/ops/sparse_bsr_tensor_meta.h>
1111
+ #include <ATen/ops/sparse_compressed_tensor_meta.h>
1112
+ #include <ATen/ops/sparse_coo_tensor_meta.h>
1113
+ #include <ATen/ops/sparse_csc_tensor_meta.h>
1114
+ #include <ATen/ops/sparse_csr_tensor_meta.h>
1115
+ #include <ATen/ops/sparse_dim_meta.h>
1116
+ #include <ATen/ops/sparse_mask_meta.h>
1117
+ #include <ATen/ops/sparse_resize_meta.h>
1118
+ #include <ATen/ops/sparse_resize_and_clear_meta.h>
1119
+ #include <ATen/ops/sparse_sampled_addmm_meta.h>
1120
+ #include <ATen/ops/special_airy_ai_meta.h>
1121
+ #include <ATen/ops/special_bessel_j0_meta.h>
1122
+ #include <ATen/ops/special_bessel_j1_meta.h>
1123
+ #include <ATen/ops/special_bessel_y0_meta.h>
1124
+ #include <ATen/ops/special_bessel_y1_meta.h>
1125
+ #include <ATen/ops/special_chebyshev_polynomial_t_meta.h>
1126
+ #include <ATen/ops/special_chebyshev_polynomial_u_meta.h>
1127
+ #include <ATen/ops/special_chebyshev_polynomial_v_meta.h>
1128
+ #include <ATen/ops/special_chebyshev_polynomial_w_meta.h>
1129
+ #include <ATen/ops/special_digamma_meta.h>
1130
+ #include <ATen/ops/special_entr_meta.h>
1131
+ #include <ATen/ops/special_erf_meta.h>
1132
+ #include <ATen/ops/special_erfc_meta.h>
1133
+ #include <ATen/ops/special_erfcx_meta.h>
1134
+ #include <ATen/ops/special_erfinv_meta.h>
1135
+ #include <ATen/ops/special_exp2_meta.h>
1136
+ #include <ATen/ops/special_expit_meta.h>
1137
+ #include <ATen/ops/special_expm1_meta.h>
1138
+ #include <ATen/ops/special_gammainc_meta.h>
1139
+ #include <ATen/ops/special_gammaincc_meta.h>
1140
+ #include <ATen/ops/special_gammaln_meta.h>
1141
+ #include <ATen/ops/special_hermite_polynomial_h_meta.h>
1142
+ #include <ATen/ops/special_hermite_polynomial_he_meta.h>
1143
+ #include <ATen/ops/special_i0_meta.h>
1144
+ #include <ATen/ops/special_i0e_meta.h>
1145
+ #include <ATen/ops/special_i1_meta.h>
1146
+ #include <ATen/ops/special_i1e_meta.h>
1147
+ #include <ATen/ops/special_laguerre_polynomial_l_meta.h>
1148
+ #include <ATen/ops/special_legendre_polynomial_p_meta.h>
1149
+ #include <ATen/ops/special_log1p_meta.h>
1150
+ #include <ATen/ops/special_log_ndtr_meta.h>
1151
+ #include <ATen/ops/special_log_softmax_meta.h>
1152
+ #include <ATen/ops/special_logit_meta.h>
1153
+ #include <ATen/ops/special_logsumexp_meta.h>
1154
+ #include <ATen/ops/special_modified_bessel_i0_meta.h>
1155
+ #include <ATen/ops/special_modified_bessel_i1_meta.h>
1156
+ #include <ATen/ops/special_modified_bessel_k0_meta.h>
1157
+ #include <ATen/ops/special_modified_bessel_k1_meta.h>
1158
+ #include <ATen/ops/special_multigammaln_meta.h>
1159
+ #include <ATen/ops/special_ndtr_meta.h>
1160
+ #include <ATen/ops/special_ndtri_meta.h>
1161
+ #include <ATen/ops/special_polygamma_meta.h>
1162
+ #include <ATen/ops/special_psi_meta.h>
1163
+ #include <ATen/ops/special_round_meta.h>
1164
+ #include <ATen/ops/special_scaled_modified_bessel_k0_meta.h>
1165
+ #include <ATen/ops/special_scaled_modified_bessel_k1_meta.h>
1166
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta.h>
1167
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta.h>
1168
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta.h>
1169
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta.h>
1170
+ #include <ATen/ops/special_sinc_meta.h>
1171
+ #include <ATen/ops/special_softmax_meta.h>
1172
+ #include <ATen/ops/special_spherical_bessel_j0_meta.h>
1173
+ #include <ATen/ops/special_xlog1py_meta.h>
1174
+ #include <ATen/ops/special_xlogy_meta.h>
1175
+ #include <ATen/ops/special_zeta_meta.h>
1176
+ #include <ATen/ops/split_meta.h>
1177
+ #include <ATen/ops/split_copy_meta.h>
1178
+ #include <ATen/ops/split_with_sizes_meta.h>
1179
+ #include <ATen/ops/split_with_sizes_copy_meta.h>
1180
+ #include <ATen/ops/sqrt_meta.h>
1181
+ #include <ATen/ops/square_meta.h>
1182
+ #include <ATen/ops/squeeze_meta.h>
1183
+ #include <ATen/ops/squeeze_copy_meta.h>
1184
+ #include <ATen/ops/sspaddmm_meta.h>
1185
+ #include <ATen/ops/stack_meta.h>
1186
+ #include <ATen/ops/std_meta.h>
1187
+ #include <ATen/ops/std_mean_meta.h>
1188
+ #include <ATen/ops/stft_meta.h>
1189
+ #include <ATen/ops/stride_meta.h>
1190
+ #include <ATen/ops/sub_meta.h>
1191
+ #include <ATen/ops/subtract_meta.h>
1192
+ #include <ATen/ops/sum_meta.h>
1193
+ #include <ATen/ops/sum_to_size_meta.h>
1194
+ #include <ATen/ops/svd_meta.h>
1195
+ #include <ATen/ops/swapaxes_meta.h>
1196
+ #include <ATen/ops/swapdims_meta.h>
1197
+ #include <ATen/ops/sym_constrain_range_meta.h>
1198
+ #include <ATen/ops/sym_constrain_range_for_size_meta.h>
1199
+ #include <ATen/ops/sym_numel_meta.h>
1200
+ #include <ATen/ops/sym_size_meta.h>
1201
+ #include <ATen/ops/sym_storage_offset_meta.h>
1202
+ #include <ATen/ops/sym_stride_meta.h>
1203
+ #include <ATen/ops/t_meta.h>
1204
+ #include <ATen/ops/t_copy_meta.h>
1205
+ #include <ATen/ops/take_meta.h>
1206
+ #include <ATen/ops/take_along_dim_meta.h>
1207
+ #include <ATen/ops/tan_meta.h>
1208
+ #include <ATen/ops/tanh_meta.h>
1209
+ #include <ATen/ops/tanh_backward_meta.h>
1210
+ #include <ATen/ops/tensor_split_meta.h>
1211
+ #include <ATen/ops/tensordot_meta.h>
1212
+ #include <ATen/ops/thnn_conv2d_meta.h>
1213
+ #include <ATen/ops/threshold_meta.h>
1214
+ #include <ATen/ops/threshold_backward_meta.h>
1215
+ #include <ATen/ops/tile_meta.h>
1216
+ #include <ATen/ops/to_meta.h>
1217
+ #include <ATen/ops/to_dense_meta.h>
1218
+ #include <ATen/ops/to_dense_backward_meta.h>
1219
+ #include <ATen/ops/to_mkldnn_meta.h>
1220
+ #include <ATen/ops/to_mkldnn_backward_meta.h>
1221
+ #include <ATen/ops/to_padded_tensor_meta.h>
1222
+ #include <ATen/ops/to_sparse_meta.h>
1223
+ #include <ATen/ops/to_sparse_bsc_meta.h>
1224
+ #include <ATen/ops/to_sparse_bsr_meta.h>
1225
+ #include <ATen/ops/to_sparse_csc_meta.h>
1226
+ #include <ATen/ops/to_sparse_csr_meta.h>
1227
+ #include <ATen/ops/topk_meta.h>
1228
+ #include <ATen/ops/trace_meta.h>
1229
+ #include <ATen/ops/trace_backward_meta.h>
1230
+ #include <ATen/ops/transpose_meta.h>
1231
+ #include <ATen/ops/transpose_copy_meta.h>
1232
+ #include <ATen/ops/trapezoid_meta.h>
1233
+ #include <ATen/ops/trapz_meta.h>
1234
+ #include <ATen/ops/triangular_solve_meta.h>
1235
+ #include <ATen/ops/tril_meta.h>
1236
+ #include <ATen/ops/tril_indices_meta.h>
1237
+ #include <ATen/ops/triplet_margin_loss_meta.h>
1238
+ #include <ATen/ops/triu_meta.h>
1239
+ #include <ATen/ops/triu_indices_meta.h>
1240
+ #include <ATen/ops/true_divide_meta.h>
1241
+ #include <ATen/ops/trunc_meta.h>
1242
+ #include <ATen/ops/type_as_meta.h>
1243
+ #include <ATen/ops/unbind_meta.h>
1244
+ #include <ATen/ops/unbind_copy_meta.h>
1245
+ #include <ATen/ops/unflatten_meta.h>
1246
+ #include <ATen/ops/unflatten_dense_tensors_meta.h>
1247
+ #include <ATen/ops/unfold_meta.h>
1248
+ #include <ATen/ops/unfold_backward_meta.h>
1249
+ #include <ATen/ops/unfold_copy_meta.h>
1250
+ #include <ATen/ops/uniform_meta.h>
1251
+ #include <ATen/ops/unique_consecutive_meta.h>
1252
+ #include <ATen/ops/unique_dim_meta.h>
1253
+ #include <ATen/ops/unique_dim_consecutive_meta.h>
1254
+ #include <ATen/ops/unsafe_chunk_meta.h>
1255
+ #include <ATen/ops/unsafe_split_meta.h>
1256
+ #include <ATen/ops/unsafe_split_with_sizes_meta.h>
1257
+ #include <ATen/ops/unsqueeze_meta.h>
1258
+ #include <ATen/ops/unsqueeze_copy_meta.h>
1259
+ #include <ATen/ops/upsample_bicubic2d_meta.h>
1260
+ #include <ATen/ops/upsample_bicubic2d_backward_meta.h>
1261
+ #include <ATen/ops/upsample_bilinear2d_meta.h>
1262
+ #include <ATen/ops/upsample_bilinear2d_backward_meta.h>
1263
+ #include <ATen/ops/upsample_linear1d_meta.h>
1264
+ #include <ATen/ops/upsample_linear1d_backward_meta.h>
1265
+ #include <ATen/ops/upsample_nearest1d_meta.h>
1266
+ #include <ATen/ops/upsample_nearest1d_backward_meta.h>
1267
+ #include <ATen/ops/upsample_nearest2d_meta.h>
1268
+ #include <ATen/ops/upsample_nearest2d_backward_meta.h>
1269
+ #include <ATen/ops/upsample_nearest3d_meta.h>
1270
+ #include <ATen/ops/upsample_nearest3d_backward_meta.h>
1271
+ #include <ATen/ops/upsample_trilinear3d_meta.h>
1272
+ #include <ATen/ops/upsample_trilinear3d_backward_meta.h>
1273
+ #include <ATen/ops/value_selecting_reduction_backward_meta.h>
1274
+ #include <ATen/ops/values_meta.h>
1275
+ #include <ATen/ops/values_copy_meta.h>
1276
+ #include <ATen/ops/vander_meta.h>
1277
+ #include <ATen/ops/var_meta.h>
1278
+ #include <ATen/ops/var_mean_meta.h>
1279
+ #include <ATen/ops/vdot_meta.h>
1280
+ #include <ATen/ops/view_meta.h>
1281
+ #include <ATen/ops/view_as_meta.h>
1282
+ #include <ATen/ops/view_as_complex_meta.h>
1283
+ #include <ATen/ops/view_as_complex_copy_meta.h>
1284
+ #include <ATen/ops/view_as_real_meta.h>
1285
+ #include <ATen/ops/view_as_real_copy_meta.h>
1286
+ #include <ATen/ops/view_copy_meta.h>
1287
+ #include <ATen/ops/vsplit_meta.h>
1288
+ #include <ATen/ops/vstack_meta.h>
1289
+ #include <ATen/ops/where_meta.h>
1290
+ #include <ATen/ops/xlogy_meta.h>
1291
+ #include <ATen/ops/xor_meta.h>
1292
+ #include <ATen/ops/zero_meta.h>
1293
+ #include <ATen/ops/zeros_meta.h>
1294
+ #include <ATen/ops/zeros_like_meta.h>
1295
+
1296
+ namespace at {
1297
+
1298
+ namespace meta {
1299
+
1300
+
1301
+
1302
+ } // namespace meta
1303
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NumericUtils.h ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef __HIPCC__
4
+ #include <hip/hip_runtime.h>
5
+ #endif
6
+
7
+ #include <c10/macros/Macros.h>
8
+ #include <c10/util/BFloat16.h>
9
+ #include <c10/util/Float8_e4m3fn.h>
10
+ #include <c10/util/Float8_e4m3fnuz.h>
11
+ #include <c10/util/Float8_e5m2.h>
12
+ #include <c10/util/Float8_e5m2fnuz.h>
13
+ #include <c10/util/Half.h>
14
+ #include <c10/util/complex.h>
15
+
16
+ #include <cmath>
17
+ #include <type_traits>
18
+
19
+ namespace at {
20
+
21
+ // std::isnan isn't performant to use on integral types; it will
22
+ // (uselessly) convert to floating point and then do the test.
23
+ // This function is.
24
+
25
+ template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
26
+ inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
27
+ return false;
28
+ }
29
+
30
+ template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
31
+ inline C10_HOST_DEVICE bool _isnan(T val) {
32
+ #if defined(__CUDACC__) || defined(__HIPCC__)
33
+ return ::isnan(val);
34
+ #else
35
+ return std::isnan(val);
36
+ #endif
37
+ }
38
+
39
+ template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
40
+ inline C10_HOST_DEVICE bool _isnan(T val) {
41
+ return std::isnan(val.real()) || std::isnan(val.imag());
42
+ }
43
+
44
+ template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
45
+ inline C10_HOST_DEVICE bool _isnan(T val) {
46
+ return at::_isnan(static_cast<float>(val));
47
+ }
48
+
49
+ template <
50
+ typename T,
51
+ std::enable_if_t<std::is_same_v<T, at::BFloat16>, int> = 0>
52
+ inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
53
+ return at::_isnan(static_cast<float>(val));
54
+ }
55
+
56
+ inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
57
+ return at::_isnan(static_cast<float>(val));
58
+ }
59
+
60
+ template <
61
+ typename T,
62
+ std::enable_if_t<std::is_same_v<T, at::Float8_e5m2>, int> = 0>
63
+ inline C10_HOST_DEVICE bool _isnan(T val) {
64
+ return val.isnan();
65
+ }
66
+
67
+ template <
68
+ typename T,
69
+ std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fn>, int> = 0>
70
+ inline C10_HOST_DEVICE bool _isnan(T val) {
71
+ return val.isnan();
72
+ }
73
+
74
+ template <
75
+ typename T,
76
+ std::enable_if_t<std::is_same_v<T, at::Float8_e5m2fnuz>, int> = 0>
77
+ inline C10_HOST_DEVICE bool _isnan(T val) {
78
+ return val.isnan();
79
+ }
80
+
81
+ template <
82
+ typename T,
83
+ std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fnuz>, int> = 0>
84
+ inline C10_HOST_DEVICE bool _isnan(T val) {
85
+ return val.isnan();
86
+ }
87
+
88
+ // std::isinf isn't performant to use on integral types; it will
89
+ // (uselessly) convert to floating point and then do the test.
90
+ // This function is.
91
+
92
+ template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
93
+ inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
94
+ return false;
95
+ }
96
+
97
+ template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
98
+ inline C10_HOST_DEVICE bool _isinf(T val) {
99
+ #if defined(__CUDACC__) || defined(__HIPCC__)
100
+ return ::isinf(val);
101
+ #else
102
+ return std::isinf(val);
103
+ #endif
104
+ }
105
+
106
+ inline C10_HOST_DEVICE bool _isinf(at::Half val) {
107
+ return at::_isinf(static_cast<float>(val));
108
+ }
109
+
110
+ inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
111
+ return at::_isinf(static_cast<float>(val));
112
+ }
113
+
114
+ inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
115
+ return val.isinf();
116
+ }
117
+
118
+ inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) {
119
+ return false;
120
+ }
121
+
122
+ inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val) {
123
+ return false;
124
+ }
125
+
126
+ inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val) {
127
+ return false;
128
+ }
129
+
130
+ template <typename T>
131
+ C10_HOST_DEVICE inline T exp(T x) {
132
+ static_assert(
133
+ !std::is_same_v<T, double>,
134
+ "this template must be used with float or less precise type");
135
+ #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
136
+ // use __expf fast approximation for peak bandwidth
137
+ return __expf(x);
138
+ #else
139
+ return ::exp(x);
140
+ #endif
141
+ }
142
+
143
+ template <>
144
+ C10_HOST_DEVICE inline double exp<double>(double x) {
145
+ return ::exp(x);
146
+ }
147
+
148
+ template <typename T>
149
+ C10_HOST_DEVICE inline T log(T x) {
150
+ static_assert(
151
+ !std::is_same_v<T, double>,
152
+ "this template must be used with float or less precise type");
153
+ #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
154
+ // use __logf fast approximation for peak bandwidth
155
+ return __logf(x);
156
+ #else
157
+ return ::log(x);
158
+ #endif
159
+ }
160
+
161
+ template <>
162
+ C10_HOST_DEVICE inline double log<double>(double x) {
163
+ return ::log(x);
164
+ }
165
+
166
+ template <typename T>
167
+ C10_HOST_DEVICE inline T log1p(T x) {
168
+ static_assert(
169
+ !std::is_same_v<T, double>,
170
+ "this template must be used with float or less precise type");
171
+ #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
172
+ // use __logf fast approximation for peak bandwidth
173
+ // NOTE: There is no __log1pf so unfortunately we lose precision.
174
+ return __logf(1.0f + x);
175
+ #else
176
+ return ::log1p(x);
177
+ #endif
178
+ }
179
+
180
+ template <>
181
+ C10_HOST_DEVICE inline double log1p<double>(double x) {
182
+ return ::log1p(x);
183
+ }
184
+
185
+ template <typename T>
186
+ C10_HOST_DEVICE inline T tan(T x) {
187
+ static_assert(
188
+ !std::is_same_v<T, double>,
189
+ "this template must be used with float or less precise type");
190
+ #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
191
+ // use __tanf fast approximation for peak bandwidth
192
+ return __tanf(x);
193
+ #else
194
+ return ::tan(x);
195
+ #endif
196
+ }
197
+
198
+ template <>
199
+ C10_HOST_DEVICE inline double tan<double>(double x) {
200
+ return ::tan(x);
201
+ }
202
+
203
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/OpaqueTensorImpl.h ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/MemoryFormat.h>
4
+ #include <c10/core/SymIntArrayRef.h>
5
+ #include <c10/core/TensorImpl.h>
6
+ #include <c10/util/Exception.h>
7
+
8
+ namespace at {
9
+
10
+ // An "Opaque" TensorImpl -- there are no strides and (for now)
11
+ // even data() is not supported (thus no pointer arithmetic).
12
+
13
+ // NOTE: We could allow data() in the future, but would have to ensure pointer
14
+ // arithmetic code is properly guarded.
15
+ //
16
+ // NOTE: This does not support resize_ (and other metadata-changing ops) because
17
+ // of `shallow_copy_and_detach`. We would need to define an interface to
18
+ // "shallow copy" in order to add support.
19
+
20
+ template <typename OpaqueHandle>
21
+ struct TORCH_API OpaqueTensorImpl : public TensorImpl {
22
+ // public constructor for now...
23
+ OpaqueTensorImpl(
24
+ at::DispatchKeySet key_set,
25
+ const caffe2::TypeMeta data_type,
26
+ c10::Device device,
27
+ OpaqueHandle opaque_handle,
28
+ c10::IntArrayRef sizes,
29
+ bool is_non_overlapping_and_dense = true)
30
+ : TensorImpl(key_set, data_type, device),
31
+ opaque_handle_(std::move(opaque_handle)) {
32
+ set_storage_access_should_throw();
33
+ set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
34
+ sizes_and_strides_.set_sizes(sizes);
35
+ refresh_numel();
36
+ // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
37
+ is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
38
+ }
39
+
40
+ // Destructor doesn't call release_resources because it's
41
+ // unnecessary; don't forget to change that if needed!
42
+ void release_resources() override {
43
+ TensorImpl::release_resources();
44
+ opaque_handle_ = {};
45
+ }
46
+
47
+ void set_size(int64_t dim, int64_t new_size) override {
48
+ AT_ERROR("opaque tensors do not have set_size");
49
+ }
50
+
51
+ void set_stride(int64_t dim, int64_t new_stride) override {
52
+ AT_ERROR("opaque tensors do not have set_stride");
53
+ }
54
+
55
+ void set_storage_offset(int64_t storage_offset) override {
56
+ AT_ERROR("opaque tensors do not have set_storage_offset");
57
+ }
58
+
59
+ #ifdef DEBUG
60
+ bool has_storage() const override {
61
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
62
+ !storage_, "OpaqueTensorImpl assumes that storage_ is never set");
63
+ return false;
64
+ }
65
+ #endif
66
+
67
+ /**
68
+ * Return a TensorImpl that is a shallow-copy of this TensorImpl.
69
+ *
70
+ * For usage of `version_counter` and `allow_tensor_metadata_change`,
71
+ * see NOTE [ TensorImpl Shallow-Copying ].
72
+ */
73
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
74
+ const c10::VariableVersion& version_counter,
75
+ bool allow_tensor_metadata_change) const override {
76
+ auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
77
+ key_set(),
78
+ dtype(),
79
+ device(),
80
+ opaque_handle_,
81
+ sizes_and_strides_.sizes_arrayref());
82
+ copy_tensor_metadata(
83
+ /*src_opaque_impl=*/this,
84
+ /*dest_opaque_impl=*/impl.get(),
85
+ /*version_counter=*/version_counter,
86
+ /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
87
+ impl->refresh_numel();
88
+ return impl;
89
+ }
90
+
91
+ /**
92
+ * Return a TensorImpl that is a shallow-copy of this TensorImpl.
93
+ *
94
+ * For usage of `version_counter` and `allow_tensor_metadata_change`,
95
+ * see NOTE [ TensorImpl Shallow-Copying ].
96
+ */
97
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
98
+ c10::VariableVersion&& version_counter,
99
+ bool allow_tensor_metadata_change) const override {
100
+ auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
101
+ key_set(),
102
+ dtype(),
103
+ device(),
104
+ opaque_handle_,
105
+ sizes_and_strides_.sizes_arrayref());
106
+ copy_tensor_metadata(
107
+ /*src_opaque_impl=*/this,
108
+ /*dest_opaque_impl=*/impl.get(),
109
+ /*version_counter=*/std::move(version_counter),
110
+ /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
111
+ impl->refresh_numel();
112
+ return impl;
113
+ }
114
+
115
+ /**
116
+ * Shallow-copies data from another TensorImpl into this TensorImpl.
117
+ *
118
+ * For why this function doesn't check this TensorImpl's
119
+ * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
120
+ */
121
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
122
+ AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
123
+ auto opaque_impl =
124
+ static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
125
+ copy_tensor_metadata(
126
+ /*src_impl=*/opaque_impl,
127
+ /*dest_impl=*/this,
128
+ /*version_counter=*/version_counter(),
129
+ /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
130
+ refresh_numel();
131
+ }
132
+
133
+ const OpaqueHandle& opaque_handle() const {
134
+ return opaque_handle_;
135
+ }
136
+
137
+ OpaqueHandle& unsafe_opaque_handle() {
138
+ return opaque_handle_;
139
+ }
140
+
141
+ protected:
142
+ /**
143
+ * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
144
+ * storage_offset) from one TensorImpl to another TensorImpl.
145
+ *
146
+ * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
147
+ * [ TensorImpl Shallow-Copying ].
148
+ */
149
+ static void copy_tensor_metadata(
150
+ const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
151
+ OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
152
+ const c10::VariableVersion& version_counter,
153
+ bool allow_tensor_metadata_change) {
154
+ TensorImpl::copy_tensor_metadata(
155
+ src_opaque_impl,
156
+ dest_opaque_impl,
157
+ version_counter,
158
+ allow_tensor_metadata_change);
159
+
160
+ // OpaqueTensorImpl-specific fields.
161
+ dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
162
+ }
163
+
164
+ static void copy_tensor_metadata(
165
+ const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
166
+ OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
167
+ c10::VariableVersion&& version_counter,
168
+ bool allow_tensor_metadata_change) {
169
+ TensorImpl::copy_tensor_metadata(
170
+ src_opaque_impl,
171
+ dest_opaque_impl,
172
+ std::move(version_counter),
173
+ allow_tensor_metadata_change);
174
+
175
+ // OpaqueTensorImpl-specific fields.
176
+ dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
177
+ }
178
+
179
+ private:
180
+ const char* tensorimpl_type_name() const override {
181
+ return "OpaqueTensorImpl";
182
+ }
183
+
184
+ OpaqueHandle opaque_handle_;
185
+ };
186
+
187
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Operators.h ADDED
@@ -0,0 +1,1358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operators.h
4
+
5
+ #ifdef TORCH_ASSERT_NO_OPERATORS
6
+ #error This change adds a dependency on native_functions.yaml, \
7
+ meaning the file will need to be re-compiled every time an operator \
8
+ is changed or added. Consider if your change would be better placed in \
9
+ another file, or if a more specific header might achieve the same goal. \
10
+ See NOTE: [Tensor vs. TensorBase]
11
+ #endif
12
+
13
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
14
+ #error This change adds a dependency on all pytorch operators, meaning the \
15
+ file will need to be re-compiled every time an operator is changed or added. \
16
+ Consider including a specific operator from <ATen/ops/{my_operator}_ops.h> \
17
+ and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
18
+ #endif
19
+
20
+ #include <c10/core/SymInt.h>
21
+ #include <c10/core/SymIntArrayRef.h>
22
+ #include <c10/core/Scalar.h>
23
+ #include <c10/core/TensorOptions.h>
24
+ #include <c10/core/QScheme.h>
25
+ #include <c10/util/OptionalArrayRef.h>
26
+ #include <tuple>
27
+ #include <vector>
28
+
29
+ #include <ATen/ops/_adaptive_avg_pool2d_ops.h>
30
+ #include <ATen/ops/_adaptive_avg_pool2d_backward_ops.h>
31
+ #include <ATen/ops/_adaptive_avg_pool3d_ops.h>
32
+ #include <ATen/ops/_adaptive_avg_pool3d_backward_ops.h>
33
+ #include <ATen/ops/_add_batch_dim_ops.h>
34
+ #include <ATen/ops/_add_relu_ops.h>
35
+ #include <ATen/ops/_addmm_activation_ops.h>
36
+ #include <ATen/ops/_aminmax_ops.h>
37
+ #include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_ops.h>
38
+ #include <ATen/ops/_amp_update_scale_ops.h>
39
+ #include <ATen/ops/_assert_async_ops.h>
40
+ #include <ATen/ops/_assert_scalar_ops.h>
41
+ #include <ATen/ops/_assert_tensor_metadata_ops.h>
42
+ #include <ATen/ops/_autocast_to_full_precision_ops.h>
43
+ #include <ATen/ops/_autocast_to_reduced_precision_ops.h>
44
+ #include <ATen/ops/_backward_ops.h>
45
+ #include <ATen/ops/_batch_norm_impl_index_ops.h>
46
+ #include <ATen/ops/_batch_norm_impl_index_backward_ops.h>
47
+ #include <ATen/ops/_cast_Byte_ops.h>
48
+ #include <ATen/ops/_cast_Char_ops.h>
49
+ #include <ATen/ops/_cast_Double_ops.h>
50
+ #include <ATen/ops/_cast_Float_ops.h>
51
+ #include <ATen/ops/_cast_Half_ops.h>
52
+ #include <ATen/ops/_cast_Int_ops.h>
53
+ #include <ATen/ops/_cast_Long_ops.h>
54
+ #include <ATen/ops/_cast_Short_ops.h>
55
+ #include <ATen/ops/_cdist_backward_ops.h>
56
+ #include <ATen/ops/_cdist_forward_ops.h>
57
+ #include <ATen/ops/_cholesky_solve_helper_ops.h>
58
+ #include <ATen/ops/_choose_qparams_per_tensor_ops.h>
59
+ #include <ATen/ops/_chunk_cat_ops.h>
60
+ #include <ATen/ops/_coalesce_ops.h>
61
+ #include <ATen/ops/_coalesced_ops.h>
62
+ #include <ATen/ops/_compute_linear_combination_ops.h>
63
+ #include <ATen/ops/_conj_ops.h>
64
+ #include <ATen/ops/_conj_copy_ops.h>
65
+ #include <ATen/ops/_conj_physical_ops.h>
66
+ #include <ATen/ops/_conv_depthwise2d_ops.h>
67
+ #include <ATen/ops/_convert_indices_from_coo_to_csr_ops.h>
68
+ #include <ATen/ops/_convert_indices_from_csr_to_coo_ops.h>
69
+ #include <ATen/ops/_convert_weight_to_int4pack_ops.h>
70
+ #include <ATen/ops/_convolution_ops.h>
71
+ #include <ATen/ops/_convolution_double_backward_ops.h>
72
+ #include <ATen/ops/_convolution_mode_ops.h>
73
+ #include <ATen/ops/_copy_from_ops.h>
74
+ #include <ATen/ops/_copy_from_and_resize_ops.h>
75
+ #include <ATen/ops/_cslt_compress_ops.h>
76
+ #include <ATen/ops/_cslt_sparse_mm_ops.h>
77
+ #include <ATen/ops/_cslt_sparse_mm_search_ops.h>
78
+ #include <ATen/ops/_ctc_loss_ops.h>
79
+ #include <ATen/ops/_ctc_loss_backward_ops.h>
80
+ #include <ATen/ops/_cudnn_ctc_loss_ops.h>
81
+ #include <ATen/ops/_cudnn_init_dropout_state_ops.h>
82
+ #include <ATen/ops/_cudnn_rnn_ops.h>
83
+ #include <ATen/ops/_cudnn_rnn_backward_ops.h>
84
+ #include <ATen/ops/_cudnn_rnn_flatten_weight_ops.h>
85
+ #include <ATen/ops/_cufft_clear_plan_cache_ops.h>
86
+ #include <ATen/ops/_cufft_get_plan_cache_max_size_ops.h>
87
+ #include <ATen/ops/_cufft_get_plan_cache_size_ops.h>
88
+ #include <ATen/ops/_cufft_set_plan_cache_max_size_ops.h>
89
+ #include <ATen/ops/_cummax_helper_ops.h>
90
+ #include <ATen/ops/_cummin_helper_ops.h>
91
+ #include <ATen/ops/_debug_has_internal_overlap_ops.h>
92
+ #include <ATen/ops/_dimI_ops.h>
93
+ #include <ATen/ops/_dimV_ops.h>
94
+ #include <ATen/ops/_dim_arange_ops.h>
95
+ #include <ATen/ops/_dirichlet_grad_ops.h>
96
+ #include <ATen/ops/_efficient_attention_backward_ops.h>
97
+ #include <ATen/ops/_efficient_attention_forward_ops.h>
98
+ #include <ATen/ops/_efficientzerotensor_ops.h>
99
+ #include <ATen/ops/_embedding_bag_ops.h>
100
+ #include <ATen/ops/_embedding_bag_backward_ops.h>
101
+ #include <ATen/ops/_embedding_bag_dense_backward_ops.h>
102
+ #include <ATen/ops/_embedding_bag_forward_only_ops.h>
103
+ #include <ATen/ops/_embedding_bag_per_sample_weights_backward_ops.h>
104
+ #include <ATen/ops/_embedding_bag_sparse_backward_ops.h>
105
+ #include <ATen/ops/_empty_affine_quantized_ops.h>
106
+ #include <ATen/ops/_empty_per_channel_affine_quantized_ops.h>
107
+ #include <ATen/ops/_euclidean_dist_ops.h>
108
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_ops.h>
109
+ #include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_ops.h>
110
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_ops.h>
111
+ #include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_ops.h>
112
+ #include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_ops.h>
113
+ #include <ATen/ops/_fft_c2c_ops.h>
114
+ #include <ATen/ops/_fft_c2r_ops.h>
115
+ #include <ATen/ops/_fft_r2c_ops.h>
116
+ #include <ATen/ops/_fill_mem_eff_dropout_mask_ops.h>
117
+ #include <ATen/ops/_flash_attention_backward_ops.h>
118
+ #include <ATen/ops/_flash_attention_forward_ops.h>
119
+ #include <ATen/ops/_foobar_ops.h>
120
+ #include <ATen/ops/_foreach_abs_ops.h>
121
+ #include <ATen/ops/_foreach_acos_ops.h>
122
+ #include <ATen/ops/_foreach_add_ops.h>
123
+ #include <ATen/ops/_foreach_addcdiv_ops.h>
124
+ #include <ATen/ops/_foreach_addcmul_ops.h>
125
+ #include <ATen/ops/_foreach_asin_ops.h>
126
+ #include <ATen/ops/_foreach_atan_ops.h>
127
+ #include <ATen/ops/_foreach_ceil_ops.h>
128
+ #include <ATen/ops/_foreach_clamp_max_ops.h>
129
+ #include <ATen/ops/_foreach_clamp_min_ops.h>
130
+ #include <ATen/ops/_foreach_copy_ops.h>
131
+ #include <ATen/ops/_foreach_cos_ops.h>
132
+ #include <ATen/ops/_foreach_cosh_ops.h>
133
+ #include <ATen/ops/_foreach_div_ops.h>
134
+ #include <ATen/ops/_foreach_erf_ops.h>
135
+ #include <ATen/ops/_foreach_erfc_ops.h>
136
+ #include <ATen/ops/_foreach_exp_ops.h>
137
+ #include <ATen/ops/_foreach_expm1_ops.h>
138
+ #include <ATen/ops/_foreach_floor_ops.h>
139
+ #include <ATen/ops/_foreach_frac_ops.h>
140
+ #include <ATen/ops/_foreach_lerp_ops.h>
141
+ #include <ATen/ops/_foreach_lgamma_ops.h>
142
+ #include <ATen/ops/_foreach_log_ops.h>
143
+ #include <ATen/ops/_foreach_log10_ops.h>
144
+ #include <ATen/ops/_foreach_log1p_ops.h>
145
+ #include <ATen/ops/_foreach_log2_ops.h>
146
+ #include <ATen/ops/_foreach_maximum_ops.h>
147
+ #include <ATen/ops/_foreach_minimum_ops.h>
148
+ #include <ATen/ops/_foreach_mul_ops.h>
149
+ #include <ATen/ops/_foreach_neg_ops.h>
150
+ #include <ATen/ops/_foreach_norm_ops.h>
151
+ #include <ATen/ops/_foreach_pow_ops.h>
152
+ #include <ATen/ops/_foreach_reciprocal_ops.h>
153
+ #include <ATen/ops/_foreach_round_ops.h>
154
+ #include <ATen/ops/_foreach_sigmoid_ops.h>
155
+ #include <ATen/ops/_foreach_sign_ops.h>
156
+ #include <ATen/ops/_foreach_sin_ops.h>
157
+ #include <ATen/ops/_foreach_sinh_ops.h>
158
+ #include <ATen/ops/_foreach_sqrt_ops.h>
159
+ #include <ATen/ops/_foreach_sub_ops.h>
160
+ #include <ATen/ops/_foreach_tan_ops.h>
161
+ #include <ATen/ops/_foreach_tanh_ops.h>
162
+ #include <ATen/ops/_foreach_trunc_ops.h>
163
+ #include <ATen/ops/_foreach_zero_ops.h>
164
+ #include <ATen/ops/_functional_assert_async_ops.h>
165
+ #include <ATen/ops/_functional_assert_scalar_ops.h>
166
+ #include <ATen/ops/_functional_sym_constrain_range_ops.h>
167
+ #include <ATen/ops/_functional_sym_constrain_range_for_size_ops.h>
168
+ #include <ATen/ops/_fused_adam_ops.h>
169
+ #include <ATen/ops/_fused_adamw_ops.h>
170
+ #include <ATen/ops/_fused_dropout_ops.h>
171
+ #include <ATen/ops/_fused_moving_avg_obs_fq_helper_ops.h>
172
+ #include <ATen/ops/_fused_sdp_choice_ops.h>
173
+ #include <ATen/ops/_fused_sgd_ops.h>
174
+ #include <ATen/ops/_fw_primal_ops.h>
175
+ #include <ATen/ops/_fw_primal_copy_ops.h>
176
+ #include <ATen/ops/_gather_sparse_backward_ops.h>
177
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_ops.h>
178
+ #include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_ops.h>
179
+ #include <ATen/ops/_has_compatible_shallow_copy_type_ops.h>
180
+ #include <ATen/ops/_has_same_storage_numel_ops.h>
181
+ #include <ATen/ops/_histogramdd_bin_edges_ops.h>
182
+ #include <ATen/ops/_histogramdd_from_bin_cts_ops.h>
183
+ #include <ATen/ops/_histogramdd_from_bin_tensors_ops.h>
184
+ #include <ATen/ops/_index_put_impl_ops.h>
185
+ #include <ATen/ops/_indices_ops.h>
186
+ #include <ATen/ops/_indices_copy_ops.h>
187
+ #include <ATen/ops/_int_mm_ops.h>
188
+ #include <ATen/ops/_is_all_true_ops.h>
189
+ #include <ATen/ops/_is_any_true_ops.h>
190
+ #include <ATen/ops/_is_zerotensor_ops.h>
191
+ #include <ATen/ops/_lazy_clone_ops.h>
192
+ #include <ATen/ops/_linalg_check_errors_ops.h>
193
+ #include <ATen/ops/_linalg_det_ops.h>
194
+ #include <ATen/ops/_linalg_eigh_ops.h>
195
+ #include <ATen/ops/_linalg_eigvals_ops.h>
196
+ #include <ATen/ops/_linalg_slogdet_ops.h>
197
+ #include <ATen/ops/_linalg_solve_ex_ops.h>
198
+ #include <ATen/ops/_linalg_svd_ops.h>
199
+ #include <ATen/ops/_local_scalar_dense_ops.h>
200
+ #include <ATen/ops/_log_softmax_ops.h>
201
+ #include <ATen/ops/_log_softmax_backward_data_ops.h>
202
+ #include <ATen/ops/_logcumsumexp_ops.h>
203
+ #include <ATen/ops/_lstm_mps_ops.h>
204
+ #include <ATen/ops/_lu_with_info_ops.h>
205
+ #include <ATen/ops/_make_dep_token_ops.h>
206
+ #include <ATen/ops/_make_dual_ops.h>
207
+ #include <ATen/ops/_make_dual_copy_ops.h>
208
+ #include <ATen/ops/_make_per_channel_quantized_tensor_ops.h>
209
+ #include <ATen/ops/_make_per_tensor_quantized_tensor_ops.h>
210
+ #include <ATen/ops/_masked_scale_ops.h>
211
+ #include <ATen/ops/_masked_softmax_ops.h>
212
+ #include <ATen/ops/_masked_softmax_backward_ops.h>
213
+ #include <ATen/ops/_mixed_dtypes_linear_ops.h>
214
+ #include <ATen/ops/_mkldnn_reshape_ops.h>
215
+ #include <ATen/ops/_mkldnn_transpose_ops.h>
216
+ #include <ATen/ops/_mps_convolution_ops.h>
217
+ #include <ATen/ops/_mps_convolution_transpose_ops.h>
218
+ #include <ATen/ops/_native_batch_norm_legit_ops.h>
219
+ #include <ATen/ops/_native_batch_norm_legit_no_training_ops.h>
220
+ #include <ATen/ops/_native_multi_head_attention_ops.h>
221
+ #include <ATen/ops/_neg_view_ops.h>
222
+ #include <ATen/ops/_neg_view_copy_ops.h>
223
+ #include <ATen/ops/_nested_from_padded_ops.h>
224
+ #include <ATen/ops/_nested_from_padded_and_nested_example_ops.h>
225
+ #include <ATen/ops/_nested_get_jagged_dummy_ops.h>
226
+ #include <ATen/ops/_nested_get_lengths_ops.h>
227
+ #include <ATen/ops/_nested_get_offsets_ops.h>
228
+ #include <ATen/ops/_nested_get_ragged_idx_ops.h>
229
+ #include <ATen/ops/_nested_get_values_ops.h>
230
+ #include <ATen/ops/_nested_get_values_copy_ops.h>
231
+ #include <ATen/ops/_nested_select_backward_ops.h>
232
+ #include <ATen/ops/_nested_sum_backward_ops.h>
233
+ #include <ATen/ops/_nested_tensor_from_mask_ops.h>
234
+ #include <ATen/ops/_nested_tensor_from_mask_left_aligned_ops.h>
235
+ #include <ATen/ops/_nested_tensor_from_tensor_list_ops.h>
236
+ #include <ATen/ops/_nested_tensor_size_ops.h>
237
+ #include <ATen/ops/_nested_tensor_softmax_with_shape_ops.h>
238
+ #include <ATen/ops/_nested_tensor_storage_offsets_ops.h>
239
+ #include <ATen/ops/_nested_tensor_strides_ops.h>
240
+ #include <ATen/ops/_nested_view_from_buffer_ops.h>
241
+ #include <ATen/ops/_nested_view_from_buffer_copy_ops.h>
242
+ #include <ATen/ops/_nested_view_from_jagged_ops.h>
243
+ #include <ATen/ops/_nested_view_from_jagged_copy_ops.h>
244
+ #include <ATen/ops/_new_zeros_with_same_feature_meta_ops.h>
245
+ #include <ATen/ops/_nnpack_available_ops.h>
246
+ #include <ATen/ops/_nnpack_spatial_convolution_ops.h>
247
+ #include <ATen/ops/_nnz_ops.h>
248
+ #include <ATen/ops/_pack_padded_sequence_ops.h>
249
+ #include <ATen/ops/_pack_padded_sequence_backward_ops.h>
250
+ #include <ATen/ops/_pad_circular_ops.h>
251
+ #include <ATen/ops/_pad_enum_ops.h>
252
+ #include <ATen/ops/_pad_packed_sequence_ops.h>
253
+ #include <ATen/ops/_pdist_backward_ops.h>
254
+ #include <ATen/ops/_pdist_forward_ops.h>
255
+ #include <ATen/ops/_pin_memory_ops.h>
256
+ #include <ATen/ops/_prelu_kernel_ops.h>
257
+ #include <ATen/ops/_prelu_kernel_backward_ops.h>
258
+ #include <ATen/ops/_print_ops.h>
259
+ #include <ATen/ops/_propagate_xla_data_ops.h>
260
+ #include <ATen/ops/_remove_batch_dim_ops.h>
261
+ #include <ATen/ops/_reshape_alias_ops.h>
262
+ #include <ATen/ops/_reshape_alias_copy_ops.h>
263
+ #include <ATen/ops/_reshape_copy_ops.h>
264
+ #include <ATen/ops/_reshape_from_tensor_ops.h>
265
+ #include <ATen/ops/_resize_output_ops.h>
266
+ #include <ATen/ops/_rowwise_prune_ops.h>
267
+ #include <ATen/ops/_sample_dirichlet_ops.h>
268
+ #include <ATen/ops/_saturate_weight_to_fp16_ops.h>
269
+ #include <ATen/ops/_scaled_dot_product_attention_math_ops.h>
270
+ #include <ATen/ops/_scaled_dot_product_cudnn_attention_ops.h>
271
+ #include <ATen/ops/_scaled_dot_product_efficient_attention_ops.h>
272
+ #include <ATen/ops/_scaled_dot_product_efficient_attention_backward_ops.h>
273
+ #include <ATen/ops/_scaled_dot_product_flash_attention_ops.h>
274
+ #include <ATen/ops/_scaled_dot_product_flash_attention_backward_ops.h>
275
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_ops.h>
276
+ #include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_ops.h>
277
+ #include <ATen/ops/_scaled_mm_ops.h>
278
+ #include <ATen/ops/_segment_reduce_backward_ops.h>
279
+ #include <ATen/ops/_shape_as_tensor_ops.h>
280
+ #include <ATen/ops/_slow_conv2d_backward_ops.h>
281
+ #include <ATen/ops/_slow_conv2d_forward_ops.h>
282
+ #include <ATen/ops/_sobol_engine_draw_ops.h>
283
+ #include <ATen/ops/_sobol_engine_ff_ops.h>
284
+ #include <ATen/ops/_sobol_engine_initialize_state_ops.h>
285
+ #include <ATen/ops/_sobol_engine_scramble_ops.h>
286
+ #include <ATen/ops/_softmax_ops.h>
287
+ #include <ATen/ops/_softmax_backward_data_ops.h>
288
+ #include <ATen/ops/_sparse_addmm_ops.h>
289
+ #include <ATen/ops/_sparse_broadcast_to_ops.h>
290
+ #include <ATen/ops/_sparse_broadcast_to_copy_ops.h>
291
+ #include <ATen/ops/_sparse_bsc_tensor_unsafe_ops.h>
292
+ #include <ATen/ops/_sparse_bsr_tensor_unsafe_ops.h>
293
+ #include <ATen/ops/_sparse_compressed_tensor_unsafe_ops.h>
294
+ #include <ATen/ops/_sparse_coo_tensor_unsafe_ops.h>
295
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_ops.h>
296
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_ops.h>
297
+ #include <ATen/ops/_sparse_csc_tensor_unsafe_ops.h>
298
+ #include <ATen/ops/_sparse_csr_prod_ops.h>
299
+ #include <ATen/ops/_sparse_csr_sum_ops.h>
300
+ #include <ATen/ops/_sparse_csr_tensor_unsafe_ops.h>
301
+ #include <ATen/ops/_sparse_log_softmax_ops.h>
302
+ #include <ATen/ops/_sparse_log_softmax_backward_data_ops.h>
303
+ #include <ATen/ops/_sparse_mask_projection_ops.h>
304
+ #include <ATen/ops/_sparse_mm_ops.h>
305
+ #include <ATen/ops/_sparse_mm_reduce_impl_ops.h>
306
+ #include <ATen/ops/_sparse_mm_reduce_impl_backward_ops.h>
307
+ #include <ATen/ops/_sparse_semi_structured_linear_ops.h>
308
+ #include <ATen/ops/_sparse_softmax_ops.h>
309
+ #include <ATen/ops/_sparse_softmax_backward_data_ops.h>
310
+ #include <ATen/ops/_sparse_sparse_matmul_ops.h>
311
+ #include <ATen/ops/_sparse_sum_ops.h>
312
+ #include <ATen/ops/_sparse_sum_backward_ops.h>
313
+ #include <ATen/ops/_spdiags_ops.h>
314
+ #include <ATen/ops/_stack_ops.h>
315
+ #include <ATen/ops/_standard_gamma_ops.h>
316
+ #include <ATen/ops/_standard_gamma_grad_ops.h>
317
+ #include <ATen/ops/_test_ambiguous_defaults_ops.h>
318
+ #include <ATen/ops/_test_autograd_multiple_dispatch_ops.h>
319
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_ops.h>
320
+ #include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_ops.h>
321
+ #include <ATen/ops/_test_check_tensor_ops.h>
322
+ #include <ATen/ops/_test_functorch_fallback_ops.h>
323
+ #include <ATen/ops/_test_optional_filled_intlist_ops.h>
324
+ #include <ATen/ops/_test_optional_floatlist_ops.h>
325
+ #include <ATen/ops/_test_optional_intlist_ops.h>
326
+ #include <ATen/ops/_test_parallel_materialize_ops.h>
327
+ #include <ATen/ops/_test_serialization_subcmul_ops.h>
328
+ #include <ATen/ops/_test_string_default_ops.h>
329
+ #include <ATen/ops/_test_warn_in_autograd_ops.h>
330
+ #include <ATen/ops/_thnn_differentiable_gru_cell_backward_ops.h>
331
+ #include <ATen/ops/_thnn_differentiable_lstm_cell_backward_ops.h>
332
+ #include <ATen/ops/_thnn_fused_gru_cell_ops.h>
333
+ #include <ATen/ops/_thnn_fused_gru_cell_backward_ops.h>
334
+ #include <ATen/ops/_thnn_fused_lstm_cell_ops.h>
335
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_ops.h>
336
+ #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_ops.h>
337
+ #include <ATen/ops/_to_copy_ops.h>
338
+ #include <ATen/ops/_to_cpu_ops.h>
339
+ #include <ATen/ops/_to_dense_ops.h>
340
+ #include <ATen/ops/_to_sparse_ops.h>
341
+ #include <ATen/ops/_to_sparse_bsc_ops.h>
342
+ #include <ATen/ops/_to_sparse_bsr_ops.h>
343
+ #include <ATen/ops/_to_sparse_csc_ops.h>
344
+ #include <ATen/ops/_to_sparse_csr_ops.h>
345
+ #include <ATen/ops/_to_sparse_semi_structured_ops.h>
346
+ #include <ATen/ops/_transform_bias_rescale_qkv_ops.h>
347
+ #include <ATen/ops/_transformer_encoder_layer_fwd_ops.h>
348
+ #include <ATen/ops/_trilinear_ops.h>
349
+ #include <ATen/ops/_triton_multi_head_attention_ops.h>
350
+ #include <ATen/ops/_triton_scaled_dot_attention_ops.h>
351
+ #include <ATen/ops/_unique_ops.h>
352
+ #include <ATen/ops/_unique2_ops.h>
353
+ #include <ATen/ops/_unpack_dual_ops.h>
354
+ #include <ATen/ops/_unsafe_index_ops.h>
355
+ #include <ATen/ops/_unsafe_index_put_ops.h>
356
+ #include <ATen/ops/_unsafe_view_ops.h>
357
+ #include <ATen/ops/_upsample_bicubic2d_aa_ops.h>
358
+ #include <ATen/ops/_upsample_bicubic2d_aa_backward_ops.h>
359
+ #include <ATen/ops/_upsample_bilinear2d_aa_ops.h>
360
+ #include <ATen/ops/_upsample_bilinear2d_aa_backward_ops.h>
361
+ #include <ATen/ops/_upsample_nearest_exact1d_ops.h>
362
+ #include <ATen/ops/_upsample_nearest_exact1d_backward_ops.h>
363
+ #include <ATen/ops/_upsample_nearest_exact2d_ops.h>
364
+ #include <ATen/ops/_upsample_nearest_exact2d_backward_ops.h>
365
+ #include <ATen/ops/_upsample_nearest_exact3d_ops.h>
366
+ #include <ATen/ops/_upsample_nearest_exact3d_backward_ops.h>
367
+ #include <ATen/ops/_use_cudnn_ctc_loss_ops.h>
368
+ #include <ATen/ops/_use_cudnn_rnn_flatten_weight_ops.h>
369
+ #include <ATen/ops/_validate_compressed_sparse_indices_ops.h>
370
+ #include <ATen/ops/_validate_sparse_bsc_tensor_args_ops.h>
371
+ #include <ATen/ops/_validate_sparse_bsr_tensor_args_ops.h>
372
+ #include <ATen/ops/_validate_sparse_compressed_tensor_args_ops.h>
373
+ #include <ATen/ops/_validate_sparse_coo_tensor_args_ops.h>
374
+ #include <ATen/ops/_validate_sparse_csc_tensor_args_ops.h>
375
+ #include <ATen/ops/_validate_sparse_csr_tensor_args_ops.h>
376
+ #include <ATen/ops/_values_ops.h>
377
+ #include <ATen/ops/_values_copy_ops.h>
378
+ #include <ATen/ops/_version_ops.h>
379
+ #include <ATen/ops/_weight_int4pack_mm_ops.h>
380
+ #include <ATen/ops/_weight_int8pack_mm_ops.h>
381
+ #include <ATen/ops/_weight_norm_ops.h>
382
+ #include <ATen/ops/_weight_norm_differentiable_backward_ops.h>
383
+ #include <ATen/ops/_weight_norm_interface_ops.h>
384
+ #include <ATen/ops/_weight_norm_interface_backward_ops.h>
385
+ #include <ATen/ops/abs_ops.h>
386
+ #include <ATen/ops/absolute_ops.h>
387
+ #include <ATen/ops/acos_ops.h>
388
+ #include <ATen/ops/acosh_ops.h>
389
+ #include <ATen/ops/adaptive_avg_pool1d_ops.h>
390
+ #include <ATen/ops/adaptive_avg_pool2d_ops.h>
391
+ #include <ATen/ops/adaptive_avg_pool3d_ops.h>
392
+ #include <ATen/ops/adaptive_avg_pool3d_backward_ops.h>
393
+ #include <ATen/ops/adaptive_max_pool1d_ops.h>
394
+ #include <ATen/ops/adaptive_max_pool2d_ops.h>
395
+ #include <ATen/ops/adaptive_max_pool2d_backward_ops.h>
396
+ #include <ATen/ops/adaptive_max_pool3d_ops.h>
397
+ #include <ATen/ops/adaptive_max_pool3d_backward_ops.h>
398
+ #include <ATen/ops/add_ops.h>
399
+ #include <ATen/ops/addbmm_ops.h>
400
+ #include <ATen/ops/addcdiv_ops.h>
401
+ #include <ATen/ops/addcmul_ops.h>
402
+ #include <ATen/ops/addmm_ops.h>
403
+ #include <ATen/ops/addmv_ops.h>
404
+ #include <ATen/ops/addr_ops.h>
405
+ #include <ATen/ops/adjoint_ops.h>
406
+ #include <ATen/ops/affine_grid_generator_ops.h>
407
+ #include <ATen/ops/affine_grid_generator_backward_ops.h>
408
+ #include <ATen/ops/alias_ops.h>
409
+ #include <ATen/ops/alias_copy_ops.h>
410
+ #include <ATen/ops/align_as_ops.h>
411
+ #include <ATen/ops/align_tensors_ops.h>
412
+ #include <ATen/ops/align_to_ops.h>
413
+ #include <ATen/ops/all_ops.h>
414
+ #include <ATen/ops/allclose_ops.h>
415
+ #include <ATen/ops/alpha_dropout_ops.h>
416
+ #include <ATen/ops/amax_ops.h>
417
+ #include <ATen/ops/amin_ops.h>
418
+ #include <ATen/ops/aminmax_ops.h>
419
+ #include <ATen/ops/and_ops.h>
420
+ #include <ATen/ops/angle_ops.h>
421
+ #include <ATen/ops/any_ops.h>
422
+ #include <ATen/ops/arange_ops.h>
423
+ #include <ATen/ops/arccos_ops.h>
424
+ #include <ATen/ops/arccosh_ops.h>
425
+ #include <ATen/ops/arcsin_ops.h>
426
+ #include <ATen/ops/arcsinh_ops.h>
427
+ #include <ATen/ops/arctan_ops.h>
428
+ #include <ATen/ops/arctan2_ops.h>
429
+ #include <ATen/ops/arctanh_ops.h>
430
+ #include <ATen/ops/argmax_ops.h>
431
+ #include <ATen/ops/argmin_ops.h>
432
+ #include <ATen/ops/argsort_ops.h>
433
+ #include <ATen/ops/argwhere_ops.h>
434
+ #include <ATen/ops/as_strided_ops.h>
435
+ #include <ATen/ops/as_strided_copy_ops.h>
436
+ #include <ATen/ops/as_strided_scatter_ops.h>
437
+ #include <ATen/ops/asin_ops.h>
438
+ #include <ATen/ops/asinh_ops.h>
439
+ #include <ATen/ops/atan_ops.h>
440
+ #include <ATen/ops/atan2_ops.h>
441
+ #include <ATen/ops/atanh_ops.h>
442
+ #include <ATen/ops/atleast_1d_ops.h>
443
+ #include <ATen/ops/atleast_2d_ops.h>
444
+ #include <ATen/ops/atleast_3d_ops.h>
445
+ #include <ATen/ops/avg_pool1d_ops.h>
446
+ #include <ATen/ops/avg_pool2d_ops.h>
447
+ #include <ATen/ops/avg_pool2d_backward_ops.h>
448
+ #include <ATen/ops/avg_pool3d_ops.h>
449
+ #include <ATen/ops/avg_pool3d_backward_ops.h>
450
+ #include <ATen/ops/baddbmm_ops.h>
451
+ #include <ATen/ops/bartlett_window_ops.h>
452
+ #include <ATen/ops/batch_norm_ops.h>
453
+ #include <ATen/ops/batch_norm_backward_elemt_ops.h>
454
+ #include <ATen/ops/batch_norm_backward_reduce_ops.h>
455
+ #include <ATen/ops/batch_norm_elemt_ops.h>
456
+ #include <ATen/ops/batch_norm_gather_stats_ops.h>
457
+ #include <ATen/ops/batch_norm_gather_stats_with_counts_ops.h>
458
+ #include <ATen/ops/batch_norm_stats_ops.h>
459
+ #include <ATen/ops/batch_norm_update_stats_ops.h>
460
+ #include <ATen/ops/bernoulli_ops.h>
461
+ #include <ATen/ops/bilinear_ops.h>
462
+ #include <ATen/ops/binary_cross_entropy_ops.h>
463
+ #include <ATen/ops/binary_cross_entropy_backward_ops.h>
464
+ #include <ATen/ops/binary_cross_entropy_with_logits_ops.h>
465
+ #include <ATen/ops/bincount_ops.h>
466
+ #include <ATen/ops/binomial_ops.h>
467
+ #include <ATen/ops/bitwise_and_ops.h>
468
+ #include <ATen/ops/bitwise_left_shift_ops.h>
469
+ #include <ATen/ops/bitwise_not_ops.h>
470
+ #include <ATen/ops/bitwise_or_ops.h>
471
+ #include <ATen/ops/bitwise_right_shift_ops.h>
472
+ #include <ATen/ops/bitwise_xor_ops.h>
473
+ #include <ATen/ops/blackman_window_ops.h>
474
+ #include <ATen/ops/block_diag_ops.h>
475
+ #include <ATen/ops/bmm_ops.h>
476
+ #include <ATen/ops/broadcast_tensors_ops.h>
477
+ #include <ATen/ops/broadcast_to_ops.h>
478
+ #include <ATen/ops/bucketize_ops.h>
479
+ #include <ATen/ops/can_cast_ops.h>
480
+ #include <ATen/ops/cartesian_prod_ops.h>
481
+ #include <ATen/ops/cat_ops.h>
482
+ #include <ATen/ops/cauchy_ops.h>
483
+ #include <ATen/ops/ccol_indices_ops.h>
484
+ #include <ATen/ops/ccol_indices_copy_ops.h>
485
+ #include <ATen/ops/cdist_ops.h>
486
+ #include <ATen/ops/ceil_ops.h>
487
+ #include <ATen/ops/celu_ops.h>
488
+ #include <ATen/ops/chain_matmul_ops.h>
489
+ #include <ATen/ops/chalf_ops.h>
490
+ #include <ATen/ops/channel_shuffle_ops.h>
491
+ #include <ATen/ops/cholesky_ops.h>
492
+ #include <ATen/ops/cholesky_inverse_ops.h>
493
+ #include <ATen/ops/cholesky_solve_ops.h>
494
+ #include <ATen/ops/choose_qparams_optimized_ops.h>
495
+ #include <ATen/ops/chunk_ops.h>
496
+ #include <ATen/ops/clamp_ops.h>
497
+ #include <ATen/ops/clamp_max_ops.h>
498
+ #include <ATen/ops/clamp_min_ops.h>
499
+ #include <ATen/ops/clip_ops.h>
500
+ #include <ATen/ops/clone_ops.h>
501
+ #include <ATen/ops/coalesce_ops.h>
502
+ #include <ATen/ops/col2im_ops.h>
503
+ #include <ATen/ops/col_indices_ops.h>
504
+ #include <ATen/ops/col_indices_copy_ops.h>
505
+ #include <ATen/ops/column_stack_ops.h>
506
+ #include <ATen/ops/combinations_ops.h>
507
+ #include <ATen/ops/complex_ops.h>
508
+ #include <ATen/ops/concat_ops.h>
509
+ #include <ATen/ops/concatenate_ops.h>
510
+ #include <ATen/ops/conj_ops.h>
511
+ #include <ATen/ops/conj_physical_ops.h>
512
+ #include <ATen/ops/constant_pad_nd_ops.h>
513
+ #include <ATen/ops/contiguous_ops.h>
514
+ #include <ATen/ops/conv1d_ops.h>
515
+ #include <ATen/ops/conv2d_ops.h>
516
+ #include <ATen/ops/conv3d_ops.h>
517
+ #include <ATen/ops/conv_depthwise3d_ops.h>
518
+ #include <ATen/ops/conv_tbc_ops.h>
519
+ #include <ATen/ops/conv_tbc_backward_ops.h>
520
+ #include <ATen/ops/conv_transpose1d_ops.h>
521
+ #include <ATen/ops/conv_transpose2d_ops.h>
522
+ #include <ATen/ops/conv_transpose3d_ops.h>
523
+ #include <ATen/ops/convolution_ops.h>
524
+ #include <ATen/ops/convolution_backward_ops.h>
525
+ #include <ATen/ops/convolution_backward_overrideable_ops.h>
526
+ #include <ATen/ops/convolution_overrideable_ops.h>
527
+ #include <ATen/ops/copy_ops.h>
528
+ #include <ATen/ops/copy_sparse_to_sparse_ops.h>
529
+ #include <ATen/ops/copysign_ops.h>
530
+ #include <ATen/ops/corrcoef_ops.h>
531
+ #include <ATen/ops/cos_ops.h>
532
+ #include <ATen/ops/cosh_ops.h>
533
+ #include <ATen/ops/cosine_embedding_loss_ops.h>
534
+ #include <ATen/ops/cosine_similarity_ops.h>
535
+ #include <ATen/ops/count_nonzero_ops.h>
536
+ #include <ATen/ops/cov_ops.h>
537
+ #include <ATen/ops/cross_ops.h>
538
+ #include <ATen/ops/cross_entropy_loss_ops.h>
539
+ #include <ATen/ops/crow_indices_ops.h>
540
+ #include <ATen/ops/crow_indices_copy_ops.h>
541
+ #include <ATen/ops/ctc_loss_ops.h>
542
+ #include <ATen/ops/cudnn_affine_grid_generator_ops.h>
543
+ #include <ATen/ops/cudnn_affine_grid_generator_backward_ops.h>
544
+ #include <ATen/ops/cudnn_batch_norm_ops.h>
545
+ #include <ATen/ops/cudnn_batch_norm_backward_ops.h>
546
+ #include <ATen/ops/cudnn_convolution_ops.h>
547
+ #include <ATen/ops/cudnn_convolution_add_relu_ops.h>
548
+ #include <ATen/ops/cudnn_convolution_relu_ops.h>
549
+ #include <ATen/ops/cudnn_convolution_transpose_ops.h>
550
+ #include <ATen/ops/cudnn_grid_sampler_ops.h>
551
+ #include <ATen/ops/cudnn_grid_sampler_backward_ops.h>
552
+ #include <ATen/ops/cudnn_is_acceptable_ops.h>
553
+ #include <ATen/ops/cummax_ops.h>
554
+ #include <ATen/ops/cummaxmin_backward_ops.h>
555
+ #include <ATen/ops/cummin_ops.h>
556
+ #include <ATen/ops/cumprod_ops.h>
557
+ #include <ATen/ops/cumprod_backward_ops.h>
558
+ #include <ATen/ops/cumsum_ops.h>
559
+ #include <ATen/ops/cumulative_trapezoid_ops.h>
560
+ #include <ATen/ops/data_ops.h>
561
+ #include <ATen/ops/deg2rad_ops.h>
562
+ #include <ATen/ops/dense_dim_ops.h>
563
+ #include <ATen/ops/dequantize_ops.h>
564
+ #include <ATen/ops/det_ops.h>
565
+ #include <ATen/ops/detach_ops.h>
566
+ #include <ATen/ops/detach_copy_ops.h>
567
+ #include <ATen/ops/diag_ops.h>
568
+ #include <ATen/ops/diag_embed_ops.h>
569
+ #include <ATen/ops/diagflat_ops.h>
570
+ #include <ATen/ops/diagonal_ops.h>
571
+ #include <ATen/ops/diagonal_backward_ops.h>
572
+ #include <ATen/ops/diagonal_copy_ops.h>
573
+ #include <ATen/ops/diagonal_scatter_ops.h>
574
+ #include <ATen/ops/diff_ops.h>
575
+ #include <ATen/ops/digamma_ops.h>
576
+ #include <ATen/ops/dist_ops.h>
577
+ #include <ATen/ops/div_ops.h>
578
+ #include <ATen/ops/divide_ops.h>
579
+ #include <ATen/ops/dot_ops.h>
580
+ #include <ATen/ops/dropout_ops.h>
581
+ #include <ATen/ops/dsplit_ops.h>
582
+ #include <ATen/ops/dstack_ops.h>
583
+ #include <ATen/ops/einsum_ops.h>
584
+ #include <ATen/ops/elu_ops.h>
585
+ #include <ATen/ops/elu_backward_ops.h>
586
+ #include <ATen/ops/embedding_ops.h>
587
+ #include <ATen/ops/embedding_backward_ops.h>
588
+ #include <ATen/ops/embedding_bag_ops.h>
589
+ #include <ATen/ops/embedding_dense_backward_ops.h>
590
+ #include <ATen/ops/embedding_renorm_ops.h>
591
+ #include <ATen/ops/embedding_sparse_backward_ops.h>
592
+ #include <ATen/ops/empty_ops.h>
593
+ #include <ATen/ops/empty_like_ops.h>
594
+ #include <ATen/ops/empty_permuted_ops.h>
595
+ #include <ATen/ops/empty_quantized_ops.h>
596
+ #include <ATen/ops/empty_strided_ops.h>
597
+ #include <ATen/ops/eq_ops.h>
598
+ #include <ATen/ops/equal_ops.h>
599
+ #include <ATen/ops/erf_ops.h>
600
+ #include <ATen/ops/erfc_ops.h>
601
+ #include <ATen/ops/erfinv_ops.h>
602
+ #include <ATen/ops/exp_ops.h>
603
+ #include <ATen/ops/exp2_ops.h>
604
+ #include <ATen/ops/expand_ops.h>
605
+ #include <ATen/ops/expand_as_ops.h>
606
+ #include <ATen/ops/expand_copy_ops.h>
607
+ #include <ATen/ops/expm1_ops.h>
608
+ #include <ATen/ops/exponential_ops.h>
609
+ #include <ATen/ops/eye_ops.h>
610
+ #include <ATen/ops/fake_quantize_per_channel_affine_ops.h>
611
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_ops.h>
612
+ #include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_ops.h>
613
+ #include <ATen/ops/fake_quantize_per_tensor_affine_ops.h>
614
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_ops.h>
615
+ #include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_ops.h>
616
+ #include <ATen/ops/fbgemm_linear_fp16_weight_ops.h>
617
+ #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_ops.h>
618
+ #include <ATen/ops/fbgemm_linear_int8_weight_ops.h>
619
+ #include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_ops.h>
620
+ #include <ATen/ops/fbgemm_linear_quantize_weight_ops.h>
621
+ #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_ops.h>
622
+ #include <ATen/ops/fbgemm_pack_quantized_matrix_ops.h>
623
+ #include <ATen/ops/feature_alpha_dropout_ops.h>
624
+ #include <ATen/ops/feature_dropout_ops.h>
625
+ #include <ATen/ops/fft_fft_ops.h>
626
+ #include <ATen/ops/fft_fft2_ops.h>
627
+ #include <ATen/ops/fft_fftfreq_ops.h>
628
+ #include <ATen/ops/fft_fftn_ops.h>
629
+ #include <ATen/ops/fft_fftshift_ops.h>
630
+ #include <ATen/ops/fft_hfft_ops.h>
631
+ #include <ATen/ops/fft_hfft2_ops.h>
632
+ #include <ATen/ops/fft_hfftn_ops.h>
633
+ #include <ATen/ops/fft_ifft_ops.h>
634
+ #include <ATen/ops/fft_ifft2_ops.h>
635
+ #include <ATen/ops/fft_ifftn_ops.h>
636
+ #include <ATen/ops/fft_ifftshift_ops.h>
637
+ #include <ATen/ops/fft_ihfft_ops.h>
638
+ #include <ATen/ops/fft_ihfft2_ops.h>
639
+ #include <ATen/ops/fft_ihfftn_ops.h>
640
+ #include <ATen/ops/fft_irfft_ops.h>
641
+ #include <ATen/ops/fft_irfft2_ops.h>
642
+ #include <ATen/ops/fft_irfftn_ops.h>
643
+ #include <ATen/ops/fft_rfft_ops.h>
644
+ #include <ATen/ops/fft_rfft2_ops.h>
645
+ #include <ATen/ops/fft_rfftfreq_ops.h>
646
+ #include <ATen/ops/fft_rfftn_ops.h>
647
+ #include <ATen/ops/fill_ops.h>
648
+ #include <ATen/ops/fill_diagonal_ops.h>
649
+ #include <ATen/ops/fix_ops.h>
650
+ #include <ATen/ops/flatten_ops.h>
651
+ #include <ATen/ops/flatten_dense_tensors_ops.h>
652
+ #include <ATen/ops/flip_ops.h>
653
+ #include <ATen/ops/fliplr_ops.h>
654
+ #include <ATen/ops/flipud_ops.h>
655
+ #include <ATen/ops/float_power_ops.h>
656
+ #include <ATen/ops/floor_ops.h>
657
+ #include <ATen/ops/floor_divide_ops.h>
658
+ #include <ATen/ops/fmax_ops.h>
659
+ #include <ATen/ops/fmin_ops.h>
660
+ #include <ATen/ops/fmod_ops.h>
661
+ #include <ATen/ops/frac_ops.h>
662
+ #include <ATen/ops/fractional_max_pool2d_ops.h>
663
+ #include <ATen/ops/fractional_max_pool2d_backward_ops.h>
664
+ #include <ATen/ops/fractional_max_pool3d_ops.h>
665
+ #include <ATen/ops/fractional_max_pool3d_backward_ops.h>
666
+ #include <ATen/ops/frexp_ops.h>
667
+ #include <ATen/ops/frobenius_norm_ops.h>
668
+ #include <ATen/ops/from_file_ops.h>
669
+ #include <ATen/ops/full_ops.h>
670
+ #include <ATen/ops/full_like_ops.h>
671
+ #include <ATen/ops/fused_moving_avg_obs_fake_quant_ops.h>
672
+ #include <ATen/ops/gather_ops.h>
673
+ #include <ATen/ops/gather_backward_ops.h>
674
+ #include <ATen/ops/gcd_ops.h>
675
+ #include <ATen/ops/ge_ops.h>
676
+ #include <ATen/ops/gelu_ops.h>
677
+ #include <ATen/ops/gelu_backward_ops.h>
678
+ #include <ATen/ops/geometric_ops.h>
679
+ #include <ATen/ops/geqrf_ops.h>
680
+ #include <ATen/ops/ger_ops.h>
681
+ #include <ATen/ops/glu_ops.h>
682
+ #include <ATen/ops/glu_backward_ops.h>
683
+ #include <ATen/ops/glu_backward_jvp_ops.h>
684
+ #include <ATen/ops/glu_jvp_ops.h>
685
+ #include <ATen/ops/gradient_ops.h>
686
+ #include <ATen/ops/greater_ops.h>
687
+ #include <ATen/ops/greater_equal_ops.h>
688
+ #include <ATen/ops/grid_sampler_ops.h>
689
+ #include <ATen/ops/grid_sampler_2d_ops.h>
690
+ #include <ATen/ops/grid_sampler_2d_backward_ops.h>
691
+ #include <ATen/ops/grid_sampler_3d_ops.h>
692
+ #include <ATen/ops/grid_sampler_3d_backward_ops.h>
693
+ #include <ATen/ops/group_norm_ops.h>
694
+ #include <ATen/ops/gru_ops.h>
695
+ #include <ATen/ops/gru_cell_ops.h>
696
+ #include <ATen/ops/gt_ops.h>
697
+ #include <ATen/ops/hamming_window_ops.h>
698
+ #include <ATen/ops/hann_window_ops.h>
699
+ #include <ATen/ops/hardshrink_ops.h>
700
+ #include <ATen/ops/hardshrink_backward_ops.h>
701
+ #include <ATen/ops/hardsigmoid_ops.h>
702
+ #include <ATen/ops/hardsigmoid_backward_ops.h>
703
+ #include <ATen/ops/hardswish_ops.h>
704
+ #include <ATen/ops/hardswish_backward_ops.h>
705
+ #include <ATen/ops/hardtanh_ops.h>
706
+ #include <ATen/ops/hardtanh_backward_ops.h>
707
+ #include <ATen/ops/heaviside_ops.h>
708
+ #include <ATen/ops/hinge_embedding_loss_ops.h>
709
+ #include <ATen/ops/histc_ops.h>
710
+ #include <ATen/ops/histogram_ops.h>
711
+ #include <ATen/ops/histogramdd_ops.h>
712
+ #include <ATen/ops/hsplit_ops.h>
713
+ #include <ATen/ops/hspmm_ops.h>
714
+ #include <ATen/ops/hstack_ops.h>
715
+ #include <ATen/ops/huber_loss_ops.h>
716
+ #include <ATen/ops/huber_loss_backward_ops.h>
717
+ #include <ATen/ops/hypot_ops.h>
718
+ #include <ATen/ops/i0_ops.h>
719
+ #include <ATen/ops/igamma_ops.h>
720
+ #include <ATen/ops/igammac_ops.h>
721
+ #include <ATen/ops/im2col_ops.h>
722
+ #include <ATen/ops/imag_ops.h>
723
+ #include <ATen/ops/index_ops.h>
724
+ #include <ATen/ops/index_add_ops.h>
725
+ #include <ATen/ops/index_copy_ops.h>
726
+ #include <ATen/ops/index_fill_ops.h>
727
+ #include <ATen/ops/index_put_ops.h>
728
+ #include <ATen/ops/index_reduce_ops.h>
729
+ #include <ATen/ops/index_select_ops.h>
730
+ #include <ATen/ops/index_select_backward_ops.h>
731
+ #include <ATen/ops/indices_ops.h>
732
+ #include <ATen/ops/indices_copy_ops.h>
733
+ #include <ATen/ops/infinitely_differentiable_gelu_backward_ops.h>
734
+ #include <ATen/ops/inner_ops.h>
735
+ #include <ATen/ops/instance_norm_ops.h>
736
+ #include <ATen/ops/int_repr_ops.h>
737
+ #include <ATen/ops/inverse_ops.h>
738
+ #include <ATen/ops/is_coalesced_ops.h>
739
+ #include <ATen/ops/is_complex_ops.h>
740
+ #include <ATen/ops/is_conj_ops.h>
741
+ #include <ATen/ops/is_distributed_ops.h>
742
+ #include <ATen/ops/is_floating_point_ops.h>
743
+ #include <ATen/ops/is_inference_ops.h>
744
+ #include <ATen/ops/is_leaf_ops.h>
745
+ #include <ATen/ops/is_neg_ops.h>
746
+ #include <ATen/ops/is_nonzero_ops.h>
747
+ #include <ATen/ops/is_pinned_ops.h>
748
+ #include <ATen/ops/is_same_size_ops.h>
749
+ #include <ATen/ops/is_set_to_ops.h>
750
+ #include <ATen/ops/is_signed_ops.h>
751
+ #include <ATen/ops/is_vulkan_available_ops.h>
752
+ #include <ATen/ops/isclose_ops.h>
753
+ #include <ATen/ops/isfinite_ops.h>
754
+ #include <ATen/ops/isin_ops.h>
755
+ #include <ATen/ops/isinf_ops.h>
756
+ #include <ATen/ops/isnan_ops.h>
757
+ #include <ATen/ops/isneginf_ops.h>
758
+ #include <ATen/ops/isposinf_ops.h>
759
+ #include <ATen/ops/isreal_ops.h>
760
+ #include <ATen/ops/istft_ops.h>
761
+ #include <ATen/ops/item_ops.h>
762
+ #include <ATen/ops/kaiser_window_ops.h>
763
+ #include <ATen/ops/kl_div_ops.h>
764
+ #include <ATen/ops/kron_ops.h>
765
+ #include <ATen/ops/kthvalue_ops.h>
766
+ #include <ATen/ops/l1_loss_ops.h>
767
+ #include <ATen/ops/layer_norm_ops.h>
768
+ #include <ATen/ops/lcm_ops.h>
769
+ #include <ATen/ops/ldexp_ops.h>
770
+ #include <ATen/ops/le_ops.h>
771
+ #include <ATen/ops/leaky_relu_ops.h>
772
+ #include <ATen/ops/leaky_relu_backward_ops.h>
773
+ #include <ATen/ops/lerp_ops.h>
774
+ #include <ATen/ops/less_ops.h>
775
+ #include <ATen/ops/less_equal_ops.h>
776
+ #include <ATen/ops/lgamma_ops.h>
777
+ #include <ATen/ops/lift_ops.h>
778
+ #include <ATen/ops/lift_fresh_ops.h>
779
+ #include <ATen/ops/lift_fresh_copy_ops.h>
780
+ #include <ATen/ops/linalg_cholesky_ops.h>
781
+ #include <ATen/ops/linalg_cholesky_ex_ops.h>
782
+ #include <ATen/ops/linalg_cond_ops.h>
783
+ #include <ATen/ops/linalg_cross_ops.h>
784
+ #include <ATen/ops/linalg_det_ops.h>
785
+ #include <ATen/ops/linalg_diagonal_ops.h>
786
+ #include <ATen/ops/linalg_eig_ops.h>
787
+ #include <ATen/ops/linalg_eigh_ops.h>
788
+ #include <ATen/ops/linalg_eigvals_ops.h>
789
+ #include <ATen/ops/linalg_eigvalsh_ops.h>
790
+ #include <ATen/ops/linalg_householder_product_ops.h>
791
+ #include <ATen/ops/linalg_inv_ops.h>
792
+ #include <ATen/ops/linalg_inv_ex_ops.h>
793
+ #include <ATen/ops/linalg_ldl_factor_ops.h>
794
+ #include <ATen/ops/linalg_ldl_factor_ex_ops.h>
795
+ #include <ATen/ops/linalg_ldl_solve_ops.h>
796
+ #include <ATen/ops/linalg_lstsq_ops.h>
797
+ #include <ATen/ops/linalg_lu_ops.h>
798
+ #include <ATen/ops/linalg_lu_factor_ops.h>
799
+ #include <ATen/ops/linalg_lu_factor_ex_ops.h>
800
+ #include <ATen/ops/linalg_lu_solve_ops.h>
801
+ #include <ATen/ops/linalg_matmul_ops.h>
802
+ #include <ATen/ops/linalg_matrix_exp_ops.h>
803
+ #include <ATen/ops/linalg_matrix_norm_ops.h>
804
+ #include <ATen/ops/linalg_matrix_power_ops.h>
805
+ #include <ATen/ops/linalg_matrix_rank_ops.h>
806
+ #include <ATen/ops/linalg_multi_dot_ops.h>
807
+ #include <ATen/ops/linalg_norm_ops.h>
808
+ #include <ATen/ops/linalg_pinv_ops.h>
809
+ #include <ATen/ops/linalg_qr_ops.h>
810
+ #include <ATen/ops/linalg_slogdet_ops.h>
811
+ #include <ATen/ops/linalg_solve_ops.h>
812
+ #include <ATen/ops/linalg_solve_ex_ops.h>
813
+ #include <ATen/ops/linalg_solve_triangular_ops.h>
814
+ #include <ATen/ops/linalg_svd_ops.h>
815
+ #include <ATen/ops/linalg_svdvals_ops.h>
816
+ #include <ATen/ops/linalg_tensorinv_ops.h>
817
+ #include <ATen/ops/linalg_tensorsolve_ops.h>
818
+ #include <ATen/ops/linalg_vander_ops.h>
819
+ #include <ATen/ops/linalg_vecdot_ops.h>
820
+ #include <ATen/ops/linalg_vector_norm_ops.h>
821
+ #include <ATen/ops/linear_ops.h>
822
+ #include <ATen/ops/linear_backward_ops.h>
823
+ #include <ATen/ops/linspace_ops.h>
824
+ #include <ATen/ops/log_ops.h>
825
+ #include <ATen/ops/log10_ops.h>
826
+ #include <ATen/ops/log1p_ops.h>
827
+ #include <ATen/ops/log2_ops.h>
828
+ #include <ATen/ops/log_normal_ops.h>
829
+ #include <ATen/ops/log_sigmoid_ops.h>
830
+ #include <ATen/ops/log_sigmoid_backward_ops.h>
831
+ #include <ATen/ops/log_sigmoid_forward_ops.h>
832
+ #include <ATen/ops/log_softmax_ops.h>
833
+ #include <ATen/ops/logaddexp_ops.h>
834
+ #include <ATen/ops/logaddexp2_ops.h>
835
+ #include <ATen/ops/logcumsumexp_ops.h>
836
+ #include <ATen/ops/logdet_ops.h>
837
+ #include <ATen/ops/logical_and_ops.h>
838
+ #include <ATen/ops/logical_not_ops.h>
839
+ #include <ATen/ops/logical_or_ops.h>
840
+ #include <ATen/ops/logical_xor_ops.h>
841
+ #include <ATen/ops/logit_ops.h>
842
+ #include <ATen/ops/logit_backward_ops.h>
843
+ #include <ATen/ops/logspace_ops.h>
844
+ #include <ATen/ops/logsumexp_ops.h>
845
+ #include <ATen/ops/lshift_ops.h>
846
+ #include <ATen/ops/lstm_ops.h>
847
+ #include <ATen/ops/lstm_cell_ops.h>
848
+ #include <ATen/ops/lstm_mps_backward_ops.h>
849
+ #include <ATen/ops/lt_ops.h>
850
+ #include <ATen/ops/lu_solve_ops.h>
851
+ #include <ATen/ops/lu_unpack_ops.h>
852
+ #include <ATen/ops/mH_ops.h>
853
+ #include <ATen/ops/mT_ops.h>
854
+ #include <ATen/ops/margin_ranking_loss_ops.h>
855
+ #include <ATen/ops/masked_fill_ops.h>
856
+ #include <ATen/ops/masked_scatter_ops.h>
857
+ #include <ATen/ops/masked_scatter_backward_ops.h>
858
+ #include <ATen/ops/masked_select_ops.h>
859
+ #include <ATen/ops/masked_select_backward_ops.h>
860
+ #include <ATen/ops/matmul_ops.h>
861
+ #include <ATen/ops/matmul_backward_ops.h>
862
+ #include <ATen/ops/matrix_H_ops.h>
863
+ #include <ATen/ops/matrix_exp_ops.h>
864
+ #include <ATen/ops/matrix_exp_backward_ops.h>
865
+ #include <ATen/ops/matrix_power_ops.h>
866
+ #include <ATen/ops/max_ops.h>
867
+ #include <ATen/ops/max_pool1d_ops.h>
868
+ #include <ATen/ops/max_pool1d_with_indices_ops.h>
869
+ #include <ATen/ops/max_pool2d_ops.h>
870
+ #include <ATen/ops/max_pool2d_backward_ops.h>
871
+ #include <ATen/ops/max_pool2d_with_indices_ops.h>
872
+ #include <ATen/ops/max_pool2d_with_indices_backward_ops.h>
873
+ #include <ATen/ops/max_pool3d_ops.h>
874
+ #include <ATen/ops/max_pool3d_with_indices_ops.h>
875
+ #include <ATen/ops/max_pool3d_with_indices_backward_ops.h>
876
+ #include <ATen/ops/max_unpool2d_ops.h>
877
+ #include <ATen/ops/max_unpool3d_ops.h>
878
+ #include <ATen/ops/maximum_ops.h>
879
+ #include <ATen/ops/mean_ops.h>
880
+ #include <ATen/ops/median_ops.h>
881
+ #include <ATen/ops/meshgrid_ops.h>
882
+ #include <ATen/ops/min_ops.h>
883
+ #include <ATen/ops/minimum_ops.h>
884
+ #include <ATen/ops/miopen_batch_norm_ops.h>
885
+ #include <ATen/ops/miopen_batch_norm_backward_ops.h>
886
+ #include <ATen/ops/miopen_convolution_ops.h>
887
+ #include <ATen/ops/miopen_convolution_add_relu_ops.h>
888
+ #include <ATen/ops/miopen_convolution_relu_ops.h>
889
+ #include <ATen/ops/miopen_convolution_transpose_ops.h>
890
+ #include <ATen/ops/miopen_depthwise_convolution_ops.h>
891
+ #include <ATen/ops/miopen_rnn_ops.h>
892
+ #include <ATen/ops/miopen_rnn_backward_ops.h>
893
+ #include <ATen/ops/mish_ops.h>
894
+ #include <ATen/ops/mish_backward_ops.h>
895
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_ops.h>
896
+ #include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_ops.h>
897
+ #include <ATen/ops/mkldnn_convolution_ops.h>
898
+ #include <ATen/ops/mkldnn_linear_ops.h>
899
+ #include <ATen/ops/mkldnn_linear_backward_ops.h>
900
+ #include <ATen/ops/mkldnn_linear_backward_input_ops.h>
901
+ #include <ATen/ops/mkldnn_linear_backward_weights_ops.h>
902
+ #include <ATen/ops/mkldnn_max_pool2d_ops.h>
903
+ #include <ATen/ops/mkldnn_max_pool2d_backward_ops.h>
904
+ #include <ATen/ops/mkldnn_max_pool3d_ops.h>
905
+ #include <ATen/ops/mkldnn_max_pool3d_backward_ops.h>
906
+ #include <ATen/ops/mkldnn_reorder_conv2d_weight_ops.h>
907
+ #include <ATen/ops/mkldnn_reorder_conv3d_weight_ops.h>
908
+ #include <ATen/ops/mkldnn_rnn_layer_ops.h>
909
+ #include <ATen/ops/mkldnn_rnn_layer_backward_ops.h>
910
+ #include <ATen/ops/mm_ops.h>
911
+ #include <ATen/ops/mode_ops.h>
912
+ #include <ATen/ops/moveaxis_ops.h>
913
+ #include <ATen/ops/movedim_ops.h>
914
+ #include <ATen/ops/mps_convolution_backward_ops.h>
915
+ #include <ATen/ops/mps_convolution_transpose_backward_ops.h>
916
+ #include <ATen/ops/mse_loss_ops.h>
917
+ #include <ATen/ops/mse_loss_backward_ops.h>
918
+ #include <ATen/ops/msort_ops.h>
919
+ #include <ATen/ops/mul_ops.h>
920
+ #include <ATen/ops/multi_margin_loss_ops.h>
921
+ #include <ATen/ops/multi_margin_loss_backward_ops.h>
922
+ #include <ATen/ops/multilabel_margin_loss_ops.h>
923
+ #include <ATen/ops/multilabel_margin_loss_backward_ops.h>
924
+ #include <ATen/ops/multilabel_margin_loss_forward_ops.h>
925
+ #include <ATen/ops/multinomial_ops.h>
926
+ #include <ATen/ops/multiply_ops.h>
927
+ #include <ATen/ops/mv_ops.h>
928
+ #include <ATen/ops/mvlgamma_ops.h>
929
+ #include <ATen/ops/nan_to_num_ops.h>
930
+ #include <ATen/ops/nanmean_ops.h>
931
+ #include <ATen/ops/nanmedian_ops.h>
932
+ #include <ATen/ops/nanquantile_ops.h>
933
+ #include <ATen/ops/nansum_ops.h>
934
+ #include <ATen/ops/narrow_ops.h>
935
+ #include <ATen/ops/narrow_copy_ops.h>
936
+ #include <ATen/ops/native_batch_norm_ops.h>
937
+ #include <ATen/ops/native_batch_norm_backward_ops.h>
938
+ #include <ATen/ops/native_channel_shuffle_ops.h>
939
+ #include <ATen/ops/native_dropout_ops.h>
940
+ #include <ATen/ops/native_dropout_backward_ops.h>
941
+ #include <ATen/ops/native_group_norm_ops.h>
942
+ #include <ATen/ops/native_group_norm_backward_ops.h>
943
+ #include <ATen/ops/native_layer_norm_ops.h>
944
+ #include <ATen/ops/native_layer_norm_backward_ops.h>
945
+ #include <ATen/ops/native_norm_ops.h>
946
+ #include <ATen/ops/ne_ops.h>
947
+ #include <ATen/ops/neg_ops.h>
948
+ #include <ATen/ops/negative_ops.h>
949
+ #include <ATen/ops/nested_to_padded_tensor_ops.h>
950
+ #include <ATen/ops/new_empty_ops.h>
951
+ #include <ATen/ops/new_empty_strided_ops.h>
952
+ #include <ATen/ops/new_full_ops.h>
953
+ #include <ATen/ops/new_ones_ops.h>
954
+ #include <ATen/ops/new_zeros_ops.h>
955
+ #include <ATen/ops/nextafter_ops.h>
956
+ #include <ATen/ops/nll_loss_ops.h>
957
+ #include <ATen/ops/nll_loss2d_ops.h>
958
+ #include <ATen/ops/nll_loss2d_backward_ops.h>
959
+ #include <ATen/ops/nll_loss2d_forward_ops.h>
960
+ #include <ATen/ops/nll_loss_backward_ops.h>
961
+ #include <ATen/ops/nll_loss_forward_ops.h>
962
+ #include <ATen/ops/nll_loss_nd_ops.h>
963
+ #include <ATen/ops/nonzero_ops.h>
964
+ #include <ATen/ops/nonzero_numpy_ops.h>
965
+ #include <ATen/ops/nonzero_static_ops.h>
966
+ #include <ATen/ops/norm_ops.h>
967
+ #include <ATen/ops/norm_except_dim_ops.h>
968
+ #include <ATen/ops/normal_ops.h>
969
+ #include <ATen/ops/not_equal_ops.h>
970
+ #include <ATen/ops/nuclear_norm_ops.h>
971
+ #include <ATen/ops/numpy_T_ops.h>
972
+ #include <ATen/ops/one_hot_ops.h>
973
+ #include <ATen/ops/ones_ops.h>
974
+ #include <ATen/ops/ones_like_ops.h>
975
+ #include <ATen/ops/or_ops.h>
976
+ #include <ATen/ops/orgqr_ops.h>
977
+ #include <ATen/ops/ormqr_ops.h>
978
+ #include <ATen/ops/outer_ops.h>
979
+ #include <ATen/ops/output_nr_ops.h>
980
+ #include <ATen/ops/pad_ops.h>
981
+ #include <ATen/ops/pad_sequence_ops.h>
982
+ #include <ATen/ops/pairwise_distance_ops.h>
983
+ #include <ATen/ops/pdist_ops.h>
984
+ #include <ATen/ops/permute_ops.h>
985
+ #include <ATen/ops/permute_copy_ops.h>
986
+ #include <ATen/ops/pin_memory_ops.h>
987
+ #include <ATen/ops/pinverse_ops.h>
988
+ #include <ATen/ops/pixel_shuffle_ops.h>
989
+ #include <ATen/ops/pixel_unshuffle_ops.h>
990
+ #include <ATen/ops/poisson_ops.h>
991
+ #include <ATen/ops/poisson_nll_loss_ops.h>
992
+ #include <ATen/ops/polar_ops.h>
993
+ #include <ATen/ops/polygamma_ops.h>
994
+ #include <ATen/ops/positive_ops.h>
995
+ #include <ATen/ops/pow_ops.h>
996
+ #include <ATen/ops/prelu_ops.h>
997
+ #include <ATen/ops/prod_ops.h>
998
+ #include <ATen/ops/promote_types_ops.h>
999
+ #include <ATen/ops/put_ops.h>
1000
+ #include <ATen/ops/q_per_channel_axis_ops.h>
1001
+ #include <ATen/ops/q_per_channel_scales_ops.h>
1002
+ #include <ATen/ops/q_per_channel_zero_points_ops.h>
1003
+ #include <ATen/ops/q_scale_ops.h>
1004
+ #include <ATen/ops/q_zero_point_ops.h>
1005
+ #include <ATen/ops/qr_ops.h>
1006
+ #include <ATen/ops/qscheme_ops.h>
1007
+ #include <ATen/ops/quantile_ops.h>
1008
+ #include <ATen/ops/quantize_per_channel_ops.h>
1009
+ #include <ATen/ops/quantize_per_tensor_ops.h>
1010
+ #include <ATen/ops/quantize_per_tensor_dynamic_ops.h>
1011
+ #include <ATen/ops/quantized_batch_norm_ops.h>
1012
+ #include <ATen/ops/quantized_gru_cell_ops.h>
1013
+ #include <ATen/ops/quantized_lstm_cell_ops.h>
1014
+ #include <ATen/ops/quantized_max_pool1d_ops.h>
1015
+ #include <ATen/ops/quantized_max_pool2d_ops.h>
1016
+ #include <ATen/ops/quantized_max_pool3d_ops.h>
1017
+ #include <ATen/ops/quantized_rnn_relu_cell_ops.h>
1018
+ #include <ATen/ops/quantized_rnn_tanh_cell_ops.h>
1019
+ #include <ATen/ops/rad2deg_ops.h>
1020
+ #include <ATen/ops/rand_ops.h>
1021
+ #include <ATen/ops/rand_like_ops.h>
1022
+ #include <ATen/ops/randint_ops.h>
1023
+ #include <ATen/ops/randint_like_ops.h>
1024
+ #include <ATen/ops/randn_ops.h>
1025
+ #include <ATen/ops/randn_like_ops.h>
1026
+ #include <ATen/ops/random_ops.h>
1027
+ #include <ATen/ops/randperm_ops.h>
1028
+ #include <ATen/ops/range_ops.h>
1029
+ #include <ATen/ops/ravel_ops.h>
1030
+ #include <ATen/ops/real_ops.h>
1031
+ #include <ATen/ops/reciprocal_ops.h>
1032
+ #include <ATen/ops/record_stream_ops.h>
1033
+ #include <ATen/ops/refine_names_ops.h>
1034
+ #include <ATen/ops/reflection_pad1d_ops.h>
1035
+ #include <ATen/ops/reflection_pad1d_backward_ops.h>
1036
+ #include <ATen/ops/reflection_pad2d_ops.h>
1037
+ #include <ATen/ops/reflection_pad2d_backward_ops.h>
1038
+ #include <ATen/ops/reflection_pad3d_ops.h>
1039
+ #include <ATen/ops/reflection_pad3d_backward_ops.h>
1040
+ #include <ATen/ops/relu_ops.h>
1041
+ #include <ATen/ops/relu6_ops.h>
1042
+ #include <ATen/ops/remainder_ops.h>
1043
+ #include <ATen/ops/rename_ops.h>
1044
+ #include <ATen/ops/renorm_ops.h>
1045
+ #include <ATen/ops/repeat_ops.h>
1046
+ #include <ATen/ops/repeat_interleave_ops.h>
1047
+ #include <ATen/ops/replication_pad1d_ops.h>
1048
+ #include <ATen/ops/replication_pad1d_backward_ops.h>
1049
+ #include <ATen/ops/replication_pad2d_ops.h>
1050
+ #include <ATen/ops/replication_pad2d_backward_ops.h>
1051
+ #include <ATen/ops/replication_pad3d_ops.h>
1052
+ #include <ATen/ops/replication_pad3d_backward_ops.h>
1053
+ #include <ATen/ops/requires_grad_ops.h>
1054
+ #include <ATen/ops/reshape_ops.h>
1055
+ #include <ATen/ops/reshape_as_ops.h>
1056
+ #include <ATen/ops/resize_ops.h>
1057
+ #include <ATen/ops/resize_as_ops.h>
1058
+ #include <ATen/ops/resize_as_sparse_ops.h>
1059
+ #include <ATen/ops/resolve_conj_ops.h>
1060
+ #include <ATen/ops/resolve_neg_ops.h>
1061
+ #include <ATen/ops/result_type_ops.h>
1062
+ #include <ATen/ops/retain_grad_ops.h>
1063
+ #include <ATen/ops/retains_grad_ops.h>
1064
+ #include <ATen/ops/rnn_relu_ops.h>
1065
+ #include <ATen/ops/rnn_relu_cell_ops.h>
1066
+ #include <ATen/ops/rnn_tanh_ops.h>
1067
+ #include <ATen/ops/rnn_tanh_cell_ops.h>
1068
+ #include <ATen/ops/roll_ops.h>
1069
+ #include <ATen/ops/rot90_ops.h>
1070
+ #include <ATen/ops/round_ops.h>
1071
+ #include <ATen/ops/row_indices_ops.h>
1072
+ #include <ATen/ops/row_indices_copy_ops.h>
1073
+ #include <ATen/ops/row_stack_ops.h>
1074
+ #include <ATen/ops/rrelu_ops.h>
1075
+ #include <ATen/ops/rrelu_with_noise_ops.h>
1076
+ #include <ATen/ops/rrelu_with_noise_backward_ops.h>
1077
+ #include <ATen/ops/rshift_ops.h>
1078
+ #include <ATen/ops/rsqrt_ops.h>
1079
+ #include <ATen/ops/rsub_ops.h>
1080
+ #include <ATen/ops/scalar_tensor_ops.h>
1081
+ #include <ATen/ops/scaled_dot_product_attention_ops.h>
1082
+ #include <ATen/ops/scatter_ops.h>
1083
+ #include <ATen/ops/scatter_add_ops.h>
1084
+ #include <ATen/ops/scatter_reduce_ops.h>
1085
+ #include <ATen/ops/searchsorted_ops.h>
1086
+ #include <ATen/ops/segment_reduce_ops.h>
1087
+ #include <ATen/ops/select_ops.h>
1088
+ #include <ATen/ops/select_backward_ops.h>
1089
+ #include <ATen/ops/select_copy_ops.h>
1090
+ #include <ATen/ops/select_scatter_ops.h>
1091
+ #include <ATen/ops/selu_ops.h>
1092
+ #include <ATen/ops/set_ops.h>
1093
+ #include <ATen/ops/set_data_ops.h>
1094
+ #include <ATen/ops/sgn_ops.h>
1095
+ #include <ATen/ops/sigmoid_ops.h>
1096
+ #include <ATen/ops/sigmoid_backward_ops.h>
1097
+ #include <ATen/ops/sign_ops.h>
1098
+ #include <ATen/ops/signbit_ops.h>
1099
+ #include <ATen/ops/silu_ops.h>
1100
+ #include <ATen/ops/silu_backward_ops.h>
1101
+ #include <ATen/ops/sin_ops.h>
1102
+ #include <ATen/ops/sinc_ops.h>
1103
+ #include <ATen/ops/sinh_ops.h>
1104
+ #include <ATen/ops/size_ops.h>
1105
+ #include <ATen/ops/slice_ops.h>
1106
+ #include <ATen/ops/slice_backward_ops.h>
1107
+ #include <ATen/ops/slice_copy_ops.h>
1108
+ #include <ATen/ops/slice_inverse_ops.h>
1109
+ #include <ATen/ops/slice_scatter_ops.h>
1110
+ #include <ATen/ops/slogdet_ops.h>
1111
+ #include <ATen/ops/slow_conv3d_ops.h>
1112
+ #include <ATen/ops/slow_conv3d_forward_ops.h>
1113
+ #include <ATen/ops/slow_conv_dilated2d_ops.h>
1114
+ #include <ATen/ops/slow_conv_dilated3d_ops.h>
1115
+ #include <ATen/ops/slow_conv_transpose2d_ops.h>
1116
+ #include <ATen/ops/slow_conv_transpose3d_ops.h>
1117
+ #include <ATen/ops/smm_ops.h>
1118
+ #include <ATen/ops/smooth_l1_loss_ops.h>
1119
+ #include <ATen/ops/smooth_l1_loss_backward_ops.h>
1120
+ #include <ATen/ops/soft_margin_loss_ops.h>
1121
+ #include <ATen/ops/soft_margin_loss_backward_ops.h>
1122
+ #include <ATen/ops/softmax_ops.h>
1123
+ #include <ATen/ops/softplus_ops.h>
1124
+ #include <ATen/ops/softplus_backward_ops.h>
1125
+ #include <ATen/ops/softshrink_ops.h>
1126
+ #include <ATen/ops/softshrink_backward_ops.h>
1127
+ #include <ATen/ops/sort_ops.h>
1128
+ #include <ATen/ops/sparse_bsc_tensor_ops.h>
1129
+ #include <ATen/ops/sparse_bsr_tensor_ops.h>
1130
+ #include <ATen/ops/sparse_compressed_tensor_ops.h>
1131
+ #include <ATen/ops/sparse_coo_tensor_ops.h>
1132
+ #include <ATen/ops/sparse_csc_tensor_ops.h>
1133
+ #include <ATen/ops/sparse_csr_tensor_ops.h>
1134
+ #include <ATen/ops/sparse_dim_ops.h>
1135
+ #include <ATen/ops/sparse_mask_ops.h>
1136
+ #include <ATen/ops/sparse_resize_ops.h>
1137
+ #include <ATen/ops/sparse_resize_and_clear_ops.h>
1138
+ #include <ATen/ops/sparse_sampled_addmm_ops.h>
1139
+ #include <ATen/ops/special_airy_ai_ops.h>
1140
+ #include <ATen/ops/special_bessel_j0_ops.h>
1141
+ #include <ATen/ops/special_bessel_j1_ops.h>
1142
+ #include <ATen/ops/special_bessel_y0_ops.h>
1143
+ #include <ATen/ops/special_bessel_y1_ops.h>
1144
+ #include <ATen/ops/special_chebyshev_polynomial_t_ops.h>
1145
+ #include <ATen/ops/special_chebyshev_polynomial_u_ops.h>
1146
+ #include <ATen/ops/special_chebyshev_polynomial_v_ops.h>
1147
+ #include <ATen/ops/special_chebyshev_polynomial_w_ops.h>
1148
+ #include <ATen/ops/special_digamma_ops.h>
1149
+ #include <ATen/ops/special_entr_ops.h>
1150
+ #include <ATen/ops/special_erf_ops.h>
1151
+ #include <ATen/ops/special_erfc_ops.h>
1152
+ #include <ATen/ops/special_erfcx_ops.h>
1153
+ #include <ATen/ops/special_erfinv_ops.h>
1154
+ #include <ATen/ops/special_exp2_ops.h>
1155
+ #include <ATen/ops/special_expit_ops.h>
1156
+ #include <ATen/ops/special_expm1_ops.h>
1157
+ #include <ATen/ops/special_gammainc_ops.h>
1158
+ #include <ATen/ops/special_gammaincc_ops.h>
1159
+ #include <ATen/ops/special_gammaln_ops.h>
1160
+ #include <ATen/ops/special_hermite_polynomial_h_ops.h>
1161
+ #include <ATen/ops/special_hermite_polynomial_he_ops.h>
1162
+ #include <ATen/ops/special_i0_ops.h>
1163
+ #include <ATen/ops/special_i0e_ops.h>
1164
+ #include <ATen/ops/special_i1_ops.h>
1165
+ #include <ATen/ops/special_i1e_ops.h>
1166
+ #include <ATen/ops/special_laguerre_polynomial_l_ops.h>
1167
+ #include <ATen/ops/special_legendre_polynomial_p_ops.h>
1168
+ #include <ATen/ops/special_log1p_ops.h>
1169
+ #include <ATen/ops/special_log_ndtr_ops.h>
1170
+ #include <ATen/ops/special_log_softmax_ops.h>
1171
+ #include <ATen/ops/special_logit_ops.h>
1172
+ #include <ATen/ops/special_logsumexp_ops.h>
1173
+ #include <ATen/ops/special_modified_bessel_i0_ops.h>
1174
+ #include <ATen/ops/special_modified_bessel_i1_ops.h>
1175
+ #include <ATen/ops/special_modified_bessel_k0_ops.h>
1176
+ #include <ATen/ops/special_modified_bessel_k1_ops.h>
1177
+ #include <ATen/ops/special_multigammaln_ops.h>
1178
+ #include <ATen/ops/special_ndtr_ops.h>
1179
+ #include <ATen/ops/special_ndtri_ops.h>
1180
+ #include <ATen/ops/special_polygamma_ops.h>
1181
+ #include <ATen/ops/special_psi_ops.h>
1182
+ #include <ATen/ops/special_round_ops.h>
1183
+ #include <ATen/ops/special_scaled_modified_bessel_k0_ops.h>
1184
+ #include <ATen/ops/special_scaled_modified_bessel_k1_ops.h>
1185
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_t_ops.h>
1186
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_u_ops.h>
1187
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_v_ops.h>
1188
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_w_ops.h>
1189
+ #include <ATen/ops/special_sinc_ops.h>
1190
+ #include <ATen/ops/special_softmax_ops.h>
1191
+ #include <ATen/ops/special_spherical_bessel_j0_ops.h>
1192
+ #include <ATen/ops/special_xlog1py_ops.h>
1193
+ #include <ATen/ops/special_xlogy_ops.h>
1194
+ #include <ATen/ops/special_zeta_ops.h>
1195
+ #include <ATen/ops/split_ops.h>
1196
+ #include <ATen/ops/split_copy_ops.h>
1197
+ #include <ATen/ops/split_with_sizes_ops.h>
1198
+ #include <ATen/ops/split_with_sizes_copy_ops.h>
1199
+ #include <ATen/ops/sqrt_ops.h>
1200
+ #include <ATen/ops/square_ops.h>
1201
+ #include <ATen/ops/squeeze_ops.h>
1202
+ #include <ATen/ops/squeeze_copy_ops.h>
1203
+ #include <ATen/ops/sspaddmm_ops.h>
1204
+ #include <ATen/ops/stack_ops.h>
1205
+ #include <ATen/ops/std_ops.h>
1206
+ #include <ATen/ops/std_mean_ops.h>
1207
+ #include <ATen/ops/stft_ops.h>
1208
+ #include <ATen/ops/stride_ops.h>
1209
+ #include <ATen/ops/sub_ops.h>
1210
+ #include <ATen/ops/subtract_ops.h>
1211
+ #include <ATen/ops/sum_ops.h>
1212
+ #include <ATen/ops/sum_to_size_ops.h>
1213
+ #include <ATen/ops/svd_ops.h>
1214
+ #include <ATen/ops/swapaxes_ops.h>
1215
+ #include <ATen/ops/swapdims_ops.h>
1216
+ #include <ATen/ops/sym_constrain_range_ops.h>
1217
+ #include <ATen/ops/sym_constrain_range_for_size_ops.h>
1218
+ #include <ATen/ops/sym_numel_ops.h>
1219
+ #include <ATen/ops/sym_size_ops.h>
1220
+ #include <ATen/ops/sym_storage_offset_ops.h>
1221
+ #include <ATen/ops/sym_stride_ops.h>
1222
+ #include <ATen/ops/t_ops.h>
1223
+ #include <ATen/ops/t_copy_ops.h>
1224
+ #include <ATen/ops/take_ops.h>
1225
+ #include <ATen/ops/take_along_dim_ops.h>
1226
+ #include <ATen/ops/tan_ops.h>
1227
+ #include <ATen/ops/tanh_ops.h>
1228
+ #include <ATen/ops/tanh_backward_ops.h>
1229
+ #include <ATen/ops/tensor_split_ops.h>
1230
+ #include <ATen/ops/tensordot_ops.h>
1231
+ #include <ATen/ops/thnn_conv2d_ops.h>
1232
+ #include <ATen/ops/threshold_ops.h>
1233
+ #include <ATen/ops/threshold_backward_ops.h>
1234
+ #include <ATen/ops/tile_ops.h>
1235
+ #include <ATen/ops/to_ops.h>
1236
+ #include <ATen/ops/to_dense_ops.h>
1237
+ #include <ATen/ops/to_dense_backward_ops.h>
1238
+ #include <ATen/ops/to_mkldnn_ops.h>
1239
+ #include <ATen/ops/to_mkldnn_backward_ops.h>
1240
+ #include <ATen/ops/to_padded_tensor_ops.h>
1241
+ #include <ATen/ops/to_sparse_ops.h>
1242
+ #include <ATen/ops/to_sparse_bsc_ops.h>
1243
+ #include <ATen/ops/to_sparse_bsr_ops.h>
1244
+ #include <ATen/ops/to_sparse_csc_ops.h>
1245
+ #include <ATen/ops/to_sparse_csr_ops.h>
1246
+ #include <ATen/ops/topk_ops.h>
1247
+ #include <ATen/ops/trace_ops.h>
1248
+ #include <ATen/ops/trace_backward_ops.h>
1249
+ #include <ATen/ops/transpose_ops.h>
1250
+ #include <ATen/ops/transpose_copy_ops.h>
1251
+ #include <ATen/ops/trapezoid_ops.h>
1252
+ #include <ATen/ops/trapz_ops.h>
1253
+ #include <ATen/ops/triangular_solve_ops.h>
1254
+ #include <ATen/ops/tril_ops.h>
1255
+ #include <ATen/ops/tril_indices_ops.h>
1256
+ #include <ATen/ops/triplet_margin_loss_ops.h>
1257
+ #include <ATen/ops/triu_ops.h>
1258
+ #include <ATen/ops/triu_indices_ops.h>
1259
+ #include <ATen/ops/true_divide_ops.h>
1260
+ #include <ATen/ops/trunc_ops.h>
1261
+ #include <ATen/ops/type_as_ops.h>
1262
+ #include <ATen/ops/unbind_ops.h>
1263
+ #include <ATen/ops/unbind_copy_ops.h>
1264
+ #include <ATen/ops/unflatten_ops.h>
1265
+ #include <ATen/ops/unflatten_dense_tensors_ops.h>
1266
+ #include <ATen/ops/unfold_ops.h>
1267
+ #include <ATen/ops/unfold_backward_ops.h>
1268
+ #include <ATen/ops/unfold_copy_ops.h>
1269
+ #include <ATen/ops/uniform_ops.h>
1270
+ #include <ATen/ops/unique_consecutive_ops.h>
1271
+ #include <ATen/ops/unique_dim_ops.h>
1272
+ #include <ATen/ops/unique_dim_consecutive_ops.h>
1273
+ #include <ATen/ops/unsafe_chunk_ops.h>
1274
+ #include <ATen/ops/unsafe_split_ops.h>
1275
+ #include <ATen/ops/unsafe_split_with_sizes_ops.h>
1276
+ #include <ATen/ops/unsqueeze_ops.h>
1277
+ #include <ATen/ops/unsqueeze_copy_ops.h>
1278
+ #include <ATen/ops/upsample_bicubic2d_ops.h>
1279
+ #include <ATen/ops/upsample_bicubic2d_backward_ops.h>
1280
+ #include <ATen/ops/upsample_bilinear2d_ops.h>
1281
+ #include <ATen/ops/upsample_bilinear2d_backward_ops.h>
1282
+ #include <ATen/ops/upsample_linear1d_ops.h>
1283
+ #include <ATen/ops/upsample_linear1d_backward_ops.h>
1284
+ #include <ATen/ops/upsample_nearest1d_ops.h>
1285
+ #include <ATen/ops/upsample_nearest1d_backward_ops.h>
1286
+ #include <ATen/ops/upsample_nearest2d_ops.h>
1287
+ #include <ATen/ops/upsample_nearest2d_backward_ops.h>
1288
+ #include <ATen/ops/upsample_nearest3d_ops.h>
1289
+ #include <ATen/ops/upsample_nearest3d_backward_ops.h>
1290
+ #include <ATen/ops/upsample_trilinear3d_ops.h>
1291
+ #include <ATen/ops/upsample_trilinear3d_backward_ops.h>
1292
+ #include <ATen/ops/value_selecting_reduction_backward_ops.h>
1293
+ #include <ATen/ops/values_ops.h>
1294
+ #include <ATen/ops/values_copy_ops.h>
1295
+ #include <ATen/ops/vander_ops.h>
1296
+ #include <ATen/ops/var_ops.h>
1297
+ #include <ATen/ops/var_mean_ops.h>
1298
+ #include <ATen/ops/vdot_ops.h>
1299
+ #include <ATen/ops/view_ops.h>
1300
+ #include <ATen/ops/view_as_ops.h>
1301
+ #include <ATen/ops/view_as_complex_ops.h>
1302
+ #include <ATen/ops/view_as_complex_copy_ops.h>
1303
+ #include <ATen/ops/view_as_real_ops.h>
1304
+ #include <ATen/ops/view_as_real_copy_ops.h>
1305
+ #include <ATen/ops/view_copy_ops.h>
1306
+ #include <ATen/ops/vsplit_ops.h>
1307
+ #include <ATen/ops/vstack_ops.h>
1308
+ #include <ATen/ops/where_ops.h>
1309
+ #include <ATen/ops/xlogy_ops.h>
1310
+ #include <ATen/ops/xor_ops.h>
1311
+ #include <ATen/ops/zero_ops.h>
1312
+ #include <ATen/ops/zeros_ops.h>
1313
+ #include <ATen/ops/zeros_like_ops.h>
1314
+
1315
+ // Extension writers: do you write wrapper functions? Are you frustrated with
1316
+ // resolving overloads of operators? Are you frustrated with dealing with
1317
+ // pointer-to-methods and resolving overloads of pointer-to-methods?? Look no
1318
+ // further, this is the utility for you.
1319
+ //
1320
+ // Given an operator schema: aten::op.overload(...
1321
+ //
1322
+ // Use ATEN_FN2(op, overload) to get a *function* version of the operator
1323
+ // that is guaranteed to not be overloaded. This means that you can safely
1324
+ // decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args.
1325
+ //
1326
+ // Given an operator schema without an overload name: aten::op(...
1327
+ //
1328
+ // Use ATEN_FN(op) to get an unambiguous *function* version of the operator.
1329
+ //
1330
+ // There is some interesting behavior for out= operations.
1331
+ // ATEN_FN2(sin, out) gives a function that is *faithful* to the schema;
1332
+ // that is, the order of arguments is exactly what it looks like in the schema.
1333
+
1334
+ #define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload::call
1335
+ #define ATEN_FN(op_name) at::_ops::op_name::call
1336
+
1337
+ // Separately, ATEN_OP(op) and ATEN_OP2(op, overload) define a class containing compile-time
1338
+ // metadata about a given aten operator.
1339
+ // Notable data on the class includes:
1340
+ // - ATEN_OP2(add, Tensor)::name // returns the string name: "add"
1341
+ // - ATEN_OP2(add, Tensor)::overload_name // returns the string overload name: "Tensor"
1342
+ // - ATEN_OP2(add, Tensor)::schema // returns the C++ schema type: at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &)
1343
+ // - ATEN_OP2(add, Tensor)::schema_str // returns the string jit type: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
1344
+
1345
+ #define ATEN_OP2(op_name, overload) at::_ops::op_name##_##overload
1346
+ #define ATEN_OP(op_name) at::_ops::op_name
1347
+
1348
+ // WARNING: Please do not call any of the ops in the _ops namespace directly.
1349
+ // Use the ATEN_FN macros. We do not guarantee stability of the naming
1350
+ // scheme for the functions in at::_ops
1351
+
1352
+ // See Note [The ATen Operators API] for details of the at::_ops namespace
1353
+
1354
+ namespace at {
1355
+ namespace _ops {
1356
+
1357
+ } // namespace _ops
1358
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Parallel.h>
4
+ #include <c10/core/thread_pool.h>
5
+
6
+ namespace at {
7
+
8
+ class TORCH_API PTThreadPool : public c10::ThreadPool {
9
+ public:
10
+ explicit PTThreadPool(int pool_size, int numa_node_id = -1)
11
+ : c10::ThreadPool(pool_size, numa_node_id, []() {
12
+ c10::setThreadName("PTThreadPool");
13
+ at::init_num_threads();
14
+ }) {}
15
+ };
16
+
17
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelNativeTBB.h ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <atomic>
4
+ #include <cstddef>
5
+ #include <exception>
6
+
7
+ #include <c10/util/Exception.h>
8
+
9
+ #ifdef _WIN32
10
+ #ifndef WIN32_LEAN_AND_MEAN
11
+ #define WIN32_LEAN_AND_MEAN
12
+ #endif
13
+ #endif
14
+ #include <tbb/tbb.h>
15
+
16
+ #define INTRA_OP_PARALLEL
17
+
18
+ namespace at::internal {
19
+
20
+ template <typename F>
21
+ inline void invoke_parallel(
22
+ const int64_t begin,
23
+ const int64_t end,
24
+ const int64_t grain_size,
25
+ const F& f) {
26
+ // Choose number of tasks based on grain size and number of threads.
27
+ int64_t chunk_size = divup((end - begin), get_num_threads());
28
+ // Make sure each task is at least grain_size size.
29
+ chunk_size = std::max(grain_size, chunk_size);
30
+
31
+ std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
32
+ std::exception_ptr eptr;
33
+ tbb::parallel_for(
34
+ tbb::blocked_range<int64_t>(begin, end, chunk_size),
35
+ [&eptr, &err_flag, f](const tbb::blocked_range<int64_t>& r) {
36
+ try {
37
+ internal::ThreadIdGuard tid_guard(
38
+ tbb::this_task_arena::current_thread_index());
39
+ f(r.begin(), r.end());
40
+ } catch (...) {
41
+ if (!err_flag.test_and_set()) {
42
+ eptr = std::current_exception();
43
+ }
44
+ }
45
+ },
46
+ tbb::static_partitioner{});
47
+ if (eptr) {
48
+ std::rethrow_exception(eptr);
49
+ }
50
+ }
51
+
52
+ } // namespace at::internal
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/SafePyObject.h>
4
+ #include <c10/macros/Macros.h>
5
+
6
+ namespace at::impl {
7
+
8
+ enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
9
+
10
+ struct TORCH_API PythonTorchFunctionTLS {
11
+ static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
12
+ static TorchFunctionDisabledState get_disabled_state();
13
+
14
+ static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
15
+ static const std::shared_ptr<SafePyObject> pop_stack();
16
+ static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
17
+ static int64_t stack_len();
18
+
19
+ static const PythonTorchFunctionTLS& get_state();
20
+ static void set_state(const PythonTorchFunctionTLS& state);
21
+
22
+ private:
23
+ // The mode TLS is split into
24
+ // - disabled_state, which says which part of torch function are disabled
25
+ // - stack_, which is a vector of modes representing the stack of user
26
+ // defined modes
27
+ TorchFunctionDisabledState disabled_state_ =
28
+ TorchFunctionDisabledState::ENABLED;
29
+ std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
30
+ };
31
+
32
+ TORCH_API bool torch_function_mode_enabled();
33
+
34
+ } // namespace at::impl
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Export.h>
4
+ #include <cstdint>
5
+
6
+ // A simple thread local enumeration, used to link forward and backward pass
7
+ // ops and is used by autograd and observers framework
8
+ namespace at::sequence_number {
9
+
10
+ TORCH_API uint64_t peek();
11
+ TORCH_API uint64_t get_and_increment();
12
+
13
+ } // namespace at::sequence_number
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SparseTensorImpl.h ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Tensor.h>
4
+ #include <c10/core/TensorImpl.h>
5
+ #include <c10/util/Exception.h>
6
+ #include <c10/util/irange.h>
7
+
8
+ #ifndef AT_PER_OPERATOR_HEADERS
9
+ #include <ATen/Functions.h>
10
+ #else
11
+ #include <ATen/ops/empty.h>
12
+ #include <ATen/ops/resize.h>
13
+ #endif
14
+
15
+ namespace at {
16
+ struct TORCH_API SparseTensorImpl : public TensorImpl {
17
+ // Stored in COO format, indices + values.
18
+
19
+ // INVARIANTS:
20
+ // sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
21
+ // dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
22
+ // _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz)
23
+ // _values.shape: dimensionality: 1 + dense_dim. shape: (nnz,
24
+ // shape[sparse_dim:])
25
+
26
+ int64_t sparse_dim_ = 0; // number of sparse dimensions
27
+ int64_t dense_dim_ = 0; // number of dense dimensions
28
+
29
+ Tensor indices_; // always a LongTensor
30
+ Tensor values_;
31
+
32
+ // A sparse tensor is 'coalesced' if every index occurs at most once in
33
+ // the indices tensor, and the indices are in sorted order. (This means
34
+ // that it is very easy to convert a coalesced tensor to CSR format: you
35
+ // need only compute CSR format indices.)
36
+ //
37
+ // Most math operations can only be performed on coalesced sparse tensors,
38
+ // because many algorithms proceed by merging two sorted lists (of indices).
39
+ bool coalesced_ = false;
40
+
41
+ // compute_numel with integer multiplication overflow check, see gh-57542
42
+ void refresh_numel() {
43
+ TensorImpl::safe_refresh_numel();
44
+ }
45
+
46
+ public:
47
+ // Public for now...
48
+ explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
49
+
50
+ void release_resources() override;
51
+
52
+ int64_t nnz() const {
53
+ return values_.size(0);
54
+ }
55
+
56
+ c10::SymInt sym_nnz() const {
57
+ return values_.sym_size(0);
58
+ }
59
+ int64_t sparse_dim() const {
60
+ return sparse_dim_;
61
+ }
62
+ int64_t dense_dim() const {
63
+ return dense_dim_;
64
+ }
65
+ bool coalesced() const {
66
+ return coalesced_;
67
+ }
68
+ Tensor indices() const {
69
+ return indices_;
70
+ }
71
+ Tensor values() const {
72
+ return values_;
73
+ }
74
+
75
+ void set_size(int64_t dim, int64_t new_size) override;
76
+ void set_stride(int64_t dim, int64_t new_stride) override;
77
+ void set_storage_offset(int64_t storage_offset) override;
78
+
79
+ #ifdef DEBUG
80
+ bool has_storage() const override;
81
+ #endif
82
+
83
+ // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim
84
+ // with respect to indices and values
85
+ void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
86
+ TORCH_CHECK(
87
+ allow_tensor_metadata_change(),
88
+ "raw_resize_ ",
89
+ err_msg_tensor_metadata_change_not_allowed);
90
+ TORCH_CHECK(
91
+ !has_symbolic_sizes_strides_,
92
+ "raw_resize_ called on tensor with symbolic shape")
93
+ set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
94
+ sparse_dim_ = sparse_dim;
95
+ dense_dim_ = dense_dim;
96
+ refresh_numel();
97
+ }
98
+
99
+ // NOTE: This function preserves invariants of sparse_dim/dense_dim with
100
+ // respect to indices and values.
101
+ //
102
+ // NOTE: This function supports the following cases:
103
+ // 1. When we keep the number of dense dimensions unchanged, and NOT shrinking
104
+ // the size of any of the dense dimensions.
105
+ // 2. When we keep the number of sparse dimensions unchanged, and NOT
106
+ // shrinking the size of any of the sparse dimensions.
107
+ // 3. When the sparse tensor has zero nnz, in which case we are free to change
108
+ // the shapes of both its sparse and dense dimensions.
109
+ //
110
+ // This function DOESN'T support (and will throw an error) the following
111
+ // cases:
112
+ // 1. When we attempt to change the number of sparse dimensions on a non-empty
113
+ // sparse tensor (such an operation will invalidate the indices stored).
114
+ // 2. When we attempt to change the number of dense dimensions on a non-empty
115
+ // sparse tensor (such an operation will behave differently from an equivalent
116
+ // dense tensor's resize method, and for API consistency we don't support it).
117
+ // 3. When we attempt to shrink the size of any of the dense dimensions on a
118
+ // non-empty sparse tensor (such an operation will behave differently from an
119
+ // equivalent dense tensor's resize method, and for API consistency we don't
120
+ // support it).
121
+ // 4. When we attempt to shrink the size of any of the sparse dimensions on a
122
+ // non-empty sparse tensor (this could make some of the stored indices
123
+ // out-of-bound and thus unsafe).
124
+ template <typename T>
125
+ void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<T> size) {
126
+ TORCH_CHECK(
127
+ allow_tensor_metadata_change(),
128
+ "resize_ ",
129
+ err_msg_tensor_metadata_change_not_allowed);
130
+ TORCH_CHECK(
131
+ !has_symbolic_sizes_strides_,
132
+ "resize_ called on tensor with symbolic shape")
133
+ TORCH_CHECK(
134
+ sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
135
+ "number of dimensions must be sparse_dim (",
136
+ sparse_dim,
137
+ ") + dense_dim (",
138
+ dense_dim,
139
+ "), but got ",
140
+ size.size());
141
+ if (nnz() > 0) {
142
+ auto alt_options_msg =
143
+ "You could try the following options:\n\
144
+ 1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\
145
+ 2. If you need to resize this tensor, you have the following options:\n\
146
+ 1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\
147
+ 2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor.";
148
+
149
+ TORCH_CHECK(
150
+ sparse_dim == sparse_dim_,
151
+ "changing the number of sparse dimensions (from ",
152
+ sparse_dim_,
153
+ " to ",
154
+ sparse_dim,
155
+ ") on a non-empty sparse tensor is not supported.\n",
156
+ alt_options_msg);
157
+
158
+ TORCH_CHECK(
159
+ dense_dim == dense_dim_,
160
+ "changing the number of dense dimensions (from ",
161
+ dense_dim_,
162
+ " to ",
163
+ dense_dim,
164
+ ") on a non-empty sparse tensor is not supported.\n",
165
+ alt_options_msg);
166
+
167
+ bool shrinking_sparse_dims = false;
168
+ bool shrinking_dense_dim = false;
169
+ auto sparse_size_original = generic_sizes<T>().slice(0, sparse_dim);
170
+ auto sparse_size_new = size.slice(0, sparse_dim);
171
+ for (const auto i : c10::irange(sparse_dim)) {
172
+ if (sparse_size_new[i] < sparse_size_original[i]) {
173
+ shrinking_sparse_dims = true;
174
+ break;
175
+ }
176
+ }
177
+ auto dense_size_original = generic_sizes<T>().slice(sparse_dim);
178
+ auto dense_size_new = size.slice(sparse_dim);
179
+ for (const auto i : c10::irange(dense_dim)) {
180
+ if (dense_size_new[i] < dense_size_original[i]) {
181
+ shrinking_dense_dim = true;
182
+ break;
183
+ }
184
+ }
185
+
186
+ TORCH_CHECK(
187
+ !shrinking_sparse_dims,
188
+ "shrinking the size of sparse dimensions (from ",
189
+ sparse_size_original,
190
+ " to ",
191
+ sparse_size_new,
192
+ ") on a non-empty sparse tensor is not supported.\n",
193
+ alt_options_msg);
194
+
195
+ TORCH_CHECK(
196
+ !shrinking_dense_dim,
197
+ "shrinking the size of dense dimensions (from ",
198
+ dense_size_original,
199
+ " to ",
200
+ dense_size_new,
201
+ ") on a non-empty sparse tensor is not supported.\n",
202
+ alt_options_msg);
203
+ }
204
+
205
+ auto sizes_and_strides = generic_sizes<T>();
206
+ const bool size_equals_sizes = std::equal(
207
+ size.begin(),
208
+ size.end(),
209
+ sizes_and_strides.begin(),
210
+ sizes_and_strides.end());
211
+ if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) ||
212
+ (dense_dim != dense_dim_)) {
213
+ auto nnz = at::symint::sizes<T>(values())[0];
214
+ std::vector<T> values_size = {nnz};
215
+ auto dense_size = size.slice(sparse_dim);
216
+ values_size.insert(
217
+ values_size.end(), dense_size.begin(), dense_size.end());
218
+ at::symint::resize_<T>(values_, values_size);
219
+ at::symint::resize_<T>(indices_, {T(sparse_dim), nnz});
220
+ }
221
+
222
+ if (!size_equals_sizes) {
223
+ set_sizes_and_strides(size, std::vector<T>(size.size()));
224
+ }
225
+ sparse_dim_ = sparse_dim;
226
+ dense_dim_ = dense_dim;
227
+ refresh_numel();
228
+ }
229
+
230
+ void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size) {
231
+ return _resize_(sparse_dim, dense_dim, size);
232
+ }
233
+
234
+ void resize_(
235
+ int64_t sparse_dim,
236
+ int64_t dense_dim,
237
+ ArrayRef<c10::SymInt> size) {
238
+ return _resize_(sparse_dim, dense_dim, size);
239
+ }
240
+
241
+ // NOTE: this function will resize the sparse tensor and also set `indices`
242
+ // and `values` to empty.
243
+ void resize_and_clear_(
244
+ int64_t sparse_dim,
245
+ int64_t dense_dim,
246
+ IntArrayRef size) {
247
+ TORCH_CHECK(
248
+ allow_tensor_metadata_change(),
249
+ "resize_and_clear_ ",
250
+ err_msg_tensor_metadata_change_not_allowed);
251
+ TORCH_CHECK(
252
+ !has_symbolic_sizes_strides_,
253
+ "resize_and_clear_ called on tensor with symbolic shape")
254
+ TORCH_CHECK(
255
+ sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
256
+ "number of dimensions must be sparse_dim (",
257
+ sparse_dim,
258
+ ") + dense_dim (",
259
+ dense_dim,
260
+ "), but got ",
261
+ size.size());
262
+
263
+ set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
264
+ sparse_dim_ = sparse_dim;
265
+ dense_dim_ = dense_dim;
266
+
267
+ auto empty_indices = at::empty({sparse_dim, 0}, indices().options());
268
+ std::vector<int64_t> values_size = {0};
269
+ auto dense_size = sizes().slice(sparse_dim);
270
+ values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
271
+ auto empty_values = at::empty(values_size, values().options());
272
+ set_indices_and_values_unsafe(empty_indices, empty_values);
273
+ refresh_numel();
274
+ }
275
+
276
+ void set_coalesced(bool coalesced) {
277
+ TORCH_CHECK(
278
+ allow_tensor_metadata_change(),
279
+ "set_coalesced ",
280
+ err_msg_tensor_metadata_change_not_allowed);
281
+ coalesced_ = coalesced;
282
+ }
283
+
284
+ // NOTE: this function is only used internally and not exposed to Python
285
+ // frontend
286
+ void set_nnz_and_narrow(int64_t new_nnz) {
287
+ TORCH_CHECK(
288
+ allow_tensor_metadata_change(),
289
+ "set_nnz_and_narrow ",
290
+ err_msg_tensor_metadata_change_not_allowed);
291
+ AT_ASSERT(new_nnz <= nnz());
292
+ indices_ = indices_.narrow(1, 0, new_nnz);
293
+ values_ = values_.narrow(0, 0, new_nnz);
294
+ if (new_nnz < 2) {
295
+ coalesced_ = true;
296
+ }
297
+ }
298
+
299
+ // Takes indices and values and directly puts them into the sparse tensor, no
300
+ // copy. NOTE: this function is unsafe because it doesn't check whether any
301
+ // indices are out of boundaries of `sizes`, so it should ONLY be used where
302
+ // we know that the indices are guaranteed to be within bounds. This used to
303
+ // be called THSTensor_(_move) NB: This used to be able to avoid a refcount
304
+ // bump, but I was too lazy to make it happen
305
+ void set_indices_and_values_unsafe(
306
+ const Tensor& indices,
307
+ const Tensor& values);
308
+
309
+ /**
310
+ * Return a TensorImpl that is a shallow-copy of this TensorImpl.
311
+ *
312
+ * For usage of `version_counter` and `allow_tensor_metadata_change`,
313
+ * see NOTE [ TensorImpl Shallow-Copying ].
314
+ */
315
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
316
+ const c10::VariableVersion& version_counter,
317
+ bool allow_tensor_metadata_change) const override {
318
+ auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
319
+ copy_tensor_metadata(
320
+ /*src_sparse_impl=*/this,
321
+ /*dest_sparse_impl=*/impl.get(),
322
+ /*version_counter=*/version_counter,
323
+ /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
324
+ impl->refresh_numel();
325
+ return impl;
326
+ }
327
+
328
+ /**
329
+ * Return a TensorImpl that is a shallow-copy of this TensorImpl.
330
+ *
331
+ * For usage of `version_counter` and `allow_tensor_metadata_change`,
332
+ * see NOTE [ TensorImpl Shallow-Copying ].
333
+ */
334
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
335
+ c10::VariableVersion&& version_counter,
336
+ bool allow_tensor_metadata_change) const override {
337
+ auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
338
+ copy_tensor_metadata(
339
+ /*src_sparse_impl=*/this,
340
+ /*dest_sparse_impl=*/impl.get(),
341
+ /*version_counter=*/std::move(version_counter),
342
+ /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
343
+ impl->refresh_numel();
344
+ return impl;
345
+ }
346
+
347
+ /**
348
+ * Shallow-copies data from another TensorImpl into this TensorImpl.
349
+ *
350
+ * For why this function doesn't check this TensorImpl's
351
+ * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
352
+ */
353
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
354
+ AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
355
+ auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get());
356
+ copy_tensor_metadata(
357
+ /*src_sparse_impl=*/sparse_impl,
358
+ /*dest_sparse_impl=*/this,
359
+ /*version_counter=*/version_counter(),
360
+ /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
361
+ refresh_numel();
362
+ }
363
+
364
+ private:
365
+ explicit SparseTensorImpl(
366
+ at::DispatchKeySet,
367
+ const caffe2::TypeMeta,
368
+ at::Tensor indices,
369
+ at::Tensor values);
370
+
371
+ /**
372
+ * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
373
+ * storage_offset) from one TensorImpl to another TensorImpl.
374
+ *
375
+ * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
376
+ * [ TensorImpl Shallow-Copying ].
377
+ */
378
+ static void copy_tensor_metadata(
379
+ const SparseTensorImpl* src_sparse_impl,
380
+ SparseTensorImpl* dest_sparse_impl,
381
+ c10::VariableVersion version_counter,
382
+ bool allow_tensor_metadata_change) {
383
+ TensorImpl::copy_tensor_metadata(
384
+ src_sparse_impl,
385
+ dest_sparse_impl,
386
+ std::move(version_counter),
387
+ allow_tensor_metadata_change);
388
+
389
+ // Sparse-specific fields
390
+ dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim();
391
+ dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim();
392
+ dest_sparse_impl->indices_ = src_sparse_impl->indices();
393
+ dest_sparse_impl->values_ = src_sparse_impl->values();
394
+ dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced();
395
+ }
396
+
397
+ const char* tensorimpl_type_name() const override;
398
+ };
399
+
400
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/StorageUtils.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Storage.h>
4
+ #include <c10/core/StorageImpl.h>
5
+ #include <c10/util/intrusive_ptr.h>
6
+
7
+ namespace at {
8
+
9
+ class TensorBase;
10
+
11
+ // Here we define a series of utils to create/manipulate ATen backed
12
+ // c10 storage implementations.
13
+
14
+ /**
15
+ * Create a new shared memory storage impl managed by file descriptor
16
+ *
17
+ * @param size size in bytes
18
+ */
19
+ C10_EXPORT c10::intrusive_ptr<c10::StorageImpl> new_shm_fd_storage(size_t size);
20
+
21
+ /**
22
+ * Copy src to dst
23
+ * Caller must guarantee the validness of the storage objects
24
+ * during the entire copy process, esp. when it's async.
25
+ *
26
+ * This can probably live in c10 namespace later if needed,
27
+ * but for now keep it in at to keep implementation simple.
28
+ *
29
+ * @param dst dst tensor
30
+ * @param src src tensor
31
+ * @param non_blocking (default false) whether this operation blocks caller
32
+ */
33
+ C10_EXPORT void storage_copy(
34
+ c10::Storage& dst,
35
+ const c10::Storage& src,
36
+ bool non_blocking = false);
37
+
38
+ /**
39
+ * In place change the storage to shm based.
40
+ *
41
+ * This is only applicable to CPU tensors not already shared.
42
+ * Otherwise, it's a no op to mirror the THP tensor behavior:
43
+ * https://pytorch.org/docs/stable/generated/torch.Tensor.share_memory_.html
44
+ *
45
+ * @param t a tensor
46
+ */
47
+ C10_EXPORT void share_memory_(TensorBase& t);
48
+
49
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorIndexing.h ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ExpandUtils.h>
4
+ #include <ATen/ScalarOps.h>
5
+ #include <ATen/core/Tensor.h>
6
+ #include <ATen/core/TensorBody.h>
7
+ #include <c10/core/SymInt.h>
8
+ #include <c10/util/Optional.h>
9
+ #include <c10/util/irange.h>
10
+
11
+ #ifndef AT_PER_OPERATOR_HEADERS
12
+ #include <ATen/Functions.h>
13
+ #include <ATen/NativeFunctions.h>
14
+ #else
15
+ #include <ATen/ops/alias.h>
16
+ #include <ATen/ops/empty.h>
17
+ #include <ATen/ops/scalar_tensor.h>
18
+ #include <ATen/ops/zeros.h>
19
+ #endif
20
+
21
+ #include <ATen/core/List.h>
22
+
23
+ #include <utility>
24
+
25
+ namespace at::indexing {
26
+
27
+ constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
28
+ constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
29
+
30
+ enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
31
+
32
+ constexpr c10::nullopt_t None = c10::nullopt;
33
+
34
+ struct TORCH_API EllipsisIndexType final {
35
+ EllipsisIndexType() = default;
36
+ };
37
+ TORCH_API extern const EllipsisIndexType Ellipsis;
38
+
39
+ struct TORCH_API Slice final {
40
+ public:
41
+ Slice(
42
+ c10::optional<c10::SymInt> start_index = c10::nullopt,
43
+ c10::optional<c10::SymInt> stop_index = c10::nullopt,
44
+ c10::optional<c10::SymInt> step_index = c10::nullopt) {
45
+ if (!step_index.has_value()) {
46
+ step_ = c10::SymInt(1);
47
+ } else {
48
+ step_ = std::move(step_index).value();
49
+ }
50
+
51
+ TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");
52
+
53
+ if (!start_index.has_value()) {
54
+ start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
55
+ } else {
56
+ start_ = std::move(start_index).value();
57
+ }
58
+
59
+ if (!stop_index.has_value()) {
60
+ stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
61
+ } else {
62
+ stop_ = std::move(stop_index).value();
63
+ }
64
+ }
65
+
66
+ inline c10::SymInt start() const {
67
+ return start_;
68
+ }
69
+
70
+ inline c10::SymInt stop() const {
71
+ return stop_;
72
+ }
73
+
74
+ inline c10::SymInt step() const {
75
+ return step_;
76
+ }
77
+
78
+ private:
79
+ c10::SymInt start_;
80
+ c10::SymInt stop_;
81
+ c10::SymInt step_;
82
+ };
83
+
84
+ TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
85
+
86
+ // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
87
+ // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
88
+ // into its equivalent `std::vector<TensorIndex>`, so that further tensor
89
+ // indexing operations can be performed using the supplied indices.
90
+ //
91
+ // There is one-to-one correspondence between Python and C++ tensor index types:
92
+ // Python | C++
93
+ // -----------------------------------------------------
94
+ // `None` | `at::indexing::None`
95
+ // `Ellipsis` | `at::indexing::Ellipsis`
96
+ // `...` | `"..."`
97
+ // `123` | `123`
98
+ // `True` / `False` | `true` / `false`
99
+ // `:` | `Slice()` / `Slice(None, None)`
100
+ // `::` | `Slice()` / `Slice(None, None, None)`
101
+ // `1:` | `Slice(1, None)`
102
+ // `1::` | `Slice(1, None, None)`
103
+ // `:3` | `Slice(None, 3)`
104
+ // `:3:` | `Slice(None, 3, None)`
105
+ // `::2` | `Slice(None, None, 2)`
106
+ // `1:3` | `Slice(1, 3)`
107
+ // `1::2` | `Slice(1, None, 2)`
108
+ // `:3:2` | `Slice(None, 3, 2)`
109
+ // `1:3:2` | `Slice(1, 3, 2)`
110
+ // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
111
+ struct TORCH_API TensorIndex final {
112
+ // Case 1: `at::indexing::None`
113
+ TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}
114
+
115
+ // Case 2: "..." / `at::indexing::Ellipsis`
116
+ TensorIndex(at::indexing::EllipsisIndexType)
117
+ : type_(TensorIndexType::Ellipsis) {}
118
+ TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
119
+ TORCH_CHECK_VALUE(
120
+ strcmp(str, "...") == 0,
121
+ "Expected \"...\" to represent an ellipsis index, but got \"",
122
+ str,
123
+ "\"");
124
+ }
125
+
126
+ // Case 3: (Sym) Integer value
127
+ TensorIndex(SymInt integer)
128
+ : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
129
+ TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
130
+ TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
131
+
132
+ // Case 4: Boolean value
133
+ template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
134
+ TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
135
+
136
+ // Case 5: Slice represented in `at::indexing::Slice` form
137
+ TensorIndex(Slice slice)
138
+ : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
139
+
140
+ // Case 6: Tensor value
141
+ TensorIndex(Tensor tensor)
142
+ : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
143
+
144
+ inline bool is_none() const {
145
+ return type_ == TensorIndexType::None;
146
+ }
147
+
148
+ inline bool is_ellipsis() const {
149
+ return type_ == TensorIndexType::Ellipsis;
150
+ }
151
+
152
+ inline bool is_integer() const {
153
+ return type_ == TensorIndexType::SymInt;
154
+ }
155
+
156
+ inline SymInt integer() const {
157
+ return integer_;
158
+ }
159
+
160
+ inline bool is_boolean() const {
161
+ return type_ == TensorIndexType::Boolean;
162
+ }
163
+
164
+ inline bool boolean() const {
165
+ return boolean_;
166
+ }
167
+
168
+ inline bool is_slice() const {
169
+ return type_ == TensorIndexType::Slice;
170
+ }
171
+
172
+ inline const Slice& slice() const {
173
+ return slice_;
174
+ }
175
+
176
+ inline bool is_tensor() const {
177
+ return type_ == TensorIndexType::Tensor;
178
+ }
179
+
180
+ inline const Tensor& tensor() const {
181
+ return tensor_;
182
+ }
183
+
184
+ private:
185
+ SymInt integer_ = 0;
186
+ bool boolean_ = false;
187
+ Slice slice_;
188
+ Tensor tensor_;
189
+ TensorIndexType type_;
190
+ };
191
+
192
+ TORCH_API std::ostream& operator<<(
193
+ std::ostream& stream,
194
+ const TensorIndex& tensor_index);
195
+ TORCH_API std::ostream& operator<<(
196
+ std::ostream& stream,
197
+ const std::vector<TensorIndex>& tensor_indices);
198
+
199
+ namespace impl {
200
+ static inline Tensor applySlice(
201
+ const Tensor& self,
202
+ int64_t dim,
203
+ c10::SymInt start,
204
+ c10::SymInt stop,
205
+ c10::SymInt step,
206
+ bool disable_slice_optimization,
207
+ const at::Device& self_device,
208
+ const c10::optional<SymIntArrayRef>& self_sizes) {
209
+ // TODO: implement negative step
210
+ TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
211
+
212
+ // See NOTE [nested tensor size for indexing]
213
+ if (self_sizes.has_value()) {
214
+ // Skip this optimization if we are tracing, as the trace may be polymorphic
215
+ // over the shape of the `self` tensor, and we still want to record
216
+ // the slice.
217
+ SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
218
+ ? (*self_sizes)[dim]
219
+ : self.sym_size(dim);
220
+ if (!disable_slice_optimization &&
221
+ TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) && length == stop &&
222
+ step == 1) {
223
+ return self;
224
+ }
225
+ }
226
+ return self.slice_symint(
227
+ dim, std::move(start), std::move(stop), std::move(step));
228
+ }
229
+
230
+ static inline Tensor applySelect(
231
+ const Tensor& self,
232
+ int64_t dim,
233
+ SymInt index,
234
+ int64_t real_dim,
235
+ const at::Device& /*self_device*/,
236
+ const c10::optional<SymIntArrayRef>& self_sizes) {
237
+ // See NOTE [nested tensor size for indexing]
238
+ if (self_sizes.has_value()) {
239
+ auto maybe_index = index.maybe_as_int();
240
+ if (maybe_index.has_value()) {
241
+ TORCH_CHECK_INDEX(
242
+ !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
243
+ "invalid index of a 0-dim tensor. ",
244
+ "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
245
+ }
246
+
247
+ auto size = (*self_sizes)[dim];
248
+ // Note: `size >= -index` is not equivalent to `size > -1 - index` if index
249
+ // is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
250
+ // minus is undefined by the standard but in practice is equal to self. On
251
+ // the other hand, indexing wraping is valid for all negative int64_t
252
+ // values, as x[INT64_MIN] is the same as x[INT64_MAX]
253
+ TORCH_CHECK_INDEX(
254
+ size > -1 - index && size > index,
255
+ "index ",
256
+ index,
257
+ " is out of bounds for dimension ",
258
+ real_dim,
259
+ " with size ",
260
+ size);
261
+ }
262
+
263
+ // if the index is negative, do not normalize it because that would fix the
264
+ // index on the current tensor size in the tracer. aten::select also works on
265
+ // negative indices
266
+ return self.select_symint(dim, std::move(index));
267
+ }
268
+
269
+ static inline Tensor boolToIndexingTensorCPUOrCUDA(
270
+ const Tensor& self,
271
+ bool value) {
272
+ // booleans add a dimension of size 1. true indexes this dimension as if 0:,
273
+ // false as empty.
274
+ if (value) {
275
+ return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
276
+ } else {
277
+ return at::empty({0}, self.options().dtype(kLong));
278
+ }
279
+ }
280
+
281
+ static inline Tensor boolToIndexingTensorNonNativeDeviceType(
282
+ const Tensor& self,
283
+ bool value) {
284
+ // booleans add a dimension of size 1. true indexes this dimension as if 0:,
285
+ // false as empty.
286
+ if (value) {
287
+ return at::zeros({1}, self.options().dtype(kLong));
288
+ } else {
289
+ return at::empty({0}, self.options().dtype(kLong));
290
+ }
291
+ }
292
+
293
+ static inline Tensor boolToIndexingTensor(
294
+ const Tensor& self,
295
+ bool value,
296
+ const at::Device& self_device) {
297
+ if (self_device == at::kCPU || self_device == at::kCUDA) {
298
+ return boolToIndexingTensorCPUOrCUDA(self, value);
299
+ } else {
300
+ return boolToIndexingTensorNonNativeDeviceType(self, value);
301
+ }
302
+ }
303
+
304
+ static inline Tensor scalarToTensorNonNativeDeviceType(
305
+ const Scalar& v,
306
+ const TensorOptions& options) {
307
+ return at::scalar_tensor(v, options);
308
+ }
309
+
310
+ static inline void recordTensorIndex(
311
+ const Tensor& tensor,
312
+ std::vector<Tensor>& outIndices,
313
+ int64_t* dim_ptr) {
314
+ // TODO: check scalarType
315
+ outIndices.resize(*dim_ptr + 1);
316
+ outIndices[*dim_ptr] = tensor;
317
+ (*dim_ptr)++;
318
+ };
319
+
320
+ static inline c10::List<c10::optional<Tensor>> typeConvertIndices(
321
+ const Tensor& /*self*/,
322
+ std::vector<Tensor>&& indices) {
323
+ c10::List<c10::optional<Tensor>> converted_inds;
324
+ converted_inds.reserve(indices.size());
325
+ for (auto&& i : std::move(indices)) {
326
+ converted_inds.push_back(std::move(i));
327
+ }
328
+ return converted_inds;
329
+ }
330
+
331
+ // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
332
+ // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
333
+ // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
334
+ // indexing (i.e. it's called by `applySlicing` which is called by
335
+ // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
336
+ // than one dimension). If we were to merge the Python/C++
337
+ // `count_specified_dimensions` function, on the Python side we would have to
338
+ // construct a `std::vector` container to be consumed by the C++
339
+ // `count_specified_dimensions` function, which adds 100s of nanoseconds
340
+ // overhead and is undesirable.
341
+ static inline int64_t count_specified_dimensions(
342
+ const ArrayRef<TensorIndex>& indices) {
343
+ // Count the number of indexed dimensions (everything but ellipsis and None)
344
+ int64_t count = 0;
345
+ for (auto& obj : indices) {
346
+ if (obj.is_tensor()) {
347
+ auto& tensor = obj.tensor();
348
+ if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
349
+ count += tensor.dim();
350
+ } else {
351
+ count++;
352
+ }
353
+ } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
354
+ count++;
355
+ }
356
+ }
357
+ return count;
358
+ }
359
+ } // namespace impl
360
+
361
+ // NOTE: Many functions below are only for consumption from Python indexing
362
+ // implementation, they include:
363
+ //
364
+ // - `Tensor scalarToTensor(...)`
365
+ // - `IntArrayRef slicePrefix1sSize(...)`
366
+ // - `void copy_to(...)`
367
+ // - `Tensor handleDimInMultiDimIndexing(...)`
368
+ // - `Tensor dispatch_index(...)`
369
+ // - `Tensor dispatch_index_put_(...)`
370
+ // - `Tensor get_item(...)`
371
+ // - `void set_item(...)`
372
+ //
373
+ // The rest of the functions are in `at::indexing::impl` namespace, signifying
374
+ // that they shouldn't be used from Python indexing implementation.
375
+ static inline Tensor scalarToTensor(
376
+ const Scalar& v,
377
+ const TensorOptions& options,
378
+ const at::Device& self_device) {
379
+ if (self_device == at::kCPU && !v.isSymbolic()) {
380
+ return at::detail::scalar_tensor_static(
381
+ v, options.dtype_opt()->toScalarType(), self_device);
382
+ } else {
383
+ return impl::scalarToTensorNonNativeDeviceType(v, options);
384
+ }
385
+ }
386
+
387
+ // To match numpy semantics:
388
+ // As a special case for backwards compatibility,
389
+ // strip away unit dimensions from the left of 'src'
390
+ static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
391
+ size_t first_non1_src = sizes.size();
392
+ for (const auto i : c10::irange(sizes.size())) {
393
+ // Unbacked SymInt has different behavior, but this is sound because
394
+ // failing to slice will only ever cause an error, not divergent
395
+ // behavior
396
+ if (!sizes[i].has_hint() || sizes[i] != 1) {
397
+ first_non1_src = i;
398
+ break;
399
+ }
400
+ }
401
+
402
+ return sizes.slice(first_non1_src);
403
+ }
404
+
405
+ static inline void copy_to(const Tensor& dst, const Tensor& src) {
406
+ if (dst.sym_sizes().equals(src.sym_sizes())) {
407
+ // A shortcut to avoid generating hard-coded constant sizes during tracing.
408
+ // This is not a perfect solution: when src & dst have different shapes,
409
+ // constants will still appear. Users can workaround that case by
410
+ // dst[index..] = src.reshape(..)
411
+ dst.copy_(src);
412
+ return;
413
+ } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
414
+ dst.fill_(src);
415
+ return;
416
+ }
417
+ auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
418
+ c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
419
+ dst.copy_(*b_src);
420
+ }
421
+
422
+ // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
423
+ // indexing functions from Python ]
424
+ static inline Tensor handleDimInMultiDimIndexing(
425
+ const Tensor& prev_dim_result,
426
+ const Tensor& original_tensor,
427
+ const TensorIndex& index,
428
+ int64_t* dim_ptr,
429
+ int64_t* specified_dims_ptr,
430
+ int64_t real_dim,
431
+ std::vector<Tensor>& outIndices,
432
+ bool disable_slice_optimization,
433
+ const at::Device& original_tensor_device,
434
+ const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) {
435
+ if (index.is_integer()) {
436
+ return impl::applySelect(
437
+ prev_dim_result,
438
+ *dim_ptr,
439
+ index.integer(),
440
+ real_dim,
441
+ original_tensor_device,
442
+ prev_dim_result_sizes);
443
+ } else if (index.is_slice()) {
444
+ Tensor result = impl::applySlice(
445
+ prev_dim_result,
446
+ *dim_ptr,
447
+ index.slice().start(),
448
+ index.slice().stop(),
449
+ index.slice().step(),
450
+ /*disable_slice_optimization=*/disable_slice_optimization,
451
+ original_tensor_device,
452
+ prev_dim_result_sizes);
453
+ (*dim_ptr)++;
454
+ return result;
455
+ } else if (index.is_ellipsis()) {
456
+ (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
457
+ return prev_dim_result;
458
+ } else if (index.is_none()) {
459
+ Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
460
+ (*dim_ptr)++;
461
+ return result;
462
+ } else if (index.is_boolean()) {
463
+ Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
464
+ impl::recordTensorIndex(
465
+ impl::boolToIndexingTensor(
466
+ result, index.boolean(), original_tensor_device),
467
+ outIndices,
468
+ dim_ptr);
469
+ return result;
470
+ } else if (index.is_tensor()) {
471
+ Tensor result = prev_dim_result;
472
+ const Tensor& tensor = index.tensor();
473
+ auto scalar_type = tensor.scalar_type();
474
+ if (tensor.dim() == 0 &&
475
+ at::isIntegralType(scalar_type, /*includeBool=*/true)) {
476
+ if (scalar_type != at::kByte && scalar_type != at::kBool) {
477
+ result = impl::applySelect(
478
+ result,
479
+ *dim_ptr,
480
+ tensor.item<int64_t>(),
481
+ real_dim,
482
+ original_tensor_device,
483
+ prev_dim_result_sizes);
484
+ } else {
485
+ result = result.unsqueeze(*dim_ptr);
486
+ if (scalar_type == at::kBool) {
487
+ impl::recordTensorIndex(
488
+ impl::boolToIndexingTensor(
489
+ result, tensor.item<bool>() != 0, original_tensor_device),
490
+ outIndices,
491
+ dim_ptr);
492
+ } else {
493
+ impl::recordTensorIndex(
494
+ impl::boolToIndexingTensor(
495
+ result, tensor.item<uint8_t>() != 0, original_tensor_device),
496
+ outIndices,
497
+ dim_ptr);
498
+ }
499
+ }
500
+ } else {
501
+ impl::recordTensorIndex(tensor, outIndices, dim_ptr);
502
+ }
503
+ return result;
504
+ } else {
505
+ TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
506
+ }
507
+ }
508
+
509
+ namespace impl {
510
+ // This mirrors `applySlicing` in
511
+ // torch/csrc/autograd/python_variable_indexing.cpp
512
+ static inline Tensor applySlicing(
513
+ const Tensor& self,
514
+ const ArrayRef<TensorIndex>& indices,
515
+ std::vector<Tensor>& outIndices,
516
+ bool disable_slice_optimization,
517
+ const at::Device& self_device,
518
+ const c10::optional<SymIntArrayRef>& self_sizes) {
519
+ int64_t dim = 0;
520
+ int64_t specified_dims = impl::count_specified_dimensions(indices);
521
+
522
+ // See NOTE [nested tensor size for indexing]
523
+ if (self_sizes.has_value()) {
524
+ TORCH_CHECK_INDEX(
525
+ specified_dims <= (int64_t)self_sizes->size(),
526
+ "too many indices for tensor of dimension ",
527
+ (int)self_sizes->size());
528
+ }
529
+
530
+ Tensor result = self;
531
+ for (const auto i : c10::irange(indices.size())) {
532
+ auto& obj = indices[i];
533
+ // See NOTE [nested tensor size for indexing]
534
+ c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
535
+ ? c10::optional<SymIntArrayRef>(c10::nullopt)
536
+ : c10::optional<SymIntArrayRef>(result.sym_sizes());
537
+ result = handleDimInMultiDimIndexing(
538
+ /*prev_dim_result=*/result,
539
+ /*original_tensor=*/self,
540
+ /*index=*/obj,
541
+ /*dim_ptr=*/&dim,
542
+ /*specified_dims_ptr=*/&specified_dims,
543
+ /*real_dim=*/static_cast<int64_t>(i),
544
+ /*outIndices=*/outIndices,
545
+ /*disable_slice_optimization=*/disable_slice_optimization,
546
+ /*original_tensor_device=*/self_device,
547
+ /*prev_dim_result_sizes=*/result_sizes);
548
+ }
549
+ return result;
550
+ }
551
+ } // namespace impl
552
+
553
+ static inline Tensor dispatch_index(
554
+ const Tensor& self,
555
+ std::vector<Tensor>&& indices) {
556
+ return self.index(impl::typeConvertIndices(self, std::move(indices)));
557
+ }
558
+
559
+ static inline Tensor dispatch_index_put_(
560
+ Tensor& self,
561
+ std::vector<Tensor>&& indices,
562
+ const Tensor& value) {
563
+ return self.index_put_(
564
+ impl::typeConvertIndices(self, std::move(indices)), value);
565
+ }
566
+
567
+ // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
568
+ // functions from Python ]
569
+ //
570
+ // Question: When should we set `disable_slice_optimization` to `true` when
571
+ // calling C++ tensor indexing functions from Python indexing code?
572
+ //
573
+ // Answer: What "slice optimization" means: when we have a slicing expression
574
+ // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
575
+ // would skip dispatching the actual slice call as an optimization. However,
576
+ // here are the cases where we DON'T want this optimization:
577
+ //
578
+ // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
579
+ // Reason: we always return a shallow copy for expressions such as
580
+ // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
581
+ // :]`, we return an alias of `tensor` by doing the following:
582
+ // ```
583
+ // Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
584
+ // disable_slice_optimization, self_device, self_sizes); if
585
+ // (tensorIndices.empty()) {
586
+ // if (sliced.is_same(self)) {
587
+ // // ensure we return a shallow copy for things like x[...]
588
+ // sliced = at::alias(sliced);
589
+ // }
590
+ // return sliced;
591
+ // }
592
+ // ```)
593
+ // 2. When we are doing JIT tracing.
594
+ // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
595
+ // slice operation.
596
+
597
+ // This mirrors `THPVariable_getitem` in
598
+ // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
599
+ // `disable_slice_optimization` when calling C++ tensor indexing functions from
600
+ // Python ]
601
+ static inline Tensor get_item(
602
+ const Tensor& self,
603
+ const ArrayRef<TensorIndex>& indices,
604
+ bool disable_slice_optimization = false) {
605
+ at::Device self_device = self.device();
606
+ // NOTE [nested tensor size for indexing]
607
+ // nested tensor does not have a size (yet) so for now we represent its size
608
+ // as null may need to be changed after we reach a better solution for nested
609
+ // tensor size
610
+ c10::optional<SymIntArrayRef> self_sizes = self.is_nested()
611
+ ? c10::optional<SymIntArrayRef>(c10::nullopt)
612
+ : c10::optional<SymIntArrayRef>(self.sym_sizes());
613
+
614
+ // handle simple types: integers, slices, none, ellipsis, bool
615
+ if (indices.size() == 1) {
616
+ const TensorIndex& index = indices[0];
617
+ if (index.is_integer()) {
618
+ return impl::applySelect(
619
+ self, 0, index.integer(), 0, self_device, self_sizes);
620
+ } else if (index.is_slice()) {
621
+ return impl::applySlice(
622
+ self,
623
+ 0,
624
+ index.slice().start(),
625
+ index.slice().stop(),
626
+ index.slice().step(),
627
+ /*disable_slice_optimization=*/true,
628
+ self_device,
629
+ self_sizes);
630
+ } else if (index.is_none()) {
631
+ return self.unsqueeze(0);
632
+ } else if (index.is_ellipsis()) {
633
+ return at::alias(self);
634
+ } else if (index.is_boolean()) {
635
+ Tensor result = self.unsqueeze(0);
636
+ return dispatch_index(
637
+ result,
638
+ std::vector<Tensor>{impl::boolToIndexingTensor(
639
+ result, index.boolean(), self_device)});
640
+ }
641
+ }
642
+
643
+ std::vector<Tensor> tensorIndices;
644
+ Tensor sliced = impl::applySlicing(
645
+ self,
646
+ indices,
647
+ tensorIndices,
648
+ disable_slice_optimization,
649
+ self_device,
650
+ self_sizes);
651
+ if (tensorIndices.empty()) {
652
+ if (sliced.is_same(self)) {
653
+ // ensure we return a shallow copy for things like x[...]
654
+ sliced = at::alias(sliced);
655
+ }
656
+ return sliced;
657
+ }
658
+
659
+ // indexing by tensors ("advanced" indexing)
660
+ return dispatch_index(sliced, std::move(tensorIndices));
661
+ }
662
+
663
+ // This mirrors `THPVariable_setitem` in
664
+ // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
665
+ // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
666
+ // tensor indexing functions from Python ]
667
+ static inline void set_item(
668
+ const Tensor& self,
669
+ const ArrayRef<TensorIndex>& indices,
670
+ const Tensor& value,
671
+ bool disable_slice_optimization = false) {
672
+ at::Device self_device = self.device();
673
+ SymIntArrayRef self_sizes = self.sym_sizes();
674
+
675
+ // handle simple types: integers, slices, ellipsis, bool
676
+ if (indices.size() == 1) {
677
+ const TensorIndex& index = indices[0];
678
+ if (index.is_boolean() && !index.boolean()) {
679
+ // do nothing for false (technically we should check the size, but we
680
+ // don't have real 0-sized shapes.
681
+ return;
682
+ } else if (index.is_ellipsis()) {
683
+ copy_to(self, value);
684
+ return;
685
+ } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
686
+ copy_to(self.unsqueeze(0), value);
687
+ return;
688
+ } else if (index.is_integer()) {
689
+ copy_to(
690
+ impl::applySelect(
691
+ self, 0, index.integer(), 0, self_device, self_sizes),
692
+ value);
693
+ return;
694
+ } else if (index.is_slice()) {
695
+ copy_to(
696
+ impl::applySlice(
697
+ self,
698
+ 0,
699
+ index.slice().start(),
700
+ index.slice().stop(),
701
+ index.slice().step(),
702
+ /*disable_slice_optimization=*/disable_slice_optimization,
703
+ self_device,
704
+ self_sizes),
705
+ value);
706
+ return;
707
+ }
708
+ }
709
+
710
+ std::vector<Tensor> tensorIndices;
711
+ Tensor sliced = impl::applySlicing(
712
+ self,
713
+ indices,
714
+ tensorIndices,
715
+ disable_slice_optimization,
716
+ self_device,
717
+ self_sizes);
718
+ if (tensorIndices.empty()) {
719
+ copy_to(sliced, value);
720
+ return;
721
+ }
722
+
723
+ SymIntArrayRef valueSizes = value.sym_sizes();
724
+ SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
725
+ Tensor valuesSliced;
726
+ if (!valueSizes.equals(slicedValueSizes)) {
727
+ valuesSliced = value.view_symint(slicedValueSizes);
728
+ } else {
729
+ valuesSliced = value;
730
+ }
731
+ dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
732
+ return;
733
+ }
734
+
735
+ } // namespace at::indexing
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/SafePyObject.h>
4
+ #include <c10/macros/Macros.h>
5
+ #include <unordered_map>
6
+
7
+ namespace at::impl {
8
+
9
+ struct TORCH_API ThreadLocalPythonObjects {
10
+ static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
11
+ static const std::shared_ptr<SafePyObject>& get(const std::string& key);
12
+ static bool contains(const std::string& key);
13
+
14
+ static const ThreadLocalPythonObjects& get_state();
15
+ static void set_state(ThreadLocalPythonObjects state);
16
+
17
+ private:
18
+ std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
19
+ };
20
+
21
+ } // namespace at::impl
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/InferenceMode.h>
4
+ #include <c10/core/impl/LocalDispatchKeySet.h>
5
+ #include <c10/util/Exception.h>
6
+ #include <c10/util/ThreadLocalDebugInfo.h>
7
+
8
+ #include <ATen/FuncTorchTLS.h>
9
+ #include <ATen/PythonTorchFunctionTLS.h>
10
+ #include <ATen/SavedTensorHooks.h>
11
+ #include <ATen/ThreadLocalPythonObjects.h>
12
+ #include <ATen/record_function.h>
13
+ #include <c10/core/impl/PythonDispatcherTLS.h>
14
+ #include <c10/core/impl/TorchDispatchModeTLS.h>
15
+
16
+ namespace at {
17
+
18
+ // Thread local state contains values that are preserved across
19
+ // thread boundaries (e.g. at::launch/JIT fork, autograd).
20
+ // Note at::parallel_for doesn't preserve TLS across thread boundaries.
21
+ class TORCH_API ThreadLocalState {
22
+ public:
23
+ // Saves the thread local variables' values and
24
+ // returns them as a ThreadLocalState
25
+ ThreadLocalState();
26
+
27
+ // set_grad_mode - force the value of the grad mode TLS in
28
+ // the current state object. This is used for example in the
29
+ // autograd engine.
30
+ void set_grad_mode(bool enabled);
31
+
32
+ // set_multithreading_enabled - force the value of the multithreadinmaximum
33
+ // threads TLS in
34
+ // the current state object. This is used for example in the
35
+ // autograd engine.
36
+ void set_multithreading_enabled(bool enabled);
37
+
38
+ // Sets thread local variables in the current thread,
39
+ // according to the thread boundary specified
40
+ static void setThreadLocalState(const ThreadLocalState& state);
41
+
42
+ private:
43
+ c10::impl::LocalDispatchKeySet dispatch_key_;
44
+
45
+ // ThreadLocalDebugInfo does not change after being created
46
+ // with DebugInfoGuard
47
+ std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
48
+
49
+ // RecordFunction TLS
50
+ RecordFunctionTLS rf_tls_;
51
+
52
+ // TLS for out-of-tree functorch
53
+ // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
54
+ // pointer (spoiler alert: it's due to the indirection)
55
+ // This needs to be a shared_ptr instead of a unique_ptr because
56
+ // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
57
+ // consider adding an explicit copy constructor for ThreadLocalState in the
58
+ // future but I didn't want to add one just for this.
59
+ std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
60
+
61
+ // TLS for AutogradModes
62
+ AutogradState autograd_tls_;
63
+
64
+ // TLS for enable_torch_dispatch_mode
65
+ c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
66
+
67
+ // TLS for enable_python_dispatcher
68
+ c10::impl::PyInterpreter* python_dispatcher_state_;
69
+
70
+ // TLS for __torch_function__ (mode and disable_torch_function)
71
+ at::impl::PythonTorchFunctionTLS python_torch_function_state_;
72
+
73
+ // TLS for saved tensors default hooks
74
+ at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
75
+
76
+ bool functionalization_reapply_views_state_;
77
+
78
+ // TLS for arbitrary python objects that is registered via hooks
79
+ at::impl::ThreadLocalPythonObjects saved_objects_;
80
+
81
+ friend class ThreadLocalStateGuard;
82
+ };
83
+
84
+ // Guard to set and reset the thread local state
85
+ class TORCH_API ThreadLocalStateGuard {
86
+ public:
87
+ explicit ThreadLocalStateGuard(const ThreadLocalState& state)
88
+ : prev_state_(ThreadLocalState()) {
89
+ // set the given state across the thread boundary
90
+ ThreadLocalState::setThreadLocalState(state);
91
+ }
92
+
93
+ ~ThreadLocalStateGuard() {
94
+ // restore previously set variables
95
+ ThreadLocalState::setThreadLocalState(prev_state_);
96
+ }
97
+
98
+ private:
99
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
100
+ const ThreadLocalState prev_state_;
101
+ };
102
+
103
+ template <typename T>
104
+ auto wrapPropagateTLSState(T callback) {
105
+ return [tls_state = ThreadLocalState(),
106
+ callback = std::move(callback)](auto&&... args) {
107
+ ThreadLocalStateGuard g(tls_state);
108
+ // Propagate value returned by callback().
109
+ return callback(std::forward<decltype(args)>(args)...);
110
+ };
111
+ }
112
+
113
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TypeDefault.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Dimname.h>
4
+ #include <c10/core/MemoryFormat.h>
5
+ #include <c10/core/QScheme.h>
6
+ #include <c10/core/Scalar.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/macros/Export.h>
9
+ #include <c10/util/ArrayRef.h>
10
+ #include <c10/util/intrusive_ptr.h>
11
+
12
+ namespace c10 {
13
+ struct Storage;
14
+ }
15
+
16
+ namespace at {
17
+
18
+ class Tensor;
19
+ using TensorList = ArrayRef<Tensor>;
20
+
21
+ class Context;
22
+ struct Generator;
23
+
24
+ struct Quantizer;
25
+ // This is temporary typedef to enable Quantizer in aten native function API
26
+ // we'll remove them when we are actually exposing Quantizer class
27
+ // to frontend
28
+ using ConstQuantizerPtr = const c10::intrusive_ptr<Quantizer>&;
29
+
30
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Version.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/Context.h>
2
+
3
+ namespace at {
4
+
5
+ /// Returns a detailed string describing the configuration PyTorch.
6
+ TORCH_API std::string show_config();
7
+
8
+ TORCH_API std::string get_mkl_version();
9
+
10
+ TORCH_API std::string get_mkldnn_version();
11
+
12
+ TORCH_API std::string get_openmp_version();
13
+
14
+ TORCH_API std::string get_cxx_flags();
15
+
16
+ TORCH_API std::string get_cpu_capability();
17
+
18
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Atomic.cuh ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda.h>
4
+ #include <c10/util/Half.h>
5
+ #include <c10/util/BFloat16.h>
6
+
7
+ #include <ATen/NumericUtils.h>
8
+
9
+ #if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
10
+ #include <cuda_bf16.h>
11
+ #endif
12
+
13
+ template <typename T>
14
+ struct AtomicFPOp;
15
+
16
+ template <>
17
+ struct AtomicFPOp<at::Half> {
18
+ template <typename func_t>
19
+ inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) {
20
+ unsigned int * address_as_ui =
21
+ (unsigned int *) ((char *)address - ((size_t)address & 2));
22
+ unsigned int old = *address_as_ui;
23
+ unsigned int assumed;
24
+
25
+ at::Half hsum;
26
+ do {
27
+ assumed = old;
28
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
29
+ hsum = func(hsum, val);
30
+ old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
31
+ old = atomicCAS(address_as_ui, assumed, old);
32
+ } while (assumed != old);
33
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
34
+ return hsum;
35
+ }
36
+ };
37
+
38
+ template <>
39
+ struct AtomicFPOp<at::BFloat16> {
40
+ template <typename func_t>
41
+ inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) {
42
+ unsigned int * address_as_ui =
43
+ (unsigned int *) ((char *)address - ((size_t)address & 2));
44
+ unsigned int old = *address_as_ui;
45
+ unsigned int assumed;
46
+
47
+ at::BFloat16 bsum;
48
+ do {
49
+ assumed = old;
50
+ bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
51
+ bsum = func(bsum, val);
52
+ old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x;
53
+ old = atomicCAS(address_as_ui, assumed, old);
54
+ } while (assumed != old);
55
+ bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
56
+ return bsum.x;
57
+ }
58
+ };
59
+
60
+ template <>
61
+ struct AtomicFPOp<double> {
62
+ template <typename func_t>
63
+ inline __device__ double operator() (double * address, double val, const func_t& func) {
64
+ unsigned long long int* address_as_ull = (unsigned long long int*)address;
65
+ unsigned long long int old = *address_as_ull;
66
+ unsigned long long int assumed;
67
+
68
+ do {
69
+ assumed = old;
70
+ old = atomicCAS(address_as_ull, assumed, func(val, assumed));
71
+ // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
72
+ } while (assumed != old);
73
+
74
+ return __longlong_as_double(old);
75
+ }
76
+ };
77
+
78
+ #define ATOMIC_INTEGER_IMPL(NAME) \
79
+ template <typename T, size_t n> \
80
+ struct Atomic##NAME##IntegerImpl; \
81
+ \
82
+ template<typename T> \
83
+ struct Atomic##NAME##IntegerImpl<T, 1> { \
84
+ template <typename func_t> \
85
+ inline __device__ void operator()(T *address, T val, const func_t& func) { \
86
+ size_t offset = (size_t)address & 3; \
87
+ uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
88
+ uint32_t old = *address_as_ui; \
89
+ uint32_t shift = offset * 8; \
90
+ uint32_t old_byte; \
91
+ uint32_t newval; \
92
+ uint32_t assumed; \
93
+ \
94
+ do { \
95
+ assumed = old; \
96
+ old_byte = (old >> shift) & 0xff; \
97
+ newval = static_cast<uint8_t>(func(val, static_cast<T>(old_byte))); \
98
+ newval = (old & ~(0x000000ff << shift)) | (newval << shift); \
99
+ old = atomicCAS(address_as_ui, assumed, newval); \
100
+ } while (assumed != old); \
101
+ } \
102
+ }; \
103
+ \
104
+ template<typename T> \
105
+ struct Atomic##NAME##IntegerImpl<T, 2> { \
106
+ template <typename func_t> \
107
+ inline __device__ void operator()(T *address, T val, const func_t& func) { \
108
+ size_t offset = (size_t)address & 2; \
109
+ uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
110
+ bool is_32_align = offset; \
111
+ uint32_t old = *address_as_ui; \
112
+ uint32_t old_bytes; \
113
+ uint32_t newval; \
114
+ uint32_t assumed; \
115
+ \
116
+ do { \
117
+ assumed = old; \
118
+ old_bytes = is_32_align ? old >> 16 : old & 0xffff; \
119
+ newval = static_cast<uint16_t>(func(val, static_cast<T>(old_bytes))); \
120
+ newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; \
121
+ old = atomicCAS(address_as_ui, assumed, newval); \
122
+ } while (assumed != old); \
123
+ } \
124
+ }; \
125
+ \
126
+ template<typename T> \
127
+ struct Atomic##NAME##IntegerImpl<T, 4> { \
128
+ template <typename func_t> \
129
+ inline __device__ void operator()(T *address, T val, const func_t& func) { \
130
+ uint32_t * address_as_ui = (uint32_t *) (address); \
131
+ uint32_t old = *address_as_ui; \
132
+ uint32_t newval; \
133
+ uint32_t assumed; \
134
+ \
135
+ do { \
136
+ assumed = old; \
137
+ newval = static_cast<uint32_t>(func(val, static_cast<T>(old))); \
138
+ old = atomicCAS(address_as_ui, assumed, newval); \
139
+ } while (assumed != old); \
140
+ } \
141
+ }; \
142
+ \
143
+ template<typename T> \
144
+ struct Atomic##NAME##IntegerImpl<T, 8> { \
145
+ template <typename func_t> \
146
+ inline __device__ void operator()(T *address, T val, const func_t& func) { \
147
+ unsigned long long * address_as_ui = (unsigned long long *) (address); \
148
+ unsigned long long old = *address_as_ui; \
149
+ unsigned long long newval; \
150
+ unsigned long long assumed; \
151
+ \
152
+ do { \
153
+ assumed = old; \
154
+ newval = static_cast<uint64_t>(func(val, static_cast<T>(old))); \
155
+ old = atomicCAS(address_as_ui, assumed, newval); \
156
+ } while (assumed != old); \
157
+ } \
158
+ };
159
+
160
+
161
+ # define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \
162
+ static inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \
163
+ Atomic##NAME##IntegerImpl<DTYPE, sizeof(DTYPE)>()(address, \
164
+ val, \
165
+ [](DTYPE a, DTYPE b) { \
166
+ return OP; \
167
+ }); \
168
+ } \
169
+
170
+ ATOMIC_INTEGER_IMPL(Add)
171
+ GPU_ATOMIC_INTEGER(Add, a || b, bool)
172
+
173
+ // Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64)
174
+ static inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) {
175
+ AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address,
176
+ val,
177
+ [](uint8_t a, uint8_t b) {
178
+ return a + b;
179
+ });
180
+ }
181
+
182
+ static inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) {
183
+ AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address,
184
+ val,
185
+ [](int8_t a, int8_t b) {
186
+ return a + b;
187
+ });
188
+ }
189
+
190
+ static inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) {
191
+ AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address,
192
+ val,
193
+ [](int16_t a, int16_t b) {
194
+ return a + b;
195
+ });
196
+ }
197
+
198
+ static inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
199
+ return atomicAdd(address, val);
200
+ }
201
+
202
+ static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
203
+ #if defined(USE_ROCM)
204
+ __atomic_fetch_add(address, val, __ATOMIC_RELAXED);
205
+ #else
206
+ static_assert(sizeof(unsigned long long int) == sizeof(int64_t), "bitwidth change is not allowed");
207
+ atomicAdd(reinterpret_cast<unsigned long long int *>(address), static_cast<unsigned long long int>(val));
208
+ #endif
209
+ }
210
+
211
+ static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
212
+ #if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
213
+ return AtomicFPOp<at::Half>()(address, val,
214
+ [](at::Half hsum, at::Half val) {
215
+ return hsum + val;
216
+ });
217
+ #else
218
+ return atomicAdd(reinterpret_cast<__half*>(address), val);
219
+ #endif
220
+ }
221
+
222
+ static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) {
223
+ #if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
224
+ return AtomicFPOp<at::BFloat16>()(address, val,
225
+ [](at::BFloat16 bsum, at::BFloat16 val) {
226
+ return bsum + val;
227
+ });
228
+ #else
229
+ __nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val));
230
+ return *reinterpret_cast<c10::BFloat16*>(&r);
231
+ #endif
232
+ }
233
+
234
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)
235
+ // from CUDA C Programmic Guide
236
+ static inline __device__ double atomicAdd(double* address, double val)
237
+ #if defined(__clang__) && defined(__CUDA__)
238
+ #pragma GCC diagnostic push
239
+ #pragma GCC diagnostic ignored "-Wgcc-compat"
240
+ __attribute__((enable_if(true, "")))
241
+ #pragma GCC diagnostic pop
242
+ #endif
243
+ {
244
+
245
+ return AtomicFPOp<double>()(address, val,
246
+ [](double val, unsigned long long int assumed) {
247
+ return __double_as_longlong(val + __longlong_as_double(assumed));
248
+ });
249
+ }
250
+ #elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__))
251
+
252
+ /* Note [hip-clang differences to hcc]
253
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
254
+ * The upcoming hip-clang compiler for ROCm differs from hcc in a few details.
255
+ * It exports the __HIP__ macro, we can hence differentiate between hcc and
256
+ * hip-clang. In the below, hcc only received support for atomicAdd with double
257
+ * typing after work week 18312. hip-clang had support from the first version.
258
+ * In general, the code-visible differences between hip-clang and hcc will be
259
+ * minimal.
260
+ */
261
+
262
+ #if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__
263
+ // This needs to be defined for the host side pass
264
+ static inline __device__ double atomicAdd(double *address, double val) { }
265
+ #endif
266
+ #endif
267
+
268
+ static inline __device__ double gpuAtomicAdd(double *address, double val) {
269
+ return atomicAdd(address, val);
270
+ }
271
+
272
+ static inline __device__ float gpuAtomicAdd(float *address, float val) {
273
+ return atomicAdd(address, val);
274
+ }
275
+
276
+ template<typename T>
277
+ static inline __device__ void gpuAtomicAdd(c10::complex<T> *address, c10::complex<T> val) {
278
+ gpuAtomicAdd(&address->real_, val.real_);
279
+ gpuAtomicAdd(&address->imag_, val.imag_);
280
+ }
281
+
282
+ /* Note [gpuAtomicAdd vs atomicAdd]
283
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
284
+ * Some extensions such as torchvision call atomicAdd()
285
+ * directly and require non-library provided data type support. Only for these, we
286
+ * continue to provide atomicAdd overloads.
287
+ */
288
+ static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
289
+ return gpuAtomicAdd(address, val);
290
+ }
291
+
292
+ static inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) {
293
+ return gpuAtomicAdd(address, val);
294
+ }
295
+
296
+ static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
297
+ gpuAtomicAdd(address, val);
298
+ }
299
+
300
+ static inline __device__ void atomicAdd(int8_t *address, int8_t val) {
301
+ gpuAtomicAdd(address, val);
302
+ }
303
+
304
+ static inline __device__ void atomicAdd(int16_t *address, int16_t val) {
305
+ gpuAtomicAdd(address, val);
306
+ }
307
+
308
+ static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
309
+ gpuAtomicAdd(address, val);
310
+ }
311
+
312
+ static inline __device__ void atomicAdd(bool *address, bool val) {
313
+ gpuAtomicAdd(address, val);
314
+ }
315
+
316
+ /* Note [explicitly non-returning atomics]
317
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
318
+ * AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet().
319
+ * Due to compiler limitations, callers must opt-in to guarantee the optimized instruction.
320
+ * This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd,
321
+ * therefore we need a new API 'gpuAtomicAddNoReturn'.
322
+ */
323
+ template<typename T>
324
+ static inline __device__ void gpuAtomicAddNoReturn(c10::complex<T> *address, c10::complex<T> val) { gpuAtomicAdd(address, val); }
325
+ static inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); }
326
+ static inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); }
327
+ static inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); }
328
+ static inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); }
329
+ static inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); }
330
+ static inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); }
331
+ static inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); }
332
+ static inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); }
333
+ static inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }
334
+
335
+ /* Special case fp32 atomic. */
336
+ #if defined(USE_ROCM)
337
+ static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); }
338
+ #else
339
+ static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
340
+ #endif
341
+
342
+ // Atomic multiplication implementation.
343
+
344
+ ATOMIC_INTEGER_IMPL(Mul)
345
+ GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t)
346
+ GPU_ATOMIC_INTEGER(Mul, a * b, int8_t)
347
+ GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
348
+ GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
349
+ GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
350
+
351
+ inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
352
+ return AtomicFPOp<at::Half>()(address, val,
353
+ [](at::Half bsum, at::Half val) {
354
+ return bsum * val;
355
+ });
356
+ }
357
+
358
+ inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) {
359
+ return AtomicFPOp<at::BFloat16>()(address, val,
360
+ [](at::BFloat16 bsum, at::BFloat16 val) {
361
+ return bsum * val;
362
+ });
363
+ }
364
+
365
+ inline __device__ double gpuAtomicMul(double * address, double val) {
366
+ return AtomicFPOp<double>()(address, val,
367
+ [](double val, unsigned long long int assumed) {
368
+ return __double_as_longlong(val * __longlong_as_double(assumed));
369
+ });
370
+ }
371
+
372
+ // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
373
+ inline __device__ float gpuAtomicMul (float * address, float val) {
374
+ unsigned int* address_as_ull = (unsigned int*)address;
375
+ unsigned int old = *address_as_ull;
376
+ unsigned int assumed;
377
+
378
+ do {
379
+ assumed = old;
380
+ old = atomicCAS(address_as_ull, assumed,
381
+ __float_as_int(val *
382
+ __int_as_float(assumed)));
383
+
384
+ // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
385
+ } while (assumed != old);
386
+
387
+ return __int_as_float(old);
388
+ }
389
+
390
+ // Atomic maximum implementation.
391
+
392
+ template <typename T>
393
+ __host__ __device__ T safe_max(T a, T b) {
394
+ #if defined(__HIPCC__)
395
+ // TODO: remove this special case for HIP when issue is fixed:
396
+ // https://github.com/ROCm-Developer-Tools/HIP/issues/2209
397
+ T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max<T>(a, b));
398
+ #else
399
+ T max = at::_isnan(b) ? b : std::max<T>(a, b);
400
+ #endif
401
+
402
+ return max;
403
+ }
404
+
405
+ ATOMIC_INTEGER_IMPL(Max)
406
+ GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
407
+ GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
408
+ GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t)
409
+ GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t)
410
+ GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t)
411
+
412
+ inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
413
+ return AtomicFPOp<at::Half>()(address, val,
414
+ [](at::Half bsum, at::Half val) {
415
+ return safe_max(bsum, val);
416
+ });
417
+ }
418
+
419
+ inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
420
+ return AtomicFPOp<at::BFloat16>()(address, val,
421
+ [](at::BFloat16 bsum, at::BFloat16 val) {
422
+ return safe_max(bsum, val);
423
+ });
424
+ }
425
+
426
+ inline __device__ double gpuAtomicMax(double * address, double val) {
427
+ return AtomicFPOp<double>()(address, val,
428
+ [](double val, unsigned long long int assumed) {
429
+ return __double_as_longlong(safe_max(val, __longlong_as_double(assumed)));
430
+ });
431
+ }
432
+
433
+ // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
434
+ inline __device__ float gpuAtomicMax(float * address, float val) {
435
+ unsigned int* address_as_ull = (unsigned int*)address;
436
+ unsigned int old = *address_as_ull;
437
+ unsigned int assumed;
438
+
439
+ do {
440
+ assumed = old;
441
+ old = atomicCAS(address_as_ull, assumed,
442
+ __float_as_int(safe_max(val, __int_as_float(assumed))));
443
+
444
+ // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
445
+ } while (assumed != old);
446
+
447
+ return __int_as_float(old);
448
+ }
449
+
450
+ // Atomic minimum implementation.
451
+
452
+ template <typename T>
453
+ __host__ __device__ T safe_min(T a, T b) {
454
+ #if defined(__HIPCC__)
455
+ // TODO: remove this special case for HIP when issue is fixed:
456
+ // https://github.com/ROCm-Developer-Tools/HIP/issues/2209
457
+ T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min<T>(a, b));
458
+ #else
459
+ T min = at::_isnan(b) ? b : std::min<T>(a, b);
460
+ #endif
461
+
462
+ return min;
463
+ }
464
+
465
+ ATOMIC_INTEGER_IMPL(Min)
466
+ GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
467
+ GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
468
+ GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t)
469
+ GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t)
470
+ GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t)
471
+
472
+ inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
473
+ return AtomicFPOp<at::Half>()(address, val,
474
+ [](at::Half bsum, at::Half val) {
475
+ return safe_min(bsum, val);
476
+ });
477
+ }
478
+
479
+ inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
480
+ return AtomicFPOp<at::BFloat16>()(address, val,
481
+ [](at::BFloat16 bsum, at::BFloat16 val) {
482
+ return safe_min(bsum, val);
483
+ });
484
+ }
485
+
486
+ inline __device__ double gpuAtomicMin(double * address, double val) {
487
+ return AtomicFPOp<double>()(address, val,
488
+ [](double val, unsigned long long int assumed) {
489
+ return __double_as_longlong(safe_min(val, __longlong_as_double(assumed)));
490
+ });
491
+ }
492
+
493
+ // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
494
+ inline __device__ float gpuAtomicMin(float * address, float val) {
495
+ unsigned int* address_as_ull = (unsigned int*)address;
496
+ unsigned int old = *address_as_ull;
497
+ unsigned int assumed;
498
+
499
+ do {
500
+ assumed = old;
501
+ old = atomicCAS(address_as_ull, assumed,
502
+ __float_as_int(safe_min(val, __int_as_float(assumed))));
503
+
504
+ // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
505
+ } while (assumed != old);
506
+
507
+ return __int_as_float(old);
508
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/cuda/ApplyGridUtils.cuh>
4
+ #include <ATen/cuda/detail/IndexUtils.cuh>
5
+ #include <ATen/core/TensorBase.h>
6
+ #include <ATen/ceil_div.h>
7
+ #include <ATen/cuda/Atomic.cuh>
8
+ #include <ATen/cuda/CUDAContext.h>
9
+ #include <c10/macros/Macros.h>
10
+ #include <ATen/native/Copy.h>
11
+
12
+ #include <math.h>
13
+
14
+ //
15
+ // This file contains pointwise operation functions and kernels that
16
+ // work on both contiguous and non-contiguous tensor arguments of
17
+ // arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without
18
+ // copying or temporary storage.
19
+ //
20
+
21
+ /*
22
+ NOTE [ CUDA_tensor_applyN helpers ]
23
+
24
+ The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4)
25
+ functions apply a pointwise operator to N tensor(s).
26
+
27
+ The calling convention is
28
+
29
+ 1. The template arguments should be, sequentially,
30
+ - First N typename args specify the scalar types of each of the N tensors.
31
+ - (Optional) `int step` arg specifies the number of elements processed
32
+ together at the same time.
33
+ Default is 1.
34
+ - A usually omitted (i.e., inferred) typename arg specifies the type of the
35
+ function/functor applied on `N * step` values in each iteration of each
36
+ CUDA thread.
37
+ 2. The arguments should be, sequentially,
38
+ - N tensors
39
+ - op: a function/functor that processes `N * step` values at the same time.
40
+ - If `step == 1`, it must have signature
41
+ `void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where
42
+ `scalar*_t`s are the first N typename template args, and the inputs
43
+ are the `N` values from the `N` tensors retrieved at a common index.
44
+ - Otherwise, it must must have signature
45
+ void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&, // repeat `step` times
46
+ scalar2_t&, scalar2_t&, ..., scalar2_t&, // repeat `step` times
47
+ ...,
48
+ scalarN_t&, scalarN_t&, ..., scalarN_t&) // repeat `step` times
49
+ Different from `step == 1` case, it processes `N * step` values taken
50
+ from `step` common indices. Moreover, the first input `n` represents the
51
+ number of valid indices (it will always have `0 < n <= step`). It will
52
+ almost always be `step`, but at the boundary we may not have full `step`
53
+ elements and `n` can be a lesser value.
54
+
55
+ E.g., if `step == 4` and `N == 2`, `op` could be
56
+
57
+ [](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4,
58
+ scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) {
59
+ // Only process u1, ..., un and v1, ..., vn.
60
+ // So if `n == 3`, `u4` and `v4` need not to be considered.
61
+ }
62
+
63
+ In both cases, the references can actually be const, but at least one of
64
+ them should be non-const in order to write the output.
65
+ - (Optional, but recommended) N TensorArgType args that specify for each
66
+ tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite),
67
+ or only reads (i.e., TensorArgType::ReadOnly).
68
+ Default is TensorArgType::ReadWrite for first Tensor, and
69
+ TensorArgType::ReadOnly for the rest.
70
+
71
+ E.g.,
72
+
73
+ to compute a = b^2 for a and b of same dtype, we can call
74
+
75
+ CUDA_tensor_apply2<scalar, scalar>(
76
+ a, b,
77
+ [] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; }
78
+ );
79
+
80
+ to work on 2 values at the same time, we can call
81
+
82
+ CUDA_tensor_apply2<scalar1, scalar2, 2>(
83
+ a, b,
84
+ [] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2,
85
+ const scalar2 &b_val1, const scalar2 &b_val2) {
86
+ // call special vectorized op here, or just do elementwise and enjoy unrolling...
87
+ // if n == 1, only process a_val1 and b_val1
88
+ }
89
+ );
90
+ */
91
+
92
+ namespace at::cuda {
93
+
94
+ // TODO: combine with TensorArg? So far that's been for debugging, and this is functional...
95
+ enum class TensorArgType { ReadWrite, ReadOnly };
96
+
97
+ namespace {
98
+
99
+ // Rearrange dimensions for pointwise operations so that strides are in
100
+ // decreasing order as much as possible, so that kernels have better memory
101
+ // access patterns.
102
+ //
103
+ // For example, consider a binary operation on two "transposed" 2-dim tensors:
104
+ // sizes: 256 512
105
+ // aInfo->strides: 1 256
106
+ // bInfo->strides: 1 256
107
+ //
108
+ // Given this, each concurrent memory access inside kernelPointwiseApply2() is
109
+ // exactly 256 elements apart, resulting in poor performance.
110
+ //
111
+ // This function exchanges dimensions so that memory access is contiguous:
112
+ // sizes: 512 256
113
+ // aInfo->strides: 256 1
114
+ // bInfo->strides: 256 1
115
+ //
116
+ // (Actually, it becomes even better because now collapseDims() can turn each
117
+ // input into one contiguous array.)
118
+ //
119
+ // In general, given M (<=4) TensorInfo's with N dimensions, we can view each
120
+ // strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange
121
+ // strides[i] and [j] if
122
+ // (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
123
+ // (exchanging them will benefit input #k), and
124
+ // (2) strides[i][k] <= strieds[j][k] for all k
125
+ // (exchanging them will not make any input worse).
126
+ template <typename T1, typename IndexType,
127
+ typename T2 = void, typename T3 = void, typename T4 = void>
128
+ inline void rearrangeDims(detail::TensorInfo<T1, IndexType>* aInfo,
129
+ detail::TensorInfo<T2, IndexType>* bInfo = nullptr,
130
+ detail::TensorInfo<T3, IndexType>* cInfo = nullptr,
131
+ detail::TensorInfo<T4, IndexType>* dInfo = nullptr) {
132
+ int numInfos = 1;
133
+ int dims = aInfo->dims;
134
+ IndexType *sizes[4] = { aInfo->sizes, };
135
+ IndexType *strides[4] = { aInfo->strides, };
136
+
137
+ if (bInfo != nullptr) {
138
+ ++numInfos;
139
+ if (bInfo->dims != dims) return;
140
+ sizes[1] = bInfo->sizes;
141
+ strides[1] = bInfo->strides;
142
+ }
143
+
144
+ if (cInfo != nullptr) {
145
+ ++numInfos;
146
+ if (cInfo->dims != dims) return;
147
+ sizes[2] = cInfo->sizes;
148
+ strides[2] = cInfo->strides;
149
+ }
150
+
151
+ if (dInfo != nullptr) {
152
+ ++numInfos;
153
+ if (dInfo->dims != dims) return;
154
+ sizes[3] = dInfo->sizes;
155
+ strides[3] = dInfo->strides;
156
+ }
157
+
158
+ // Bail out if sizes do not match: we are using "deprecated pointwise
159
+ // behavior" among tensors of different shapes but same number of elements.
160
+ for (int i = 1; i < numInfos; ++i) {
161
+ for (int j = 0; j < dims; ++j) {
162
+ if (sizes[i][j] != sizes[0][j]) return;
163
+ }
164
+ }
165
+
166
+ for (int i = 0; i < dims - 1; ++i) {
167
+ // No need to consider dimensions of size 1.
168
+ if (sizes[0][i] == 1) continue;
169
+
170
+ for (int j = i + 1; j < dims; ++j) {
171
+ if (sizes[0][j] == 1) continue;
172
+
173
+ // Compare the relative sizes of strides between dim #i and dim #j.
174
+ bool hasIncreasingStrides = false;
175
+ bool hasDecreasingStrides = false;
176
+
177
+ for (int k = 0; k < numInfos; k++) {
178
+ IndexType stride_i = strides[k][i];
179
+ IndexType stride_j = strides[k][j];
180
+ if (stride_i < stride_j) {
181
+ hasIncreasingStrides = true;
182
+ } else if (stride_i > stride_j) {
183
+ hasDecreasingStrides = true;
184
+ }
185
+ }
186
+
187
+ if (hasIncreasingStrides && !hasDecreasingStrides) {
188
+ for (int k = 0; k < numInfos; k++) {
189
+ IndexType size = sizes[k][i];
190
+ sizes[k][i] = sizes[k][j];
191
+ sizes[k][j] = size;
192
+
193
+ IndexType stride = strides[k][i];
194
+ strides[k][i] = strides[k][j];
195
+ strides[k][j] = stride;
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ // The `remaining_steps` argument is used to support Op that operates on
203
+ // multiple elements at the same time. Generally, the strategy of ApplyOpN is to
204
+ // 1. Initialize `remaining_steps = step`, where `step` is the template arg of
205
+ // CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the
206
+ // number of elements in bound for this call. It will almost always equal to
207
+ // `step` except at boundaries.
208
+ // 2. If `remaining_steps > 0` convert the current linearIndex to offset (if in
209
+ // bound), and recursively call `ApplyOpN` with `remaining_steps - 1`.
210
+ // 3. At `remaining_steps = 0`,
211
+ // if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`;
212
+ // if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep,
213
+ // tensor2_val1, tensor2_val2, ..., tesor2_valstep,
214
+ // ...
215
+ // tensorN_val1, tensorN_val2, ..., tesorN_valstep);`
216
+ //
217
+ // See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like.
218
+
219
+ template <typename Op,
220
+ typename scalar,
221
+ typename IndexType,
222
+ int ADims,
223
+ int remaining_steps,
224
+ typename... Offsets>
225
+ struct ApplyOp1 {
226
+ __device__ __forceinline__
227
+ static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
228
+ IndexType linearIndex, Offsets... aOffsets) {
229
+ // Convert `linearIndex` into an offset of `a`
230
+ const IndexType aOffset = sizeof...(Offsets) < n ?
231
+ detail::IndexToOffset<scalar, IndexType, ADims>::get(linearIndex, a) : 0;
232
+
233
+ ApplyOp1<Op, scalar, IndexType, ADims, remaining_steps - 1, const IndexType, Offsets...>::apply(
234
+ a, op, n, linearIndex + 1, aOffsets..., aOffset
235
+ );
236
+ }
237
+ };
238
+
239
+ // Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
240
+ // We don't need to pass in how many elements need to processed in this case.
241
+ template <typename Op,
242
+ typename scalar,
243
+ typename IndexType,
244
+ int ADims,
245
+ typename Offset>
246
+ struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offset> {
247
+ __device__ __forceinline__
248
+ static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op,
249
+ int n, IndexType linearIndex, Offset offset) {
250
+ op(a.data[offset]);
251
+ }
252
+ };
253
+
254
+ template <typename Op,
255
+ typename scalar,
256
+ typename IndexType,
257
+ int ADims,
258
+ typename... Offsets>
259
+ struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offsets...> {
260
+ __device__ __forceinline__
261
+ static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
262
+ IndexType linearIndex, Offsets... offsets) {
263
+ op(n, a.data[offsets]...);
264
+ }
265
+ };
266
+
267
+ template <typename Op,
268
+ typename scalar,
269
+ typename IndexType,
270
+ int ADims,
271
+ int step>
272
+ #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
273
+ C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
274
+ #endif
275
+ __global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
276
+ IndexType totalElements, const Op op) {
277
+ for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
278
+ linearIndex < totalElements;
279
+ linearIndex += gridDim.x * blockDim.x * step) {
280
+ ApplyOp1<Op, scalar, IndexType, ADims, step>::apply(
281
+ a, op, ::min(step, static_cast<int>(totalElements - linearIndex)), linearIndex);
282
+ }
283
+ }
284
+
285
+
286
+ template <typename Op,
287
+ typename scalar1,
288
+ typename scalar2,
289
+ typename IndexType,
290
+ int ADims,
291
+ int BDims,
292
+ int remaining_steps,
293
+ typename... Offsets>
294
+ struct ApplyOp2 {
295
+ __device__ __forceinline__
296
+ static void apply(detail::TensorInfo<scalar1, IndexType> &a,
297
+ detail::TensorInfo<scalar2, IndexType> &b,
298
+ const Op &op, int64_t n, IndexType linearIndex,
299
+ Offsets... aOffsets, Offsets... bOffsets) {
300
+ // Convert `linearIndex` into an offset of `a`
301
+ const IndexType aOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
302
+ detail::IndexToOffset<scalar1, IndexType, ADims>::get(linearIndex, a) : 0;
303
+
304
+ // Convert `linearIndex` into an offset of `b`
305
+ const IndexType bOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
306
+ detail::IndexToOffset<scalar2, IndexType, BDims>::get(linearIndex, b) : 0;
307
+
308
+ ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, remaining_steps - 1, const IndexType, Offsets...>::apply(
309
+ a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset
310
+ );
311
+ }
312
+ };
313
+
314
+ // Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
315
+ // We don't need to pass in how many elements need to processed in this case.
316
+ template <typename Op,
317
+ typename scalar1,
318
+ typename scalar2,
319
+ typename IndexType,
320
+ int ADims,
321
+ int BDims,
322
+ typename Offset>
323
+ struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offset> {
324
+ __device__ __forceinline__
325
+ static void apply(detail::TensorInfo<scalar1, IndexType> &a,
326
+ detail::TensorInfo<scalar2, IndexType> &b,
327
+ const Op &op, int /*n*/, IndexType /*linearIndex*/,
328
+ Offset aOffset, Offset bOffset) {
329
+ op(a.data[aOffset], b.data[bOffset]);
330
+ }
331
+ };
332
+
333
+ template <typename Op,
334
+ typename scalar1,
335
+ typename scalar2,
336
+ typename IndexType,
337
+ int ADims,
338
+ int BDims,
339
+ typename... Offsets>
340
+ struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offsets...> {
341
+ __device__ __forceinline__
342
+ static void apply(detail::TensorInfo<scalar1, IndexType> &a,
343
+ detail::TensorInfo<scalar2, IndexType> &b,
344
+ const Op &op, int n, IndexType linearIndex,
345
+ Offsets... aOffsets, Offsets... bOffsets) {
346
+ op(n, a.data[aOffsets]..., b.data[bOffsets]...);
347
+ }
348
+ };
349
+
350
+ template <typename Op,
351
+ typename scalar1,
352
+ typename scalar2,
353
+ typename IndexType,
354
+ int ADims, int BDims,
355
+ int step,
356
+ int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
357
+ int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
358
+ #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
359
+ C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm)
360
+ #endif
361
+ __global__ void
362
+ kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
363
+ detail::TensorInfo<scalar2, IndexType> b,
364
+ IndexType totalElements,
365
+ const Op op) {
366
+ for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
367
+ linearIndex < totalElements;
368
+ linearIndex += gridDim.x * blockDim.x * step) {
369
+ ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, step>::apply(
370
+ a, b, op, ::min(step, static_cast<int>(totalElements - linearIndex)),
371
+ linearIndex);
372
+ }
373
+ }
374
+
375
+ } // anonymous namespace
376
+
377
+ template <typename scalar1, typename scalar2, int step, typename Op,
378
+ int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
379
+ int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
380
+ inline bool CUDA_tensor_apply2(at::TensorBase a,
381
+ at::TensorBase b,
382
+ const Op op,
383
+ TensorArgType aType = TensorArgType::ReadWrite,
384
+ TensorArgType bType = TensorArgType::ReadOnly) {
385
+ TORCH_CHECK(a.device().is_cuda() && b.device().is_cuda(),
386
+ "CUDA_tensor_apply2: Expected tensors to have CUDA DeviceType, but got "
387
+ "tensors with type ", a.device().type(), " and ", b.device().type());
388
+ int64_t totalElements = a.numel();
389
+
390
+ if (totalElements != b.numel()) {
391
+ return false;
392
+ }
393
+
394
+ if (a.dim() > MAX_TENSORINFO_DIMS ||
395
+ b.dim() > MAX_TENSORINFO_DIMS) {
396
+ return false;
397
+ }
398
+
399
+ if (a.numel() == 0) {
400
+ // Empty tensor; do nothing
401
+ return true;
402
+ }
403
+ const dim3 block = getApplyBlock(max_threads_per_block);
404
+
405
+ dim3 grid;
406
+ auto curDevice = current_device();
407
+ if (curDevice == -1) return false;
408
+ if (!getApplyGrid<step>(totalElements, grid, curDevice, max_threads_per_block)) {
409
+ return false;
410
+ }
411
+
412
+ /*
413
+ Expands readable/writable tensors whose indices may be "overlapped."
414
+ This ensures that each element of the tensor is operated on once and only
415
+ once.
416
+ */
417
+ TensorBase oldA;
418
+ TensorBase oldB;
419
+
420
+ if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
421
+ // Must perform in contiguous space
422
+ oldA = std::exchange(a, a.contiguous());
423
+ }
424
+ if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) {
425
+ // Must perform in contiguous space
426
+ oldB = std::exchange(b, b.contiguous());
427
+ }
428
+
429
+ // It is possible that the tensor dimensions are able to be collapsed,
430
+ // and thus we can reduce the actual code complexity of the copy by
431
+ // exploiting this knowledge statically, since the div/mod is the
432
+ // most expensive part of the operation, more so than memory accesses.
433
+ // For instance, when copying a non-contiguous to a contiguous tensor
434
+ // (or vice versa), the contiguous tensor can be collapsed to one
435
+ // dimension, and the loop to translate the linear index to the array
436
+ // index can be similarly collapsed. That is what this unrolling is for.
437
+
438
+ #define HANDLE_CASE(TYPE, A, B) \
439
+ kernelPointwiseApply2<Op, \
440
+ scalar1, \
441
+ scalar2, \
442
+ TYPE, A, B, step, \
443
+ max_threads_per_block, \
444
+ min_blocks_per_sm> \
445
+ <<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>( \
446
+ aInfo, bInfo, static_cast<TYPE>(totalElements), op); \
447
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
448
+
449
+ #define HANDLE_B_CASE(TYPE, A, B) { \
450
+ switch (B) { \
451
+ case 1: \
452
+ HANDLE_CASE(TYPE, A, 1); \
453
+ break; \
454
+ case 2: \
455
+ HANDLE_CASE(TYPE, A, 2); \
456
+ break; \
457
+ default: \
458
+ HANDLE_CASE(TYPE, A, -1); \
459
+ break; \
460
+ } \
461
+ }
462
+
463
+ #define HANDLE_A_CASE(TYPE, A, B) { \
464
+ switch (A) { \
465
+ case 1: \
466
+ HANDLE_B_CASE(TYPE, 1, B); \
467
+ break; \
468
+ case 2: \
469
+ HANDLE_B_CASE(TYPE, 2, B); \
470
+ break; \
471
+ default: \
472
+ HANDLE_B_CASE(TYPE, -1, B); \
473
+ break; \
474
+ } \
475
+ }
476
+
477
+ if (detail::canUse32BitIndexMath(a) &&
478
+ detail::canUse32BitIndexMath(b)) {
479
+ detail::TensorInfo<scalar1, unsigned int> aInfo =
480
+ detail::getTensorInfo<scalar1, unsigned int>(a);
481
+
482
+ detail::TensorInfo<scalar2, unsigned int> bInfo =
483
+ detail::getTensorInfo<scalar2, unsigned int>(b);
484
+ rearrangeDims(&aInfo, &bInfo);
485
+ aInfo.collapseDims();
486
+ bInfo.collapseDims();
487
+
488
+ HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
489
+ } else {
490
+ detail::TensorInfo<scalar1, uint64_t> aInfo =
491
+ detail::getTensorInfo<scalar1, uint64_t>(a);
492
+
493
+ detail::TensorInfo<scalar2, uint64_t> bInfo =
494
+ detail::getTensorInfo<scalar2, uint64_t>(b);
495
+ rearrangeDims(&aInfo, &bInfo);
496
+ aInfo.collapseDims();
497
+ bInfo.collapseDims();
498
+
499
+ /*
500
+ Only instantiates the all 1D special case and the fallback all nD case for
501
+ large (64-bit indexed) tensors to reduce compilation time.
502
+ */
503
+ if (aInfo.dims == 1 && bInfo.dims == 1) {
504
+ HANDLE_CASE(uint64_t, 1, 1);
505
+ } else {
506
+ HANDLE_CASE(uint64_t, -1, -1);
507
+ }
508
+ }
509
+ #undef HANDLE_CASE
510
+ #undef HANDLE_B_CASE
511
+ #undef HANDLE_A_CASE
512
+
513
+ if (oldA.defined()) {
514
+ at::native::copy_ignoring_overlaps(oldA, a);
515
+ }
516
+
517
+ if (oldB.defined()) {
518
+ at::native::copy_ignoring_overlaps(oldB, b);
519
+ }
520
+
521
+ return true;
522
+ }
523
+
524
+ /* Provides default step = 1 to CUDA_tensor_apply2. */
525
+ template <typename scalar1, typename scalar2, typename Op,
526
+ int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
527
+ int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
528
+ inline bool CUDA_tensor_apply2(const at::TensorBase &a,
529
+ const at::TensorBase &b,
530
+ const Op op,
531
+ TensorArgType aType = TensorArgType::ReadWrite,
532
+ TensorArgType bType = TensorArgType::ReadOnly) {
533
+ return CUDA_tensor_apply2<scalar1, scalar2, 1, Op,
534
+ max_threads_per_block, min_blocks_per_sm>(a, b, op, aType, bType);
535
+ }
536
+
537
+ } // namespace at::cuda
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Generator.h>
4
+ #include <ATen/cuda/PhiloxCudaState.h>
5
+ #include <ATen/Context.h>
6
+ #include <limits>
7
+ #include <atomic>
8
+
9
+ namespace at {
10
+ /**
11
+ * Note [CUDA Graph-safe RNG states]
12
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
13
+ *
14
+ * Strategy:
15
+ * ~~~~~~~~~
16
+ * (It helps to look at
17
+ * cuda/detail/PhiloxCudaStateRaw.cuh and
18
+ * cuda/detail/UnpackRaw.cuh
19
+ * while you read this.)
20
+ *
21
+ * A CUDA graph containing multiple RNG ops behaves like a
22
+ * single giant kernel from the perspective of ops external
23
+ * to the graph. During graph capture, logic in CUDAGeneratorImpl
24
+ * records the total of all offset increments that occur in the
25
+ * graphed region, and records the final total as the offset for
26
+ * the entire graph.
27
+ *
28
+ * When the graph reruns, the logic that reruns it
29
+ * increments this device's CUDA generator's offset
30
+ * by that total.
31
+ *
32
+ * Meanwhile, within the graph, at capture time, instead of
33
+ * populating PhiloxCudaStates with the uint64_t offset pulled
34
+ * directly from the global state, PhiloxCudaState uses a pointer
35
+ * to a one-element stream-local int64_t device tensor
36
+ * holding an initial offset value, and a uint64_t holding an
37
+ * intra-graph offset. (The intra-graph offset starts from zero
38
+ * when capture begins.) In each consumer kernel,
39
+ * at::cuda::philox::unpack computes the offset to use for this kernel
40
+ * as intra-graph offset + *initial offset.
41
+ *
42
+ * When the graph reruns, the logic that reruns it first
43
+ * fill_s the initial offset tensor with this device's
44
+ * CUDA generator's current offset.
45
+ *
46
+ * The control flow above ensures graphed execution is bitwise
47
+ * identical to eager execution as long as RNG ops are enqueued
48
+ * from a single thread, even if RNG ops and graphs containing
49
+ * RNG ops are enqueued and run simultaneously on multiple streams.
50
+ *
51
+ * Usage:
52
+ * ~~~~~~
53
+ * PhiloxCudaState in this file, and unpack() in
54
+ * cuda/CUDAGraphsUtils.cuh allow non-divergent use of
55
+ * CUDAGeneratorImpl whether graph capture is underway or not.
56
+ *
57
+ * Each PhiloxCudaState instance should be used for one and only one
58
+ * consumer kernel.
59
+ *
60
+ * Example (see e.g. native/cuda/Dropout.cu):
61
+ *
62
+ * #include <ATen/cuda/CUDAGeneratorImpl.h>
63
+ * #include <ATen/cuda/CUDAGraphsUtils.cuh>
64
+ *
65
+ * __global__ void kernel(..., PhiloxCudaState philox_args) {
66
+ * auto seeds = at::cuda::philox::unpack(philox_args);
67
+ * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
68
+ * curandStatePhilox4_32_10_t state;
69
+ * curand_init(std::get<0>(seeds), // seed
70
+ * idx, // per-thread subsequence
71
+ * std::get<1>(seeds), // offset in subsequence
72
+ * &state);
73
+ * ...
74
+ * }
75
+ *
76
+ * host_caller(...) {
77
+ * PhiloxCudaState rng_engine_inputs;
78
+ * {
79
+ * // See Note [Acquire lock when using random generators]
80
+ * std::lock_guard<std::mutex> lock(gen->mutex_);
81
+ *
82
+ * // gen could be HostState or DevState here! No divergent code needed!
83
+ * rng_engine_inputs = gen->philox_cuda_state(offset_increment);
84
+ * }
85
+ * kernel<<<...>>>(..., rng_engine_inputs);
86
+ * }
87
+ *
88
+ */
89
+
90
+ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
91
+ // Constructors
92
+ CUDAGeneratorImpl(DeviceIndex device_index = -1);
93
+ ~CUDAGeneratorImpl() override = default;
94
+
95
+ // CUDAGeneratorImpl methods
96
+ std::shared_ptr<CUDAGeneratorImpl> clone() const;
97
+ void set_current_seed(uint64_t seed) override;
98
+ void set_offset(uint64_t offset) override;
99
+ uint64_t get_offset() const override;
100
+ uint64_t current_seed() const override;
101
+ uint64_t seed() override;
102
+ void set_state(const c10::TensorImpl& new_state) override;
103
+ c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
104
+ void set_philox_offset_per_thread(uint64_t offset);
105
+ uint64_t philox_offset_per_thread() const;
106
+ void capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph);
107
+ uint64_t capture_epilogue();
108
+ PhiloxCudaState philox_cuda_state(uint64_t increment);
109
+
110
+ bool reset_rnn_state() {
111
+ return !no_reset_rnn_state_.test_and_set();
112
+ }
113
+
114
+ // Temporarily accommodates call sites that use philox_engine_inputs.
115
+ // Allows incremental refactor of call sites to use philox_cuda_state.
116
+ std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
117
+
118
+ static c10::DeviceType device_type();
119
+
120
+ private:
121
+ CUDAGeneratorImpl* clone_impl() const override;
122
+ uint64_t seed_ = default_rng_seed_val;
123
+ uint64_t philox_offset_per_thread_ = 0;
124
+ int64_t* seed_extragraph_{};
125
+ int64_t* offset_extragraph_{};
126
+ uint32_t offset_intragraph_ = 0;
127
+ bool graph_expects_this_gen_ = false;
128
+ std::atomic_flag no_reset_rnn_state_;
129
+ };
130
+
131
+ namespace cuda::detail {
132
+
133
+ TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator(
134
+ DeviceIndex device_index = -1);
135
+ TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1);
136
+
137
+ } // namespace cuda::detail
138
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGraph.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Tensor.h>
4
+ #include <c10/core/Device.h>
5
+ #include <c10/cuda/CUDAGraphsC10Utils.h>
6
+ #include <c10/cuda/CUDAStream.h>
7
+
8
+ #include <mutex>
9
+
10
+ namespace at {
11
+
12
+ struct CUDAGeneratorImpl;
13
+
14
+ namespace cuda {
15
+
16
+ // Standalone way to get a unique mempool id usable as a pool=... argument
17
+ // to CUDAGraph::capture_begin
18
+ TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
19
+
20
+ struct TORCH_CUDA_CPP_API CUDAGraph {
21
+ CUDAGraph();
22
+ ~CUDAGraph();
23
+
24
+ static void inc_pending_event_queries();
25
+ static void dec_pending_event_queries();
26
+ static int num_pending_event_queries();
27
+ void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
28
+ void capture_end();
29
+ void replay();
30
+ void reset();
31
+ MempoolId_t pool();
32
+ void enable_debug_mode();
33
+ void debug_dump(const std::string& debug_path);
34
+
35
+ protected:
36
+ #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
37
+ cudaGraph_t graph_ = NULL;
38
+ cudaGraphExec_t graph_exec_ = NULL;
39
+ #endif
40
+
41
+ static std::atomic<int> pending_event_queries;
42
+
43
+ // internal states so reset() can do its best cleaning up
44
+ // Set to true in capture_end if cudaStreamEndCapture succeeded
45
+ // Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate
46
+ // to create graph_exec_, then graph_ is deleted
47
+ bool has_graph_ = false;
48
+ // Set to true in capture_end if cudaGraphInstantiate succeeded
49
+ bool has_graph_exec_ = false;
50
+
51
+ // uuid of this instance's current capture, used to
52
+ // specify the pool.
53
+ CaptureId_t id_;
54
+
55
+ // the ID assigned by cuda during graph capture,
56
+ // used to identify when a stream is participating in capture
57
+ CaptureId_t capture_id_ = -1;
58
+
59
+ // uuid used to request a particular private mempool from CUDACachingAllocator.
60
+ // By default, this will be set to {id_, 0}.
61
+ //
62
+ // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_
63
+ // will be set to the other graph's mempool_id_, and therefore share a mempool with the
64
+ // other graph.
65
+ //
66
+ // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(),
67
+ // it will share a mempool with any other captures that used "pool=handle".
68
+ //
69
+ // Sharing a mempool across graphs saves memory, and it's safe if you
70
+ // know you'll replay those graphs in the same order you captured them.
71
+ MempoolId_t mempool_id_;
72
+
73
+ // Stream on which capture began
74
+ at::cuda::CUDAStream capture_stream_;
75
+
76
+ // Default generator on device where capture began
77
+ at::CUDAGeneratorImpl* capture_gen_;
78
+
79
+ // Device where capture occurred. Right now, for simplicity, we require all ops
80
+ // in a capture to run on the same device, but this is a limitation of CUDAGraph,
81
+ // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
82
+ // captures if needed.
83
+ int capture_dev_;
84
+
85
+ // RNG state trackers
86
+ at::Tensor seed_extragraph_;
87
+ at::Tensor offset_extragraph_;
88
+ uint64_t wholegraph_increment_;
89
+ };
90
+
91
+ } // namespace cuda
92
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/cuda/CUDAGeneratorImpl.h>
4
+ #include <ATen/cuda/CUDAEvent.h>
5
+ #include <ATen/cuda/PhiloxUtils.cuh>
6
+ #include <ATen/cuda/detail/CUDAHooks.h>
7
+ #include <ATen/detail/CUDAHooksInterface.h>
8
+ #include <c10/core/StreamGuard.h>
9
+ #include <c10/cuda/CUDAGraphsC10Utils.h>
10
+ #include <c10/cuda/CUDAGuard.h>
11
+
12
+ // c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten.
13
+ // This file adds utils used by aten only.
14
+
15
+ namespace at::cuda {
16
+
17
+ using CaptureId_t = c10::cuda::CaptureId_t;
18
+ using CaptureStatus = c10::cuda::CaptureStatus;
19
+
20
+ // Use this version where you don't want to create a CUDA context if none exists.
21
+ inline CaptureStatus currentStreamCaptureStatus() {
22
+ #if !defined(USE_ROCM) || ROCM_VERSION >= 50300
23
+ // don't create a context if we don't have to
24
+ if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) {
25
+ return c10::cuda::currentStreamCaptureStatusMayInitCtx();
26
+ } else {
27
+ return CaptureStatus::None;
28
+ }
29
+ #else
30
+ return CaptureStatus::None;
31
+ #endif
32
+ }
33
+
34
+ inline void assertNotCapturing(std::string attempt) {
35
+ auto status = currentStreamCaptureStatus();
36
+ TORCH_CHECK(status == CaptureStatus::None,
37
+ attempt,
38
+ " during CUDA graph capture. If you need this call to be captured, "
39
+ "please file an issue. "
40
+ "Current cudaStreamCaptureStatus: ",
41
+ status);
42
+ }
43
+
44
+ inline void errorIfCapturingCudnnBenchmark(std::string version_specific) {
45
+ auto status = currentStreamCaptureStatus();
46
+ TORCH_CHECK(status == CaptureStatus::None,
47
+ "Current cudaStreamCaptureStatus: ",
48
+ status,
49
+ "\nCapturing ",
50
+ version_specific,
51
+ "is prohibited. Possible causes of this error:\n"
52
+ "1. No warmup iterations occurred before capture.\n"
53
+ "2. The convolutions you're trying to capture use dynamic shapes, "
54
+ "in which case capturing them is generally prohibited.");
55
+ }
56
+
57
+ } // namespace at::cuda
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /*
4
+ Provides a subset of cuSPARSE functions as templates:
5
+
6
+ csrgeam2<scalar_t>(...)
7
+
8
+ where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
9
+ The functions are available in at::cuda::sparse namespace.
10
+ */
11
+
12
+ #include <ATen/cuda/CUDAContext.h>
13
+ #include <ATen/cuda/CUDASparse.h>
14
+
15
+ namespace at::cuda::sparse {
16
+
17
+ #define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t) \
18
+ cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
19
+ const cusparseMatDescr_t descrA, int nnzA, \
20
+ const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
21
+ const int *csrSortedColIndA, const scalar_t *beta, \
22
+ const cusparseMatDescr_t descrB, int nnzB, \
23
+ const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
24
+ const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
25
+ const scalar_t *csrSortedValC, const int *csrSortedRowPtrC, \
26
+ const int *csrSortedColIndC, size_t *pBufferSizeInBytes
27
+
28
+ template <typename scalar_t>
29
+ inline void csrgeam2_bufferSizeExt(
30
+ CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) {
31
+ TORCH_INTERNAL_ASSERT(
32
+ false,
33
+ "at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ",
34
+ typeid(scalar_t).name());
35
+ }
36
+
37
+ template <>
38
+ void csrgeam2_bufferSizeExt<float>(
39
+ CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float));
40
+ template <>
41
+ void csrgeam2_bufferSizeExt<double>(
42
+ CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double));
43
+ template <>
44
+ void csrgeam2_bufferSizeExt<c10::complex<float>>(
45
+ CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>));
46
+ template <>
47
+ void csrgeam2_bufferSizeExt<c10::complex<double>>(
48
+ CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>));
49
+
50
+ #define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES() \
51
+ cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, \
52
+ int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, \
53
+ const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \
54
+ const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
55
+ int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace
56
+
57
+ template <typename scalar_t>
58
+ inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) {
59
+ TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz(
60
+ handle,
61
+ m,
62
+ n,
63
+ descrA,
64
+ nnzA,
65
+ csrSortedRowPtrA,
66
+ csrSortedColIndA,
67
+ descrB,
68
+ nnzB,
69
+ csrSortedRowPtrB,
70
+ csrSortedColIndB,
71
+ descrC,
72
+ csrSortedRowPtrC,
73
+ nnzTotalDevHostPtr,
74
+ workspace));
75
+ }
76
+
77
+ #define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t) \
78
+ cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
79
+ const cusparseMatDescr_t descrA, int nnzA, \
80
+ const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
81
+ const int *csrSortedColIndA, const scalar_t *beta, \
82
+ const cusparseMatDescr_t descrB, int nnzB, \
83
+ const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
84
+ const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
85
+ scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \
86
+ void *pBuffer
87
+
88
+ template <typename scalar_t>
89
+ inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) {
90
+ TORCH_INTERNAL_ASSERT(
91
+ false,
92
+ "at::cuda::sparse::csrgeam2: not implemented for ",
93
+ typeid(scalar_t).name());
94
+ }
95
+
96
+ template <>
97
+ void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float));
98
+ template <>
99
+ void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double));
100
+ template <>
101
+ void csrgeam2<c10::complex<float>>(
102
+ CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>));
103
+ template <>
104
+ void csrgeam2<c10::complex<double>>(
105
+ CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>));
106
+
107
+ #define CUSPARSE_BSRMM_ARGTYPES(scalar_t) \
108
+ cusparseHandle_t handle, cusparseDirection_t dirA, \
109
+ cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \
110
+ int kb, int nnzb, const scalar_t *alpha, \
111
+ const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
112
+ const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
113
+ const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc
114
+
115
+ template <typename scalar_t>
116
+ inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) {
117
+ TORCH_INTERNAL_ASSERT(
118
+ false,
119
+ "at::cuda::sparse::bsrmm: not implemented for ",
120
+ typeid(scalar_t).name());
121
+ }
122
+
123
+ template <>
124
+ void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float));
125
+ template <>
126
+ void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double));
127
+ template <>
128
+ void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>));
129
+ template <>
130
+ void bsrmm<c10::complex<double>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>));
131
+
132
+ #define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \
133
+ cusparseHandle_t handle, cusparseDirection_t dirA, \
134
+ cusparseOperation_t transA, int mb, int nb, int nnzb, \
135
+ const scalar_t *alpha, const cusparseMatDescr_t descrA, \
136
+ const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
137
+ int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
138
+
139
+ template <typename scalar_t>
140
+ inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
141
+ TORCH_INTERNAL_ASSERT(
142
+ false,
143
+ "at::cuda::sparse::bsrmv: not implemented for ",
144
+ typeid(scalar_t).name());
145
+ }
146
+
147
+ template <>
148
+ void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float));
149
+ template <>
150
+ void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double));
151
+ template <>
152
+ void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>));
153
+ template <>
154
+ void bsrmv<c10::complex<double>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>));
155
+
156
+ #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
157
+
158
+ #define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t) \
159
+ cusparseHandle_t handle, cusparseDirection_t dirA, \
160
+ cusparseOperation_t transA, int mb, int nnzb, \
161
+ const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
162
+ const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
163
+ bsrsv2Info_t info, int *pBufferSizeInBytes
164
+
165
+ template <typename scalar_t>
166
+ inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) {
167
+ TORCH_INTERNAL_ASSERT(
168
+ false,
169
+ "at::cuda::sparse::bsrsv2_bufferSize: not implemented for ",
170
+ typeid(scalar_t).name());
171
+ }
172
+
173
+ template <>
174
+ void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float));
175
+ template <>
176
+ void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double));
177
+ template <>
178
+ void bsrsv2_bufferSize<c10::complex<float>>(
179
+ CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>));
180
+ template <>
181
+ void bsrsv2_bufferSize<c10::complex<double>>(
182
+ CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>));
183
+
184
+ #define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t) \
185
+ cusparseHandle_t handle, cusparseDirection_t dirA, \
186
+ cusparseOperation_t transA, int mb, int nnzb, \
187
+ const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
188
+ const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
189
+ bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
190
+
191
+ template <typename scalar_t>
192
+ inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) {
193
+ TORCH_INTERNAL_ASSERT(
194
+ false,
195
+ "at::cuda::sparse::bsrsv2_analysis: not implemented for ",
196
+ typeid(scalar_t).name());
197
+ }
198
+
199
+ template <>
200
+ void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float));
201
+ template <>
202
+ void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double));
203
+ template <>
204
+ void bsrsv2_analysis<c10::complex<float>>(
205
+ CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>));
206
+ template <>
207
+ void bsrsv2_analysis<c10::complex<double>>(
208
+ CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>));
209
+
210
+ #define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t) \
211
+ cusparseHandle_t handle, cusparseDirection_t dirA, \
212
+ cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \
213
+ const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
214
+ const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
215
+ bsrsv2Info_t info, const scalar_t *x, scalar_t *y, \
216
+ cusparseSolvePolicy_t policy, void *pBuffer
217
+
218
+ template <typename scalar_t>
219
+ inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) {
220
+ TORCH_INTERNAL_ASSERT(
221
+ false,
222
+ "at::cuda::sparse::bsrsv2_solve: not implemented for ",
223
+ typeid(scalar_t).name());
224
+ }
225
+
226
+ template <>
227
+ void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float));
228
+ template <>
229
+ void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double));
230
+ template <>
231
+ void bsrsv2_solve<c10::complex<float>>(
232
+ CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>));
233
+ template <>
234
+ void bsrsv2_solve<c10::complex<double>>(
235
+ CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>));
236
+
237
+ #define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t) \
238
+ cusparseHandle_t handle, cusparseDirection_t dirA, \
239
+ cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
240
+ int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
241
+ const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
242
+ bsrsm2Info_t info, int *pBufferSizeInBytes
243
+
244
+ template <typename scalar_t>
245
+ inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) {
246
+ TORCH_INTERNAL_ASSERT(
247
+ false,
248
+ "at::cuda::sparse::bsrsm2_bufferSize: not implemented for ",
249
+ typeid(scalar_t).name());
250
+ }
251
+
252
+ template <>
253
+ void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float));
254
+ template <>
255
+ void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double));
256
+ template <>
257
+ void bsrsm2_bufferSize<c10::complex<float>>(
258
+ CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>));
259
+ template <>
260
+ void bsrsm2_bufferSize<c10::complex<double>>(
261
+ CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>));
262
+
263
+ #define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t) \
264
+ cusparseHandle_t handle, cusparseDirection_t dirA, \
265
+ cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
266
+ int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
267
+ const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
268
+ bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
269
+
270
+ template <typename scalar_t>
271
+ inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) {
272
+ TORCH_INTERNAL_ASSERT(
273
+ false,
274
+ "at::cuda::sparse::bsrsm2_analysis: not implemented for ",
275
+ typeid(scalar_t).name());
276
+ }
277
+
278
+ template <>
279
+ void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float));
280
+ template <>
281
+ void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double));
282
+ template <>
283
+ void bsrsm2_analysis<c10::complex<float>>(
284
+ CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>));
285
+ template <>
286
+ void bsrsm2_analysis<c10::complex<double>>(
287
+ CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>));
288
+
289
+ #define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t) \
290
+ cusparseHandle_t handle, cusparseDirection_t dirA, \
291
+ cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
292
+ int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA, \
293
+ const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
294
+ int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb, \
295
+ scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer
296
+
297
+ template <typename scalar_t>
298
+ inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) {
299
+ TORCH_INTERNAL_ASSERT(
300
+ false,
301
+ "at::cuda::sparse::bsrsm2_solve: not implemented for ",
302
+ typeid(scalar_t).name());
303
+ }
304
+
305
+ template <>
306
+ void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float));
307
+ template <>
308
+ void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double));
309
+ template <>
310
+ void bsrsm2_solve<c10::complex<float>>(
311
+ CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>));
312
+ template <>
313
+ void bsrsm2_solve<c10::complex<double>>(
314
+ CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>));
315
+
316
+ #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
317
+
318
+ } // namespace at::cuda::sparse
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Tensor.h>
4
+ #include <c10/util/Half.h>
5
+
6
+ #include <cuda.h>
7
+ #include <cuda_runtime.h>
8
+ #include <cuda_fp16.h>
9
+
10
+ namespace at {
11
+ template <>
12
+ inline __half* Tensor::data() const {
13
+ return reinterpret_cast<__half*>(data<Half>());
14
+ }
15
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda.h>
4
+ #include <limits.h>
5
+ #include <math.h>
6
+ #include <float.h>
7
+
8
+ // NumericLimits.cuh is a holder for numeric limits definitions of commonly used
9
+ // types. This header is very specific to ROCm HIP and may be removed in the future.
10
+ // This header is derived from the legacy THCNumerics.cuh.
11
+
12
+ // The lower_bound and upper_bound constants are same as lowest and max for
13
+ // integral types, but are -inf and +inf for floating point types. They are
14
+ // useful in implementing min, max, etc.
15
+
16
+ namespace at {
17
+
18
+ template <typename T>
19
+ struct numeric_limits {
20
+ };
21
+
22
+ // WARNING: the following at::numeric_limits definitions are there only to support
23
+ // HIP compilation for the moment. Use std::numeric_limits if you are not
24
+ // compiling for ROCm.
25
+ // from @colesbury: "The functions on numeric_limits aren't marked with
26
+ // __device__ which is why they don't work with ROCm. CUDA allows them
27
+ // because they're constexpr."
28
+
29
+ namespace {
30
+ // ROCm doesn't like INFINITY too.
31
+ constexpr double inf = INFINITY;
32
+ }
33
+
34
+ template <>
35
+ struct numeric_limits<bool> {
36
+ static inline __host__ __device__ bool lowest() { return false; }
37
+ static inline __host__ __device__ bool max() { return true; }
38
+ static inline __host__ __device__ bool lower_bound() { return false; }
39
+ static inline __host__ __device__ bool upper_bound() { return true; }
40
+ };
41
+
42
+ template <>
43
+ struct numeric_limits<uint8_t> {
44
+ static inline __host__ __device__ uint8_t lowest() { return 0; }
45
+ static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
46
+ static inline __host__ __device__ uint8_t lower_bound() { return 0; }
47
+ static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; }
48
+ };
49
+
50
+ template <>
51
+ struct numeric_limits<int8_t> {
52
+ static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
53
+ static inline __host__ __device__ int8_t max() { return INT8_MAX; }
54
+ static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; }
55
+ static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
56
+ };
57
+
58
+ template <>
59
+ struct numeric_limits<int16_t> {
60
+ static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
61
+ static inline __host__ __device__ int16_t max() { return INT16_MAX; }
62
+ static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; }
63
+ static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
64
+ };
65
+
66
+ template <>
67
+ struct numeric_limits<int32_t> {
68
+ static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
69
+ static inline __host__ __device__ int32_t max() { return INT32_MAX; }
70
+ static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; }
71
+ static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
72
+ };
73
+
74
+ template <>
75
+ struct numeric_limits<int64_t> {
76
+ #ifdef _MSC_VER
77
+ static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
78
+ static inline __host__ __device__ int64_t max() { return _I64_MAX; }
79
+ static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; }
80
+ static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; }
81
+ #else
82
+ static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
83
+ static inline __host__ __device__ int64_t max() { return INT64_MAX; }
84
+ static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; }
85
+ static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; }
86
+ #endif
87
+ };
88
+
89
+ template <>
90
+ struct numeric_limits<at::Half> {
91
+ static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); }
92
+ static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); }
93
+ static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); }
94
+ static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); }
95
+ };
96
+
97
+ template <>
98
+ struct numeric_limits<at::BFloat16> {
99
+ static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); }
100
+ static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); }
101
+ static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); }
102
+ static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); }
103
+ };
104
+
105
+ template <>
106
+ struct numeric_limits<float> {
107
+ static inline __host__ __device__ float lowest() { return -FLT_MAX; }
108
+ static inline __host__ __device__ float max() { return FLT_MAX; }
109
+ static inline __host__ __device__ float lower_bound() { return -static_cast<float>(inf); }
110
+ static inline __host__ __device__ float upper_bound() { return static_cast<float>(inf); }
111
+ };
112
+
113
+ template <>
114
+ struct numeric_limits<double> {
115
+ static inline __host__ __device__ double lowest() { return -DBL_MAX; }
116
+ static inline __host__ __device__ double max() { return DBL_MAX; }
117
+ static inline __host__ __device__ double lower_bound() { return -inf; }
118
+ static inline __host__ __device__ double upper_bound() { return inf; }
119
+ };
120
+
121
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/macros/Export.h>
3
+ #include <cstdint>
4
+
5
+ namespace at::cuda {
6
+
7
+ // enqueues a kernel that spins for the specified number of cycles
8
+ TORCH_CUDA_CU_API void sleep(int64_t cycles);
9
+
10
+ } // namespace at::cuda
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/detail/CUDAHooksInterface.h>
4
+
5
+ #include <ATen/Generator.h>
6
+ #include <c10/util/Optional.h>
7
+
8
+ // TODO: No need to have this whole header, we can just put it all in
9
+ // the cpp file
10
+
11
+ namespace at::cuda::detail {
12
+
13
+ // Set the callback to initialize Magma, which is set by
14
+ // torch_cuda_cu. This indirection is required so magma_init is called
15
+ // in the same library where Magma will be used.
16
+ TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
17
+
18
+
19
+ // The real implementation of CUDAHooksInterface
20
+ struct CUDAHooks : public at::CUDAHooksInterface {
21
+ CUDAHooks(at::CUDAHooksArgs) {}
22
+ void initCUDA() const override;
23
+ Device getDeviceFromPtr(void* data) const override;
24
+ bool isPinnedPtr(const void* data) const override;
25
+ const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
26
+ bool hasCUDA() const override;
27
+ bool hasMAGMA() const override;
28
+ bool hasCuDNN() const override;
29
+ bool hasCuSOLVER() const override;
30
+ bool hasROCM() const override;
31
+ const at::cuda::NVRTC& nvrtc() const override;
32
+ DeviceIndex current_device() const override;
33
+ bool hasPrimaryContext(DeviceIndex device_index) const override;
34
+ Allocator* getCUDADeviceAllocator() const override;
35
+ Allocator* getPinnedMemoryAllocator() const override;
36
+ bool compiledWithCuDNN() const override;
37
+ bool compiledWithMIOpen() const override;
38
+ bool supportsDilatedConvolutionWithCuDNN() const override;
39
+ bool supportsDepthwiseConvolutionWithCuDNN() const override;
40
+ bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
41
+ bool hasCUDART() const override;
42
+ long versionCUDART() const override;
43
+ long versionCuDNN() const override;
44
+ std::string showConfig() const override;
45
+ double batchnormMinEpsilonCuDNN() const override;
46
+ int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
47
+ void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
48
+ int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
49
+ void cuFFTClearPlanCache(DeviceIndex device_index) const override;
50
+ int getNumGPUs() const override;
51
+ void deviceSynchronize(DeviceIndex device_index) const override;
52
+ };
53
+
54
+ } // at::cuda::detail
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
2
+ // These handles are tied to device, and these libraries requires/recommends not to
3
+ // share handles across host threads.
4
+ //
5
+ // These libraries recommend using one handle per host thread. We may not want to do
6
+ // this because threads are relatively light-weight, but creating and destroying
7
+ // handles is expensive (destroying the handle causes synchronizations). DataParallel,
8
+ // for example, creates new threads for each forward pass.
9
+ //
10
+ // This file implements a handle pool mechanism. The handle pool returns handles on
11
+ // demand as threads request them. If all existing handles in the pool are in use,
12
+ // it creates a new one. As threads terminate, they release handles back into the pool.
13
+ // In this way, the handle pool never creates more handles than the high-water mark of
14
+ // active threads, so it's efficient with DataParallel.
15
+
16
+ #pragma once
17
+
18
+ #include <unordered_map>
19
+ #include <vector>
20
+ #include <utility>
21
+ #include <mutex>
22
+ #include <memory>
23
+
24
+ #include <c10/util/Exception.h>
25
+
26
+ namespace at::cuda { namespace {
27
+
28
+ template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
29
+ struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
30
+
31
+ struct Handle {
32
+ Handle_t handle;
33
+ Handle(bool create = false) : handle(nullptr)
34
+ {
35
+ if(create) Create(&handle);
36
+ }
37
+ // std::vector.emplace() and push_back() may route through temporaries and call
38
+ // copy/move constructors along the way. If this is the case, we don't want
39
+ // the destructors of temporaries to call cudnnDestroy on the handle.
40
+ // We can achieve safety (for the narrow case of stashing within std::vectors)
41
+ // by making Handle moveable but not copyable, and transferring handle ownership
42
+ // to the latest constructed object. This is not a substitute for full-blown
43
+ // reference counting, but reference counting may be overkill here.
44
+ // Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
45
+ // unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
46
+ Handle(const Handle& rhs) = delete;
47
+ // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
48
+ Handle(Handle&& rhs) : Handle() { std::swap(handle, rhs.handle); }
49
+ // operator= takes argument by value
50
+ Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
51
+ ~Handle() {
52
+ if(handle) Destroy(handle);
53
+ }
54
+ };
55
+
56
+ std::mutex mutex;
57
+
58
+ // Handles are lazily created as different threads request them,
59
+ // but are never destroyed until the end of the process.
60
+ // The maximum number of handles this process will create for each device is equal
61
+ // to the high-water mark of the number of concurrently active threads that request
62
+ // handles for that device.
63
+ // When threads terminate, they release their handles back into the pool for reuse.
64
+ // Otherwise, new handles would be created every time new threads were spawned,
65
+ // resulting in poor performance for Python modules that repeatedly or frequently
66
+ // spawned new sets of threads (like DataParallel, which creates a new set of threads
67
+ // for each forward pass).
68
+ //
69
+ // To prevent potential deadlocks, we explicitly choose not to cap the number
70
+ // of handles that are created per device.
71
+ // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
72
+ // only 4 can make forward progress at any time. The other 4 will not release their
73
+ // handles until they exit, so the fifth cannot make progress until then. This is
74
+ // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
75
+ // intermediate point (ie, before any of them have exited). We have no way to anticipate
76
+ // or enforce that user threads will not attempt such intermediate synchronization.
77
+ // The only way to ensure safety is to avoid imposing a cap on the number of handles.
78
+ std::unordered_map<int, std::vector<Handle>> created_handles;
79
+ std::unordered_map<int, std::vector<Handle_t>> available_handles;
80
+
81
+ // PoolWindow lazily creates and caches the handles that a particular thread is using,
82
+ // so in the common case handle access doesn't incur either handle creation or a mutex lock.
83
+ class PoolWindow
84
+ {
85
+ public:
86
+ PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
87
+ ~PoolWindow(){ release(); }
88
+
89
+ Handle_t reserve(int device)
90
+ {
91
+ // If this thread already has a handle for this device, return it
92
+ if(my_handles.find(device) != my_handles.end())
93
+ return my_handles[device];
94
+
95
+ // otherwise, either grab a handle from the pool if one is available,
96
+ // or if not, create a new one.
97
+ auto parent = weak_parent.lock();
98
+ TORCH_CHECK(parent, "Cannot create handle during program termination");
99
+ std::lock_guard<std::mutex> guard(parent->mutex);
100
+
101
+ if(parent->available_handles[device].size() > 0)
102
+ {
103
+ my_handles[device] = parent->available_handles[device].back();
104
+ parent->available_handles[device].pop_back();
105
+ }
106
+ else
107
+ {
108
+ // In local testing, I do observe that emplace_back sometimes routes through temporaries
109
+ // that incur move-constructor and destructor calls. See comments in Handle above.
110
+ parent->created_handles[device].emplace_back(true /*create*/);
111
+ my_handles[device] = parent->created_handles[device].back().handle;
112
+ }
113
+
114
+ return my_handles[device];
115
+ }
116
+
117
+ private:
118
+ // Stores the per-device handles currently owned by this thread
119
+ std::unordered_map<int, Handle_t> my_handles;
120
+
121
+ std::weak_ptr<DeviceThreadHandlePool> weak_parent;
122
+
123
+ // Called by the destructor. Releases this thread's handles back into the pool.
124
+ void release() {
125
+ if(my_handles.size() > 0) {
126
+ auto parent = weak_parent.lock();
127
+ if (!parent) {
128
+ // If this thread exits after atexit handlers have completed, the
129
+ // cuda context itself may be invalid, so we must leak the handles.
130
+ return;
131
+ }
132
+
133
+ std::lock_guard<std::mutex> guard(parent->mutex);
134
+ for(auto d_h : my_handles)
135
+ parent->available_handles[d_h.first].push_back(d_h.second);
136
+ }
137
+ }
138
+ };
139
+
140
+ // Warning:
141
+ // If you want to change this function, be aware that this function will be called
142
+ // by multiple threads and there is no mutex guarding the call of this function, so
143
+ // make sure your implementation is thread-safe.
144
+ PoolWindow *newPoolWindow() {
145
+ // The returned pointer will be owned by a thread local variable
146
+ // so that different threads does not share the same PoolWindow.
147
+ return new PoolWindow(this->shared_from_this());
148
+ }
149
+ };
150
+
151
+ }} // namespace at::cuda::detail::<anonymous>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/TensorBase.h>
4
+ #include <ATen/cuda/detail/TensorInfo.cuh>
5
+ #include <ATen/native/CanUse32BitIndexMath.h>
6
+
7
+ namespace at::cuda::detail {
8
+
9
+ TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
10
+ using at::native::canUse32BitIndexMath;
11
+
12
+ template <typename scalar, typename IndexType>
13
+ TensorInfo<scalar, IndexType>
14
+ getTensorInfo(const at::TensorBase &t) {
15
+ IndexType sz[MAX_TENSORINFO_DIMS];
16
+ IndexType st[MAX_TENSORINFO_DIMS];
17
+
18
+ int dims = t.dim();
19
+ for (int i = 0; i < dims; ++i) {
20
+ sz[i] = t.size(i);
21
+ st[i] = t.stride(i);
22
+ }
23
+
24
+ scalar* data_ptr = nullptr;
25
+
26
+ if constexpr (std::is_const<scalar>::value) {
27
+ data_ptr = t.const_data_ptr<scalar>();
28
+ } else {
29
+ data_ptr = t.mutable_data_ptr<scalar>();
30
+ }
31
+
32
+ return TensorInfo<scalar, IndexType>(
33
+ data_ptr, dims, sz, st);
34
+ }
35
+
36
+ } // namespace at::cuda::detail
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <assert.h>
4
+ #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
5
+ #include <cuda_runtime.h>
6
+ #endif
7
+
8
+ namespace at::cuda::detail {
9
+
10
+ // A utility class to implement integer division by multiplication, given a fixed
11
+ // divisor.
12
+ //
13
+ // WARNING: The fast divider algorithm is only implemented for unsigned int;
14
+ // otherwise we default to plain integer division. For unsigned int,
15
+ // we further assume that the dividend is at most INT32_MAX. Thus,
16
+ // IntDivider must NOT be used for general integer division.
17
+ //
18
+ // This reduced range is enough for our purpose, and it allows us to
19
+ // slightly simplify the computation.
20
+ //
21
+ // (NOTE: Below, "2^k" denotes exponentiation, i.e., 1<<k.)
22
+ //
23
+ // For any N-bit unsigned integer d (> 0), we can find a "magic number" m (2^N
24
+ // <= m < 2^(N+1)) and shift s such that:
25
+ //
26
+ // \floor(n / d) = \floor((m * n) / 2^(N+s)).
27
+ //
28
+ // Given such m and s, the integer division can be then implemented as:
29
+ //
30
+ // let m' = m - 2^N // 0 <= m' < 2^N
31
+ //
32
+ // fast_integer_division(n):
33
+ // // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned
34
+ // // integer. Then take the higher N bits.
35
+ // t = (m' * n) >> N
36
+ //
37
+ // // Here we use the fact that n is less than 2^(N-1): otherwise the value
38
+ // // of (t + n) may not fit in an N-bit integer.
39
+ // return (t + n) >> s
40
+ //
41
+ // Finding such a magic number is surprisingly easy:
42
+ //
43
+ // s = \ceil(\log_2 d)
44
+ // m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic.
45
+ //
46
+ // See also:
47
+ // - Division by Invariant Integers Using Multiplication,
48
+ // Torbjörn Granlund and Peter L. Montgomery, 1994.
49
+ //
50
+ // - http://www.hackersdelight.org/magic.htm
51
+ //
52
+ // - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
53
+
54
+ // Result of div/mod operation stored together.
55
+ template <typename Value>
56
+ struct DivMod {
57
+ Value div, mod;
58
+
59
+ C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { }
60
+ };
61
+
62
+ // Base case: we only have an implementation for uint32_t for now. For
63
+ // everything else, we use plain division.
64
+ template <typename Value>
65
+ struct IntDivider {
66
+ IntDivider() = default;
67
+ IntDivider(Value d) : divisor(d) { }
68
+
69
+ C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; }
70
+ C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; }
71
+ C10_HOST_DEVICE inline DivMod<Value> divmod(Value n) const {
72
+ return DivMod<Value>(n / divisor, n % divisor);
73
+ }
74
+
75
+ Value divisor;
76
+ };
77
+
78
+ // Implement fast integer division.
79
+ template <>
80
+ struct IntDivider<unsigned int> {
81
+ static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int.");
82
+
83
+ IntDivider() = default;
84
+
85
+ IntDivider(unsigned int d) : divisor(d) {
86
+ assert(divisor >= 1 && divisor <= INT32_MAX);
87
+
88
+ // TODO: gcc/clang has __builtin_clz() but it's not portable.
89
+ for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break;
90
+
91
+ uint64_t one = 1;
92
+ uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
93
+ m1 = magic;
94
+ assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits.
95
+ }
96
+
97
+ C10_HOST_DEVICE inline unsigned int div(unsigned int n) const {
98
+ #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
99
+ // 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
100
+ // 'm1'.
101
+ unsigned int t = __umulhi(n, m1);
102
+ return (t + n) >> shift;
103
+ #else
104
+ // Using uint64_t so that the addition does not overflow.
105
+ uint64_t t = ((uint64_t) n * m1) >> 32;
106
+ return (t + n) >> shift;
107
+ #endif
108
+ }
109
+
110
+ C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const {
111
+ return n - div(n) * divisor;
112
+ }
113
+
114
+ C10_HOST_DEVICE inline DivMod<unsigned int> divmod(unsigned int n) const {
115
+ unsigned int q = div(n);
116
+ return DivMod<unsigned int>(q, n - q * divisor);
117
+ }
118
+
119
+ unsigned int divisor; // d above.
120
+ unsigned int m1; // Magic number: m' above.
121
+ unsigned int shift; // Shift amounts.
122
+ };
123
+
124
+ } // namespace at::cuda::detail
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <array>
4
+ #include <cstdint>
5
+ #include <type_traits>
6
+ #include <c10/macros/Macros.h>
7
+ #include <ATen/core/Array.h>
8
+ #include <ATen/native/TensorIterator.h>
9
+ #include <ATen/cuda/detail/IntegerDivider.cuh>
10
+
11
+ // If element_sizes is nullptr, then the strides will be in bytes, otherwise
12
+ // the strides will be in # of elements.
13
+ // Operands that share the same shape, but may have different strides.
14
+ // OffsetCalculator iterates the tensor in a column-major order
15
+
16
+ #if defined(USE_ROCM)
17
+ constexpr int MAX_DIMS = 16;
18
+ #else
19
+ constexpr int MAX_DIMS = 25;
20
+ #endif
21
+
22
+ template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
23
+ struct OffsetCalculator {
24
+ // We allow having negative strides to implement some operations like torch.flip
25
+ using stride_t = std::conditional_t<signed_strides,
26
+ std::make_signed_t<index_t>,
27
+ index_t>;
28
+ // The offset for each argument. Wrapper around fixed-size array.
29
+ // On CUDA, zero sized array is not allowed, so when we are handling nullary
30
+ // operators, we need to create a size 1 offset to avoid compiler failure.
31
+ // This size 1 offset is just a placeholder, and we will not use it.
32
+ using offset_type = at::detail::Array<stride_t, std::max<int>(NARGS, 1)>;
33
+
34
+ // if element_sizes is nullptr, then the strides will be in bytes, otherwise
35
+ // the strides will be in # of elements.
36
+ OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) {
37
+ TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
38
+ for (int i=0; i < dims; i++){
39
+ sizes_[i] = at::cuda::detail::IntDivider<index_t>(sizes[i]);
40
+ for (int arg = 0; arg < NARGS; arg++) {
41
+ int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
42
+ strides_[i][arg] = strides[arg][i] / element_size;
43
+ }
44
+ }
45
+ }
46
+
47
+ C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
48
+ offset_type offsets;
49
+ #pragma unroll
50
+ for (int arg = 0; arg < NARGS; arg++) {
51
+ offsets[arg] = 0;
52
+ }
53
+
54
+ #pragma unroll
55
+ for (int dim = 0; dim < MAX_DIMS; ++dim) {
56
+ if (dim == dims) {
57
+ break;
58
+ }
59
+ auto divmod = sizes_[dim].divmod(linear_idx);
60
+ linear_idx = divmod.div;
61
+
62
+ #pragma unroll
63
+ for (int arg = 0; arg < NARGS; arg++) {
64
+ offsets[arg] += divmod.mod * strides_[dim][arg];
65
+ }
66
+
67
+ }
68
+ return offsets;
69
+ }
70
+
71
+ int dims;
72
+ at::cuda::detail::IntDivider<index_t> sizes_[MAX_DIMS];
73
+ stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
74
+ };
75
+
76
+ template <int NARGS, typename index_t = uint32_t>
77
+ struct TrivialOffsetCalculator {
78
+ // The offset for each argument. Wrapper around fixed-size array.
79
+ // The offsets are in # of elements, not in bytes.
80
+ // On CUDA, zero sized array is not allowed, so when we are handling nullary
81
+ // operators, we need to create a size 1 offset to avoid compiler failure.
82
+ // This size 1 offset is just a placeholder, and we will not use it.
83
+ using offset_type = at::detail::Array<index_t, std::max<int>(NARGS, 1)>;
84
+
85
+ C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
86
+ offset_type offsets;
87
+ #pragma unroll
88
+ for (int arg = 0; arg < NARGS; arg++) {
89
+ offsets[arg] = linear_idx;
90
+ }
91
+ return offsets;
92
+ }
93
+ };
94
+
95
+ // Make an OffsetCalculator with byte offsets
96
+ template<int N, bool signed_strides = false>
97
+ static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(const at::TensorIteratorBase& iter) {
98
+ TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
99
+ std::array<const int64_t*, N> strides;
100
+ for (int i = 0; i < N; i++) {
101
+ strides[i] = iter.strides(i).data();
102
+ }
103
+ return OffsetCalculator<N, uint32_t, signed_strides>(iter.ndim(), iter.shape().data(), strides.data());
104
+ }
105
+
106
+ // Make an OffsetCalculator with element offsets
107
+ template<int N, bool signed_strides = false>
108
+ static OffsetCalculator<N, uint32_t, signed_strides> make_element_offset_calculator(
109
+ const at::TensorIteratorBase& iter) {
110
+ TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
111
+ std::array<const int64_t*, N> strides;
112
+ std::array<int64_t, N> element_sizes;
113
+ for (int i = 0; i < N; i++) {
114
+ strides[i] = iter.strides(i).data();
115
+ element_sizes[i] = iter.element_size(i);
116
+ }
117
+ return OffsetCalculator<N, uint32_t, signed_strides>(
118
+ iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data());
119
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
2
+ // Eager mode clients should not include this file directly, instead,
3
+ // they should #include <ATen/cuda/PhiloxCudaState.h>, which has a #pragma once.
4
+
5
+ // Stores RNG state values. Passed as a kernel argument.
6
+ // See Note [CUDA Graph-safe RNG states].
7
+ //
8
+ // The raw definition lives in its own file so jit codegen can easily copy it.
9
+ namespace at {
10
+
11
+ struct PhiloxCudaState {
12
+ PhiloxCudaState() = default;
13
+ // Called if graph capture is not underway
14
+ PhiloxCudaState(uint64_t seed,
15
+ uint64_t offset) {
16
+ seed_.val = seed;
17
+ offset_.val = offset;
18
+ }
19
+ // Called if graph capture is underway
20
+ PhiloxCudaState(int64_t* seed,
21
+ int64_t* offset_extragraph,
22
+ uint32_t offset_intragraph) {
23
+ seed_.ptr = seed;
24
+ offset_.ptr = offset_extragraph;
25
+ offset_intragraph_ = offset_intragraph;
26
+ captured_ = true;
27
+ }
28
+
29
+ // Public members, directly accessible by at::cuda::philox::unpack.
30
+ // If we made them private with getters/setters, the getters/setters
31
+ // would have to be __device__, and we can't declare __device__ in ATen.
32
+ union Payload {
33
+ uint64_t val;
34
+ int64_t* ptr;
35
+ };
36
+
37
+ Payload seed_;
38
+ Payload offset_;
39
+ uint32_t offset_intragraph_ = 0;
40
+ bool captured_ = false;
41
+ };
42
+
43
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/CollapseDims.h>
4
+
5
+ namespace at::cuda::detail {
6
+
7
+ #define MAX_TENSORINFO_DIMS 25
8
+
9
+ // CUDA kernel argument that defines tensor layout
10
+ template <typename T, typename IndexType>
11
+ struct TensorInfo {
12
+ TensorInfo();
13
+ TensorInfo(T* p,
14
+ int dim,
15
+ IndexType sz[MAX_TENSORINFO_DIMS],
16
+ IndexType st[MAX_TENSORINFO_DIMS]);
17
+
18
+ // Set the size of the given dimension to 1, as if it were a
19
+ // reduction dim (allows you to calculate offsets of the reduction
20
+ // slice)
21
+ void reduceDim(int dim);
22
+
23
+ // See note on [collapse dims].
24
+ int collapseDims(const int excludeDim = -1);
25
+
26
+ // Contiguous tensors of more than one dimension are collapsed down
27
+ // to one tensor
28
+ __host__ __device__ inline bool isContiguous() const {
29
+ return (dims == 1 && strides[0] == 1);
30
+ }
31
+
32
+ T* data;
33
+ IndexType sizes[MAX_TENSORINFO_DIMS];
34
+ IndexType strides[MAX_TENSORINFO_DIMS];
35
+ int dims;
36
+ };
37
+
38
+ template <typename T, typename IndexType>
39
+ TensorInfo<T, IndexType>::TensorInfo() {
40
+ data = nullptr;
41
+ dims = 0;
42
+ }
43
+
44
+ template <typename T, typename IndexType>
45
+ TensorInfo<T, IndexType>::TensorInfo(T* p,
46
+ int dim,
47
+ IndexType sz[MAX_TENSORINFO_DIMS],
48
+ IndexType st[MAX_TENSORINFO_DIMS]) {
49
+ data = p;
50
+ dims = dim;
51
+ TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
52
+
53
+ for (int i = 0; i < dim; ++i) {
54
+ sizes[i] = sz[i];
55
+ strides[i] = st[i];
56
+ }
57
+ }
58
+
59
+ template <typename T, typename IndexType>
60
+ void
61
+ TensorInfo<T, IndexType>::reduceDim(int dim) {
62
+ TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
63
+ sizes[dim] = 1;
64
+ }
65
+
66
+ template <typename T, typename IndexType>
67
+ int
68
+ TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {
69
+ auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
70
+ dims = std::get<1>(result);
71
+ return std::get<0>(result);
72
+ }
73
+
74
+ // Translate a linear index for the apply to a T* offset;
75
+ // specialized on `Dims` to reduce nvcc compilation time
76
+ template <typename T, typename IndexType, int Dims>
77
+ struct IndexToOffset {
78
+ static __host__ __device__ IndexType get(
79
+ IndexType linearId,
80
+ const TensorInfo<T, IndexType>& info) {
81
+
82
+ IndexType offset = 0;
83
+
84
+ // Uses static dims
85
+ for (int i = Dims - 1; i > 0; --i) {
86
+ IndexType curDimIndex = linearId % info.sizes[i];
87
+ IndexType curDimOffset = curDimIndex * info.strides[i];
88
+ offset += curDimOffset;
89
+ linearId /= info.sizes[i];
90
+ }
91
+
92
+ return offset + linearId * info.strides[0];
93
+ }
94
+ };
95
+
96
+ // Uses dynamic (runtime) instead of static (compiletime) dims
97
+ template <typename T, typename IndexType>
98
+ struct IndexToOffset<T, IndexType, -1> {
99
+ static inline __host__ __device__ IndexType get(
100
+ IndexType linearId,
101
+ const TensorInfo<T, IndexType>& info) {
102
+
103
+ IndexType offset = 0;
104
+
105
+ for (int i = info.dims - 1; i > 0; --i) {
106
+ IndexType curDimIndex = linearId % info.sizes[i];
107
+ IndexType curDimOffset = curDimIndex * info.strides[i];
108
+ offset += curDimOffset;
109
+ linearId /= info.sizes[i];
110
+ }
111
+
112
+ return offset + linearId * info.strides[0];
113
+ }
114
+ };
115
+
116
+ } // namespace at::cuda::detail
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/jiterator.h ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/jit_macros.h>
3
+
4
+ #if AT_USE_JITERATOR()
5
+
6
+ #include <c10/macros/Export.h>
7
+ #include <c10/util/SmallVector.h>
8
+ #include <ATen/core/Tensor.h>
9
+
10
+ #include <string>
11
+ #include <vector>
12
+
13
+ namespace at::cuda {
14
+
15
+ TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
16
+ const std::string& code_string,
17
+ const std::string& kernel_name,
18
+ const int num_outputs,
19
+ const c10::SmallVector<at::Tensor>& tensors,
20
+ const c10::SmallVector<at::Scalar>& extra_args,
21
+ bool return_by_ref);
22
+
23
+ } // namespace at::cuda
24
+
25
+ #else
26
+
27
+ namespace at::cuda {
28
+
29
+ TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
30
+ const std::string& code_string,
31
+ const std::string& kernel_name,
32
+ const int num_outputs,
33
+ const c10::SmallVector<at::Tensor>& tensors,
34
+ const c10::SmallVector<at::Scalar>& extra_args,
35
+ bool return_by_ref) {
36
+ TORCH_CHECK(false, "Jiterator is not supported");
37
+ }
38
+ } // namespace at::cuda
39
+
40
+ #endif // AT_USE_JITERATOR()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Original TunableOp is from onnxruntime.
2
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4
+ // Copyright (c) Microsoft Corporation.
5
+ // Licensed under the MIT license.
6
+ //
7
+ // Adapting TunableOp into PyTorch
8
+ // Copyright (c) Advanced Micro Devices, Inc.
9
+ //
10
+ #pragma once
11
+
12
+ #include <string>
13
+
14
+ #include <ATen/cuda/tunable/TunableOp.h>
15
+ #include <ATen/cuda/Exceptions.h>
16
+ #include <c10/util/StringUtil.h>
17
+
18
+ namespace at::cuda::tunable {
19
+
20
+ enum class BlasOp {
21
+ N = 0,
22
+ T = 1
23
+ };
24
+
25
+ inline std::string BlasOpToString(BlasOp op) {
26
+ switch (op) {
27
+ case BlasOp::N:
28
+ return "N";
29
+ case BlasOp::T:
30
+ return "T";
31
+ }
32
+ TORCH_CHECK(false, "unrecognized BlasOp");
33
+ return "N";
34
+ }
35
+
36
+ template <typename T>
37
+ struct GemmParams : OpParams {
38
+ std::string Signature() const override {
39
+ return c10::str(transa, transb, "_", m, "_", n, "_", k);
40
+ }
41
+
42
+ GemmParams* DeepCopy() const {
43
+ GemmParams* copy = new GemmParams;
44
+ *copy = *this;
45
+ c10::DeviceIndex device = 0;
46
+ AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
47
+ size_t c_size = m * n * sizeof(T);
48
+ copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
49
+ AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
50
+ copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
51
+ return copy;
52
+ }
53
+
54
+ // only call on object returned by DeepCopy
55
+ void Delete() {
56
+ c10::cuda::CUDACachingAllocator::raw_delete(c);
57
+ }
58
+
59
+ TuningStatus NumericalCheck(GemmParams<T> *other) {
60
+ auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType<T>::value).device(at::kCUDA);
61
+ // comparison done as 1D tensor
62
+ at::Tensor ref = at::from_blob(c, {m*n}, options);
63
+ at::Tensor oth = at::from_blob(other->c, {m*n}, options);
64
+ at::Tensor ref_float = ref.to(at::kFloat);
65
+ at::Tensor oth_float = oth.to(at::kFloat);
66
+ std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
67
+ std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
68
+ double last_succeed_atol = 1;
69
+ double last_succeed_rtol = 1;
70
+ for (auto& atol : atols) {
71
+ for (auto& rtol : rtols) {
72
+ if (at::allclose(ref_float, oth_float, rtol, atol)) {
73
+ last_succeed_atol = atol;
74
+ last_succeed_rtol = rtol;
75
+ }
76
+ }
77
+ }
78
+ if (last_succeed_atol == 1) {
79
+ return FAIL;
80
+ }
81
+ else {
82
+ TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
83
+ }
84
+
85
+ return OK;
86
+ }
87
+
88
+ char transa;
89
+ char transb;
90
+ int64_t m;
91
+ int64_t n;
92
+ int64_t k;
93
+ at::opmath_type<T> alpha;
94
+ const T* a;
95
+ int64_t lda;
96
+ const T* b;
97
+ int64_t ldb;
98
+ at::opmath_type<T> beta;
99
+ T* c;
100
+ int64_t ldc;
101
+ };
102
+
103
+ template <typename T>
104
+ struct GemmStridedBatchedParams : OpParams {
105
+ std::string Signature() const override {
106
+ return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
107
+ }
108
+
109
+ GemmStridedBatchedParams* DeepCopy() const {
110
+ GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
111
+ *copy = *this;
112
+ c10::DeviceIndex device = 0;
113
+ AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
114
+ size_t c_size = batch * stride_c * sizeof(T);
115
+ copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
116
+ AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
117
+ copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
118
+ return copy;
119
+ }
120
+
121
+ // only call on object returned by DeepCopy
122
+ void Delete() {
123
+ c10::cuda::CUDACachingAllocator::raw_delete(c);
124
+ }
125
+
126
+ TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
127
+ auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType<T>::value).device(at::kCUDA);
128
+ // comparison done as 1D tensor
129
+ at::Tensor ref = at::from_blob(c, {batch*stride_c}, options);
130
+ at::Tensor oth = at::from_blob(other->c, {batch*stride_c}, options);
131
+ at::Tensor ref_float = ref.to(at::kFloat);
132
+ at::Tensor oth_float = oth.to(at::kFloat);
133
+ std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
134
+ std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
135
+ double last_succeed_atol = 1;
136
+ double last_succeed_rtol = 1;
137
+ for (auto& atol : atols) {
138
+ for (auto& rtol : rtols) {
139
+ if (at::allclose(ref_float, oth_float, rtol, atol)) {
140
+ last_succeed_atol = atol;
141
+ last_succeed_rtol = rtol;
142
+ }
143
+ }
144
+ }
145
+ if (last_succeed_atol == 1) {
146
+ return FAIL;
147
+ }
148
+ else {
149
+ TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
150
+ }
151
+
152
+ return OK;
153
+ }
154
+
155
+ char transa;
156
+ char transb;
157
+ int64_t m;
158
+ int64_t n;
159
+ int64_t k;
160
+ at::opmath_type<T> alpha;
161
+ const T* a;
162
+ int64_t lda;
163
+ int64_t stride_a;
164
+ const T* b;
165
+ int64_t ldb;
166
+ int64_t stride_b;
167
+ at::opmath_type<T> beta;
168
+ T* c;
169
+ int64_t ldc;
170
+ int64_t stride_c;
171
+ int64_t batch;
172
+ };
173
+
174
+ } // namespace at::cuda::tunable
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ #include <ATen/cuda/CUDAContext.h>
7
+ #include <ATen/cuda/tunable/TunableOp.h>
8
+ #include <ATen/cuda/tunable/GemmCommon.h>
9
+ #include <c10/cuda/CUDACachingAllocator.h>
10
+ #include <c10/util/StringUtil.h>
11
+
12
+ #include <hipblaslt/hipblaslt.h>
13
+ #include <hipblaslt/hipblaslt-ext.hpp>
14
+
15
+ #define TORCH_HIPBLASLT_CHECK(EXPR) \
16
+ do { \
17
+ hipblasStatus_t __err = EXPR; \
18
+ TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \
19
+ "hipblaslt error: ", \
20
+ hipblasStatusToString(__err), \
21
+ " when calling `" #EXPR "`"); \
22
+ } while (0)
23
+
24
+ namespace at::cuda::tunable {
25
+
26
+ #ifdef HIPBLASLT_HAS_GETINDEXFROMALGO
27
+ #define GETINDEXFROMALGO(algo) hipblaslt_ext::getIndexFromAlgo(algo)
28
+ #else
29
+ static int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo) {
30
+ int* algo_ptr = (int*)algo.data;
31
+ if(*algo_ptr < 0) {
32
+ return -1;
33
+ }
34
+ return *algo_ptr;
35
+ }
36
+ #define GETINDEXFROMALGO(algo) getIndexFromAlgo(algo)
37
+ #endif
38
+
39
+ #ifdef HIPBLASLT_CUSTOM_COMPUTE_TYPE
40
+ #define COMPUTE_TYPE_32 HIPBLASLT_COMPUTE_F32
41
+ #else
42
+ #define COMPUTE_TYPE_32 HIPBLAS_COMPUTE_32F
43
+ #endif
44
+
45
+ #ifdef HIPBLASLT_CUSTOM_DATA_TYPE
46
+
47
+ template <typename T>
48
+ constexpr hipblasltDatatype_t HipBlasDataTypeFor();
49
+
50
+ template <>
51
+ constexpr hipblasltDatatype_t HipBlasDataTypeFor<float>() {
52
+ return HIPBLASLT_R_32F;
53
+ }
54
+
55
+ template <>
56
+ constexpr hipblasltDatatype_t HipBlasDataTypeFor<Half>() {
57
+ return HIPBLASLT_R_16F;
58
+ }
59
+
60
+ template <>
61
+ constexpr hipblasltDatatype_t HipBlasDataTypeFor<BFloat16>() {
62
+ return HIPBLASLT_R_16B;
63
+ }
64
+
65
+ template <>
66
+ constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
67
+ return HIPBLASLT_R_64F;
68
+ }
69
+
70
+ #define DATA_TYPE_R_32 HIPBLASLT_R_32F
71
+
72
+ #else
73
+
74
+ template <typename T>
75
+ constexpr hipblasDatatype_t HipBlasDataTypeFor();
76
+
77
+ template <>
78
+ constexpr hipblasDatatype_t HipBlasDataTypeFor<float>() {
79
+ return HIPBLAS_R_32F;
80
+ }
81
+
82
+ template <>
83
+ constexpr hipblasDatatype_t HipBlasDataTypeFor<Half>() {
84
+ return HIPBLAS_R_16F;
85
+ }
86
+
87
+ template <>
88
+ constexpr hipblasDatatype_t HipBlasDataTypeFor<BFloat16>() {
89
+ return HIPBLAS_R_16B;
90
+ }
91
+
92
+ template <>
93
+ constexpr hipblasDatatype_t HipBlasDataTypeFor<double>() {
94
+ return HIPBLAS_R_64F;
95
+ }
96
+
97
+ #ifdef HIPBLAS_V2
98
+ #define DATA_TYPE_R_32 HIP_R_32F
99
+ #else
100
+ #define DATA_TYPE_R_32 HIPBLAS_R_32F
101
+ #endif
102
+
103
+ #endif
104
+
105
+ template <typename T, typename ParamsT>
106
+ int GetBatchFromParams(const ParamsT* params) {
107
+ return 1;
108
+ }
109
+
110
+ template <typename T>
111
+ int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
112
+ return params->batch;
113
+ }
114
+
115
+ template <typename T, typename ParamsT>
116
+ int GetStrideAFromParams(const ParamsT* params) {
117
+ return 1;
118
+ }
119
+
120
+ template <typename T>
121
+ int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
122
+ return params->stride_a;
123
+ }
124
+
125
+ template <typename T, typename ParamsT>
126
+ int GetStrideBFromParams(const ParamsT* params) {
127
+ return 1;
128
+ }
129
+
130
+ template <typename T>
131
+ int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
132
+ return params->stride_b;
133
+ }
134
+
135
+ template <typename T, typename ParamsT>
136
+ int GetStrideCFromParams(const ParamsT* params) {
137
+ return 1;
138
+ }
139
+
140
+ template <typename T>
141
+ int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
142
+ return params->stride_c;
143
+ }
144
+
145
+ static hipblasOperation_t _hipblasOpFromChar(char op) {
146
+ switch (op) {
147
+ case 'n':
148
+ case 'N':
149
+ return HIPBLAS_OP_N;
150
+ case 't':
151
+ case 'T':
152
+ return HIPBLAS_OP_T;
153
+ case 'c':
154
+ case 'C':
155
+ return HIPBLAS_OP_C;
156
+ }
157
+ AT_ERROR(
158
+ "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
159
+ }
160
+
161
+ static char _charFromhipblasOp(hipblasOperation_t op) {
162
+ switch (op) {
163
+ case HIPBLAS_OP_N:
164
+ return 'N';
165
+ case HIPBLAS_OP_T:
166
+ return 'T';
167
+ case HIPBLAS_OP_C:
168
+ return 'C';
169
+ }
170
+ AT_ERROR(
171
+ "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
172
+ }
173
+
174
+ static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
175
+ if (layout == BlasOp::N) {
176
+ return HIPBLAS_OP_N;
177
+ }
178
+ return HIPBLAS_OP_T;
179
+ }
180
+
181
+ static size_t GetHipblasltWorkspaceSize() {
182
+ static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE");
183
+ // 256MB is max workspace size allowed for hipblaslt
184
+ // hipblaslt-bench uses 32MB
185
+ // recommendation from hipblaslt author was 76MB
186
+ size_t workspace_size = 2*128*1024*1024; // default 256MB
187
+ if (env) {
188
+ try {
189
+ workspace_size = std::stoi(env);
190
+ } catch(std::invalid_argument const& e) {
191
+ TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
192
+ " using default workspace size of ", workspace_size, " bytes.");
193
+ } catch(std::out_of_range const& e) {
194
+ TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
195
+ " using default workspace size of ", workspace_size, " bytes.");
196
+ }
197
+ }
198
+ return workspace_size;
199
+ }
200
+
201
+ template <typename T, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
202
+ class HipblasltGemmOp : public Callable<ParamsT> {
203
+ public:
204
+ HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
205
+
206
+ TuningStatus Call(const ParamsT* params) override {
207
+ hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
208
+ hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
209
+ auto in_out_datatype = HipBlasDataTypeFor<T>();
210
+ auto opa = _hipblasOpFromChar(params->transa);
211
+ auto opb = _hipblasOpFromChar(params->transb);
212
+
213
+ TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
214
+
215
+ float alpha = static_cast<float>(params->alpha);
216
+ float beta = static_cast<float>(params->beta);
217
+
218
+ hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
219
+ hipblasLtMatmulDesc_t matmul;
220
+ if (opa == HIPBLAS_OP_N) {
221
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->m, params->k, params->lda));
222
+ }
223
+ else {
224
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->k, params->m, params->lda));
225
+ }
226
+ if (opb == HIPBLAS_OP_N) {
227
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->k, params->n, params->ldb));
228
+ }
229
+ else {
230
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->n, params->k, params->ldb));
231
+ }
232
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
233
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescCreate(&matmul, COMPUTE_TYPE_32, DATA_TYPE_R_32));
234
+
235
+ int batch = GetBatchFromParams<T>(params);
236
+ if (batch > 1) {
237
+ int64_t stride_a = GetStrideAFromParams<T>(params);
238
+ int64_t stride_b = GetStrideBFromParams<T>(params);
239
+ int64_t stride_c = GetStrideCFromParams<T>(params);
240
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
241
+ mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
242
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
243
+ mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
244
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
245
+ mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
246
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
247
+ mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
248
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
249
+ mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
250
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
251
+ mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
252
+ }
253
+
254
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
255
+ matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &opa, sizeof(int32_t)));
256
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
257
+ matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &opb, sizeof(int32_t)));
258
+
259
+ size_t workspace_size = GetHipblasltWorkspaceSize();
260
+
261
+ auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
262
+
263
+ size_t ret_workspace_size = 0;
264
+ auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
265
+ matmul,
266
+ &alpha,
267
+ mat_a,
268
+ mat_b,
269
+ &beta,
270
+ mat_c,
271
+ mat_c,
272
+ algo_,
273
+ ret_workspace_size);
274
+
275
+ if (status == HIPBLAS_STATUS_SUCCESS) {
276
+ if (ret_workspace_size >= workspace_size) {
277
+ //TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " workspace too large");
278
+ return FAIL;
279
+ }
280
+ }
281
+ else {
282
+ //TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " not supported");
283
+ return FAIL;
284
+ }
285
+
286
+ void* workspace_buffer = nullptr;
287
+ if (workspace_size > 0) {
288
+ workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size);
289
+ }
290
+
291
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
292
+ matmul,
293
+ &alpha,
294
+ params->a,
295
+ mat_a,
296
+ params->b,
297
+ mat_b,
298
+ &beta,
299
+ params->c,
300
+ mat_c,
301
+ params->c,
302
+ mat_c,
303
+ &algo_,
304
+ workspace_buffer,
305
+ workspace_size,
306
+ at::cuda::getCurrentCUDAStream()));
307
+
308
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
309
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
310
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
311
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
312
+ if (workspace_size > 0) {
313
+ c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer);
314
+ }
315
+ return OK;
316
+ }
317
+
318
+ private:
319
+ hipblasLtMatmulAlgo_t algo_;
320
+ };
321
+
322
+ template <typename T, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
323
+ auto GetHipBlasLtTypeStringAndOps() {
324
+ hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
325
+ hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
326
+ auto in_out_datatype = HipBlasDataTypeFor<T>();
327
+ std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
328
+
329
+ hipblasLtHandle_t handle;
330
+ TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
331
+ TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
332
+ hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
333
+ transa_outer,
334
+ transb_outer,
335
+ in_out_datatype,
336
+ in_out_datatype,
337
+ in_out_datatype,
338
+ in_out_datatype,
339
+ COMPUTE_TYPE_32,
340
+ heuristic_result));
341
+ TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
342
+
343
+ // Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
344
+ std::sort(heuristic_result.begin(),
345
+ heuristic_result.end(),
346
+ [](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
347
+ return GETINDEXFROMALGO(a.algo) < GETINDEXFROMALGO(b.algo);
348
+ });
349
+
350
+ int returned_algo_count = heuristic_result.size();
351
+ std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
352
+ for (int i = 0; i < returned_algo_count; i++) {
353
+ auto algo = heuristic_result[i].algo;
354
+ int algo_index = GETINDEXFROMALGO(algo);
355
+ auto callable = std::make_unique<HipblasltGemmOp<T, ALayout, BLayout, ParamsT>>(algo);
356
+ std::string type_string = c10::str(
357
+ "Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
358
+ ret.emplace_back(type_string, std::move(callable));
359
+ }
360
+
361
+ return ret;
362
+ }
363
+
364
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
365
+ auto GetHipBlasLtGemmTypeStringAndOps() {
366
+ return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, GemmParams<T>>();
367
+ }
368
+
369
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
370
+ auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
371
+ return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
372
+ }
373
+
374
+ #undef TORCH_HIPBLASLT_CHECK
375
+ #undef GETINDEXFROMALGO
376
+ #undef COMPUTE_TYPE_32
377
+ #undef DATA_TYPE_R_32
378
+
379
+ } // namespace at::cuda::tunable
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Original TunableOp is from onnxruntime.
2
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4
+ // Copyright (c) Microsoft Corporation.
5
+ // Licensed under the MIT license.
6
+ //
7
+ // Adapting TunableOp into PyTorch
8
+ // Copyright (c) Advanced Micro Devices, Inc.
9
+ //
10
+ #pragma once
11
+
12
+ #include <cuda_runtime.h>
13
+
14
+ #include <ATen/cuda/tunable/Tunable.h>
15
+
16
+ namespace at::cuda::tunable {
17
+
18
+ class StreamTimer : public ITimer {
19
+ public:
20
+ StreamTimer();
21
+ virtual ~StreamTimer();
22
+
23
+ void Start() override;
24
+
25
+ void End() override;
26
+
27
+ float Duration() override;
28
+
29
+ private:
30
+ cudaEvent_t start_;
31
+ cudaEvent_t end_;
32
+ };
33
+
34
+ } // namespace at::cuda::tunable
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorConversions.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Device.h>
4
+ #include <c10/core/Layout.h>
5
+ #include <c10/core/MemoryFormat.h>
6
+ #include <c10/core/ScalarType.h>
7
+ #include <c10/util/Optional.h>
8
+
9
+ namespace at {
10
+ class Tensor;
11
+ namespace native {
12
+ bool to_will_alias(
13
+ const Tensor& self,
14
+ c10::optional<ScalarType> dtype,
15
+ c10::optional<Layout> layout,
16
+ c10::optional<Device> device,
17
+ bool copy,
18
+ c10::optional<c10::MemoryFormat> optional_memory_format);
19
+
20
+ Tensor to_meta(const Tensor& tensor);
21
+ c10::optional<Tensor> to_meta(const c10::optional<Tensor>& tensor);
22
+ std::vector<Tensor> to_meta(at::ITensorListRef t_list);
23
+ Tensor dense_to_sparse_with_mask(const Tensor& self, const Tensor& mask, c10::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, c10::optional<int64_t> dense_dim_opt);
24
+
25
+ } // namespace native
26
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/group_norm.h ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <cstdint>
5
+
6
+ namespace at {
7
+ class Tensor;
8
+
9
+ namespace native {
10
+
11
+ using forward_fn = void (*)(
12
+ const Tensor& /* X */,
13
+ const Tensor& /* gamma */,
14
+ const Tensor& /* beta */,
15
+ int64_t /* N */,
16
+ int64_t /* C */,
17
+ int64_t /* HxW */,
18
+ int64_t /* group */,
19
+ double /* eps */,
20
+ Tensor& /* Y */,
21
+ Tensor& /* mean */,
22
+ Tensor& /* rstd */);
23
+
24
+ using backward_fn = void (*)(
25
+ const Tensor& /* dY */,
26
+ const Tensor& /* X */,
27
+ const Tensor& /* mean */,
28
+ const Tensor& /* rstd */,
29
+ const Tensor& /* gamma */,
30
+ int64_t /* N */,
31
+ int64_t /* C */,
32
+ int64_t /* HxW */,
33
+ int64_t /* group */,
34
+ Tensor& /* dX */,
35
+ Tensor& /* dgamma */,
36
+ Tensor& /* dbeta */);
37
+
38
+ DECLARE_DISPATCH(forward_fn, GroupNormKernel);
39
+ DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel);
40
+
41
+ } // namespace native
42
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cdist_forward_native.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor & _cdist_forward_out(const at::Tensor & x1, const at::Tensor & x2, double p, c10::optional<int64_t> compute_mode, at::Tensor & out);
20
+ TORCH_API at::Tensor _cdist_forward(const at::Tensor & x1, const at::Tensor & x2, double p, c10::optional<int64_t> compute_mode);
21
+ } // namespace native
22
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_clamp_min_cpu_dispatch.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cpu {
19
+
20
+ TORCH_API ::std::vector<at::Tensor> _foreach_clamp_min(at::TensorList self, const at::Scalar & scalar);
21
+ TORCH_API void _foreach_clamp_min_(at::TensorList self, const at::Scalar & scalar);
22
+ TORCH_API ::std::vector<at::Tensor> _foreach_clamp_min(at::TensorList self, at::TensorList other);
23
+ TORCH_API void _foreach_clamp_min_(at::TensorList self, at::TensorList other);
24
+ TORCH_API ::std::vector<at::Tensor> _foreach_clamp_min(at::TensorList self, at::ArrayRef<at::Scalar> scalars);
25
+ TORCH_API void _foreach_clamp_min_(at::TensorList self, at::ArrayRef<at::Scalar> scalars);
26
+
27
+ } // namespace cpu
28
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_log_softmax_meta_dispatch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace meta {
19
+
20
+ TORCH_API at::Tensor _log_softmax(const at::Tensor & self, int64_t dim, bool half_to_float);
21
+ TORCH_API at::Tensor & _log_softmax_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float);
22
+ TORCH_API at::Tensor & _log_softmax_outf(const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out);
23
+
24
+ } // namespace meta
25
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_make_per_tensor_quantized_tensor_cpu_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cpu {
19
+
20
+ TORCH_API at::Tensor _make_per_tensor_quantized_tensor(const at::Tensor & self, double scale, int64_t zero_point);
21
+
22
+ } // namespace cpu
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_backward_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor & _masked_softmax_backward_out(at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional<int64_t> dim=c10::nullopt);
21
+ TORCH_API at::Tensor & _masked_softmax_backward_outf(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional<int64_t> dim, at::Tensor & out);
22
+
23
+ } // namespace compositeexplicitautograd
24
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_backward_native.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor & _masked_softmax_backward_out(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional<int64_t> dim, at::Tensor & out);
20
+ TORCH_API at::Tensor masked_softmax_backward_cpu(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional<int64_t> dim=c10::nullopt);
21
+ TORCH_API at::Tensor masked_softmax_backward_cuda(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional<int64_t> dim=c10::nullopt);
22
+ } // namespace native
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_cpu_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cpu {
19
+
20
+ TORCH_API at::Tensor _masked_softmax(const at::Tensor & self, const at::Tensor & mask, c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> mask_type=c10::nullopt);
21
+
22
+ } // namespace cpu
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nnpack_spatial_convolution_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _nnpack_spatial_convolution {
18
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &, c10::SymIntArrayRef, c10::SymIntArrayRef);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_nnpack_spatial_convolution")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride);
26
+ };
27
+
28
+ struct TORCH_API _nnpack_spatial_convolution_out {
29
+ using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &, c10::SymIntArrayRef, c10::SymIntArrayRef, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_nnpack_spatial_convolution")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!)")
35
+ static at::Tensor & call(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out);
37
+ };
38
+
39
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_scaled_dot_product_efficient_attention_backward_ops.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _scaled_dot_product_efficient_attention_backward {
18
+ using schema = ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, double, ::std::array<bool,4>, bool, c10::optional<double>);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_scaled_dot_product_efficient_attention_backward")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)")
24
+ static ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> call(const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array<bool,4> grad_input_mask, bool is_causal, c10::optional<double> scale);
25
+ static ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array<bool,4> grad_input_mask, bool is_causal, c10::optional<double> scale);
26
+ };
27
+
28
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_serialization_subcmul_ops.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _test_serialization_subcmul {
18
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_test_serialization_subcmul")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
26
+ };
27
+
28
+ }} // namespace at::_ops