koichi12 commited on
Commit
2fabb86
·
verified ·
1 Parent(s): f3e6968

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h +2 -0
  2. .venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h +2 -0
  3. .venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h +1 -0
  4. .venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h +1 -0
  5. .venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h +325 -0
  6. .venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h +17 -0
  7. .venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h +13 -0
  8. .venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h +21 -0
  9. .venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h +120 -0
  10. .venv/lib/python3.11/site-packages/torch/include/ATen/Version.h +18 -0
  11. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/AsmUtils.cuh +149 -0
  12. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContextLight.h +99 -0
  13. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h +105 -0
  14. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADevice.h +23 -0
  15. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAUtils.h +20 -0
  16. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh +121 -0
  17. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h +11 -0
  18. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h +13 -0
  19. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/cub_definitions.cuh +53 -0
  20. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h +58 -0
  21. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h +151 -0
  22. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh +124 -0
  23. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h +37 -0
  24. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h +11 -0
  25. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh +43 -0
  26. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh +116 -0
  27. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh +28 -0
  28. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h +14 -0
  29. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h +397 -0
  30. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h +611 -0
  31. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h +275 -0
  32. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h +34 -0
  33. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h +246 -0
  34. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h +307 -0
  35. .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h +286 -0
  36. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h +98 -0
  37. .venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h +49 -0
  38. .venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h +321 -0
  39. .venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h +119 -0
  40. .venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h +97 -0
  41. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Distance.h +20 -0
  42. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h +21 -0
  43. .venv/lib/python3.11/site-packages/torch/include/ATen/native/FractionalMaxPooling.h +80 -0
  44. .venv/lib/python3.11/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h +20 -0
  45. .venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSampler.h +298 -0
  46. .venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h +69 -0
  47. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h +0 -0
  48. .venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h +71 -0
  49. .venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h +27 -0
  50. .venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h +19 -0
.venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <c10/util/ArrayRef.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/DimVector.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <ATen/core/Dimname.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <ATen/core/Formatting.h>
.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
12
+ #error This change adds a dependency on all pytorch operators, meaning the \
13
+ file will need to be re-compiled every time an operator is changed or added. \
14
+ Consider including a specific operator from \
15
+ <ATen/ops/{my_operator}_meta_dispatch.h>. \
16
+ See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
17
+ #endif
18
+
19
+ #include <ATen/ops/_add_relu_meta_dispatch.h>
20
+ #include <ATen/ops/_addmm_activation_meta_dispatch.h>
21
+ #include <ATen/ops/_amp_update_scale_meta_dispatch.h>
22
+ #include <ATen/ops/_coalesced_meta_dispatch.h>
23
+ #include <ATen/ops/_convert_indices_from_coo_to_csr_meta_dispatch.h>
24
+ #include <ATen/ops/_convert_indices_from_csr_to_coo_meta_dispatch.h>
25
+ #include <ATen/ops/_ctc_loss_meta_dispatch.h>
26
+ #include <ATen/ops/_efficientzerotensor_meta_dispatch.h>
27
+ #include <ATen/ops/_fill_mem_eff_dropout_mask_meta_dispatch.h>
28
+ #include <ATen/ops/_fused_sdp_choice_meta_dispatch.h>
29
+ #include <ATen/ops/_index_put_impl_meta_dispatch.h>
30
+ #include <ATen/ops/_linalg_det_meta_dispatch.h>
31
+ #include <ATen/ops/_linalg_eigh_meta_dispatch.h>
32
+ #include <ATen/ops/_linalg_slogdet_meta_dispatch.h>
33
+ #include <ATen/ops/_linalg_solve_ex_meta_dispatch.h>
34
+ #include <ATen/ops/_linalg_svd_meta_dispatch.h>
35
+ #include <ATen/ops/_log_softmax_meta_dispatch.h>
36
+ #include <ATen/ops/_log_softmax_backward_data_meta_dispatch.h>
37
+ #include <ATen/ops/_mkldnn_transpose_meta_dispatch.h>
38
+ #include <ATen/ops/_reshape_alias_meta_dispatch.h>
39
+ #include <ATen/ops/_resize_output_meta_dispatch.h>
40
+ #include <ATen/ops/_softmax_meta_dispatch.h>
41
+ #include <ATen/ops/_softmax_backward_data_meta_dispatch.h>
42
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_meta_dispatch.h>
43
+ #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta_dispatch.h>
44
+ #include <ATen/ops/_upsample_bicubic2d_aa_meta_dispatch.h>
45
+ #include <ATen/ops/_upsample_bicubic2d_aa_backward_meta_dispatch.h>
46
+ #include <ATen/ops/_upsample_bilinear2d_aa_meta_dispatch.h>
47
+ #include <ATen/ops/_upsample_bilinear2d_aa_backward_meta_dispatch.h>
48
+ #include <ATen/ops/_upsample_nearest_exact1d_meta_dispatch.h>
49
+ #include <ATen/ops/_upsample_nearest_exact1d_backward_meta_dispatch.h>
50
+ #include <ATen/ops/_upsample_nearest_exact2d_meta_dispatch.h>
51
+ #include <ATen/ops/_upsample_nearest_exact2d_backward_meta_dispatch.h>
52
+ #include <ATen/ops/_upsample_nearest_exact3d_meta_dispatch.h>
53
+ #include <ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h>
54
+ #include <ATen/ops/acos_meta_dispatch.h>
55
+ #include <ATen/ops/acosh_meta_dispatch.h>
56
+ #include <ATen/ops/adaptive_max_pool2d_meta_dispatch.h>
57
+ #include <ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h>
58
+ #include <ATen/ops/adaptive_max_pool3d_meta_dispatch.h>
59
+ #include <ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h>
60
+ #include <ATen/ops/add_meta_dispatch.h>
61
+ #include <ATen/ops/addbmm_meta_dispatch.h>
62
+ #include <ATen/ops/addcdiv_meta_dispatch.h>
63
+ #include <ATen/ops/addcmul_meta_dispatch.h>
64
+ #include <ATen/ops/addmm_meta_dispatch.h>
65
+ #include <ATen/ops/addmv_meta_dispatch.h>
66
+ #include <ATen/ops/all_meta_dispatch.h>
67
+ #include <ATen/ops/amax_meta_dispatch.h>
68
+ #include <ATen/ops/amin_meta_dispatch.h>
69
+ #include <ATen/ops/aminmax_meta_dispatch.h>
70
+ #include <ATen/ops/any_meta_dispatch.h>
71
+ #include <ATen/ops/arange_meta_dispatch.h>
72
+ #include <ATen/ops/argmax_meta_dispatch.h>
73
+ #include <ATen/ops/argmin_meta_dispatch.h>
74
+ #include <ATen/ops/as_strided_meta_dispatch.h>
75
+ #include <ATen/ops/asin_meta_dispatch.h>
76
+ #include <ATen/ops/asinh_meta_dispatch.h>
77
+ #include <ATen/ops/atan_meta_dispatch.h>
78
+ #include <ATen/ops/atan2_meta_dispatch.h>
79
+ #include <ATen/ops/atanh_meta_dispatch.h>
80
+ #include <ATen/ops/avg_pool2d_meta_dispatch.h>
81
+ #include <ATen/ops/avg_pool2d_backward_meta_dispatch.h>
82
+ #include <ATen/ops/avg_pool3d_meta_dispatch.h>
83
+ #include <ATen/ops/avg_pool3d_backward_meta_dispatch.h>
84
+ #include <ATen/ops/baddbmm_meta_dispatch.h>
85
+ #include <ATen/ops/bernoulli_meta_dispatch.h>
86
+ #include <ATen/ops/bitwise_and_meta_dispatch.h>
87
+ #include <ATen/ops/bitwise_left_shift_meta_dispatch.h>
88
+ #include <ATen/ops/bitwise_not_meta_dispatch.h>
89
+ #include <ATen/ops/bitwise_or_meta_dispatch.h>
90
+ #include <ATen/ops/bitwise_right_shift_meta_dispatch.h>
91
+ #include <ATen/ops/bitwise_xor_meta_dispatch.h>
92
+ #include <ATen/ops/bmm_meta_dispatch.h>
93
+ #include <ATen/ops/cat_meta_dispatch.h>
94
+ #include <ATen/ops/cauchy_meta_dispatch.h>
95
+ #include <ATen/ops/ceil_meta_dispatch.h>
96
+ #include <ATen/ops/clamp_meta_dispatch.h>
97
+ #include <ATen/ops/clamp_max_meta_dispatch.h>
98
+ #include <ATen/ops/clamp_min_meta_dispatch.h>
99
+ #include <ATen/ops/copy_meta_dispatch.h>
100
+ #include <ATen/ops/copy_sparse_to_sparse_meta_dispatch.h>
101
+ #include <ATen/ops/copysign_meta_dispatch.h>
102
+ #include <ATen/ops/cos_meta_dispatch.h>
103
+ #include <ATen/ops/cosh_meta_dispatch.h>
104
+ #include <ATen/ops/cumprod_meta_dispatch.h>
105
+ #include <ATen/ops/cumsum_meta_dispatch.h>
106
+ #include <ATen/ops/digamma_meta_dispatch.h>
107
+ #include <ATen/ops/div_meta_dispatch.h>
108
+ #include <ATen/ops/elu_meta_dispatch.h>
109
+ #include <ATen/ops/elu_backward_meta_dispatch.h>
110
+ #include <ATen/ops/embedding_renorm_meta_dispatch.h>
111
+ #include <ATen/ops/empty_meta_dispatch.h>
112
+ #include <ATen/ops/empty_strided_meta_dispatch.h>
113
+ #include <ATen/ops/eq_meta_dispatch.h>
114
+ #include <ATen/ops/erf_meta_dispatch.h>
115
+ #include <ATen/ops/erfc_meta_dispatch.h>
116
+ #include <ATen/ops/erfinv_meta_dispatch.h>
117
+ #include <ATen/ops/exp_meta_dispatch.h>
118
+ #include <ATen/ops/exp2_meta_dispatch.h>
119
+ #include <ATen/ops/expm1_meta_dispatch.h>
120
+ #include <ATen/ops/exponential_meta_dispatch.h>
121
+ #include <ATen/ops/eye_meta_dispatch.h>
122
+ #include <ATen/ops/fill_meta_dispatch.h>
123
+ #include <ATen/ops/floor_meta_dispatch.h>
124
+ #include <ATen/ops/floor_divide_meta_dispatch.h>
125
+ #include <ATen/ops/fmax_meta_dispatch.h>
126
+ #include <ATen/ops/fmin_meta_dispatch.h>
127
+ #include <ATen/ops/fmod_meta_dispatch.h>
128
+ #include <ATen/ops/frac_meta_dispatch.h>
129
+ #include <ATen/ops/fractional_max_pool2d_meta_dispatch.h>
130
+ #include <ATen/ops/fractional_max_pool2d_backward_meta_dispatch.h>
131
+ #include <ATen/ops/fractional_max_pool3d_meta_dispatch.h>
132
+ #include <ATen/ops/gather_meta_dispatch.h>
133
+ #include <ATen/ops/gcd_meta_dispatch.h>
134
+ #include <ATen/ops/ge_meta_dispatch.h>
135
+ #include <ATen/ops/gelu_meta_dispatch.h>
136
+ #include <ATen/ops/gelu_backward_meta_dispatch.h>
137
+ #include <ATen/ops/geometric_meta_dispatch.h>
138
+ #include <ATen/ops/glu_meta_dispatch.h>
139
+ #include <ATen/ops/gt_meta_dispatch.h>
140
+ #include <ATen/ops/hardshrink_meta_dispatch.h>
141
+ #include <ATen/ops/hardshrink_backward_meta_dispatch.h>
142
+ #include <ATen/ops/hardsigmoid_meta_dispatch.h>
143
+ #include <ATen/ops/hardsigmoid_backward_meta_dispatch.h>
144
+ #include <ATen/ops/hardswish_meta_dispatch.h>
145
+ #include <ATen/ops/hardtanh_meta_dispatch.h>
146
+ #include <ATen/ops/heaviside_meta_dispatch.h>
147
+ #include <ATen/ops/hypot_meta_dispatch.h>
148
+ #include <ATen/ops/i0_meta_dispatch.h>
149
+ #include <ATen/ops/igamma_meta_dispatch.h>
150
+ #include <ATen/ops/igammac_meta_dispatch.h>
151
+ #include <ATen/ops/index_meta_dispatch.h>
152
+ #include <ATen/ops/index_add_meta_dispatch.h>
153
+ #include <ATen/ops/index_copy_meta_dispatch.h>
154
+ #include <ATen/ops/index_fill_meta_dispatch.h>
155
+ #include <ATen/ops/index_reduce_meta_dispatch.h>
156
+ #include <ATen/ops/isin_meta_dispatch.h>
157
+ #include <ATen/ops/isneginf_meta_dispatch.h>
158
+ #include <ATen/ops/isposinf_meta_dispatch.h>
159
+ #include <ATen/ops/lcm_meta_dispatch.h>
160
+ #include <ATen/ops/le_meta_dispatch.h>
161
+ #include <ATen/ops/leaky_relu_meta_dispatch.h>
162
+ #include <ATen/ops/leaky_relu_backward_meta_dispatch.h>
163
+ #include <ATen/ops/lerp_meta_dispatch.h>
164
+ #include <ATen/ops/lgamma_meta_dispatch.h>
165
+ #include <ATen/ops/linalg_cholesky_ex_meta_dispatch.h>
166
+ #include <ATen/ops/linalg_cross_meta_dispatch.h>
167
+ #include <ATen/ops/linalg_inv_ex_meta_dispatch.h>
168
+ #include <ATen/ops/linalg_ldl_factor_ex_meta_dispatch.h>
169
+ #include <ATen/ops/linalg_ldl_solve_meta_dispatch.h>
170
+ #include <ATen/ops/linalg_lu_meta_dispatch.h>
171
+ #include <ATen/ops/linalg_lu_factor_ex_meta_dispatch.h>
172
+ #include <ATen/ops/linalg_lu_solve_meta_dispatch.h>
173
+ #include <ATen/ops/linalg_qr_meta_dispatch.h>
174
+ #include <ATen/ops/linalg_vector_norm_meta_dispatch.h>
175
+ #include <ATen/ops/linspace_meta_dispatch.h>
176
+ #include <ATen/ops/log_meta_dispatch.h>
177
+ #include <ATen/ops/log10_meta_dispatch.h>
178
+ #include <ATen/ops/log1p_meta_dispatch.h>
179
+ #include <ATen/ops/log2_meta_dispatch.h>
180
+ #include <ATen/ops/log_normal_meta_dispatch.h>
181
+ #include <ATen/ops/logaddexp_meta_dispatch.h>
182
+ #include <ATen/ops/logaddexp2_meta_dispatch.h>
183
+ #include <ATen/ops/logit_meta_dispatch.h>
184
+ #include <ATen/ops/logit_backward_meta_dispatch.h>
185
+ #include <ATen/ops/logspace_meta_dispatch.h>
186
+ #include <ATen/ops/lshift_meta_dispatch.h>
187
+ #include <ATen/ops/lt_meta_dispatch.h>
188
+ #include <ATen/ops/lu_unpack_meta_dispatch.h>
189
+ #include <ATen/ops/masked_fill_meta_dispatch.h>
190
+ #include <ATen/ops/masked_scatter_meta_dispatch.h>
191
+ #include <ATen/ops/max_meta_dispatch.h>
192
+ #include <ATen/ops/max_pool2d_with_indices_meta_dispatch.h>
193
+ #include <ATen/ops/max_pool2d_with_indices_backward_meta_dispatch.h>
194
+ #include <ATen/ops/maximum_meta_dispatch.h>
195
+ #include <ATen/ops/mean_meta_dispatch.h>
196
+ #include <ATen/ops/min_meta_dispatch.h>
197
+ #include <ATen/ops/minimum_meta_dispatch.h>
198
+ #include <ATen/ops/mish_meta_dispatch.h>
199
+ #include <ATen/ops/mm_meta_dispatch.h>
200
+ #include <ATen/ops/mse_loss_meta_dispatch.h>
201
+ #include <ATen/ops/mul_meta_dispatch.h>
202
+ #include <ATen/ops/ne_meta_dispatch.h>
203
+ #include <ATen/ops/neg_meta_dispatch.h>
204
+ #include <ATen/ops/nextafter_meta_dispatch.h>
205
+ #include <ATen/ops/nll_loss_backward_meta_dispatch.h>
206
+ #include <ATen/ops/nll_loss_forward_meta_dispatch.h>
207
+ #include <ATen/ops/norm_meta_dispatch.h>
208
+ #include <ATen/ops/normal_meta_dispatch.h>
209
+ #include <ATen/ops/polygamma_meta_dispatch.h>
210
+ #include <ATen/ops/pow_meta_dispatch.h>
211
+ #include <ATen/ops/prod_meta_dispatch.h>
212
+ #include <ATen/ops/put_meta_dispatch.h>
213
+ #include <ATen/ops/random_meta_dispatch.h>
214
+ #include <ATen/ops/range_meta_dispatch.h>
215
+ #include <ATen/ops/reciprocal_meta_dispatch.h>
216
+ #include <ATen/ops/reflection_pad1d_meta_dispatch.h>
217
+ #include <ATen/ops/reflection_pad1d_backward_meta_dispatch.h>
218
+ #include <ATen/ops/reflection_pad3d_meta_dispatch.h>
219
+ #include <ATen/ops/reflection_pad3d_backward_meta_dispatch.h>
220
+ #include <ATen/ops/relu_meta_dispatch.h>
221
+ #include <ATen/ops/remainder_meta_dispatch.h>
222
+ #include <ATen/ops/renorm_meta_dispatch.h>
223
+ #include <ATen/ops/replication_pad1d_meta_dispatch.h>
224
+ #include <ATen/ops/replication_pad1d_backward_meta_dispatch.h>
225
+ #include <ATen/ops/replication_pad2d_meta_dispatch.h>
226
+ #include <ATen/ops/replication_pad3d_meta_dispatch.h>
227
+ #include <ATen/ops/resize_meta_dispatch.h>
228
+ #include <ATen/ops/resize_as_sparse_meta_dispatch.h>
229
+ #include <ATen/ops/round_meta_dispatch.h>
230
+ #include <ATen/ops/rrelu_with_noise_meta_dispatch.h>
231
+ #include <ATen/ops/rshift_meta_dispatch.h>
232
+ #include <ATen/ops/rsqrt_meta_dispatch.h>
233
+ #include <ATen/ops/scatter_meta_dispatch.h>
234
+ #include <ATen/ops/scatter_add_meta_dispatch.h>
235
+ #include <ATen/ops/scatter_reduce_meta_dispatch.h>
236
+ #include <ATen/ops/set_meta_dispatch.h>
237
+ #include <ATen/ops/sgn_meta_dispatch.h>
238
+ #include <ATen/ops/sigmoid_meta_dispatch.h>
239
+ #include <ATen/ops/sigmoid_backward_meta_dispatch.h>
240
+ #include <ATen/ops/sign_meta_dispatch.h>
241
+ #include <ATen/ops/signbit_meta_dispatch.h>
242
+ #include <ATen/ops/silu_meta_dispatch.h>
243
+ #include <ATen/ops/silu_backward_meta_dispatch.h>
244
+ #include <ATen/ops/sin_meta_dispatch.h>
245
+ #include <ATen/ops/sinc_meta_dispatch.h>
246
+ #include <ATen/ops/sinh_meta_dispatch.h>
247
+ #include <ATen/ops/slow_conv_transpose2d_meta_dispatch.h>
248
+ #include <ATen/ops/smooth_l1_loss_meta_dispatch.h>
249
+ #include <ATen/ops/softplus_meta_dispatch.h>
250
+ #include <ATen/ops/softplus_backward_meta_dispatch.h>
251
+ #include <ATen/ops/softshrink_meta_dispatch.h>
252
+ #include <ATen/ops/softshrink_backward_meta_dispatch.h>
253
+ #include <ATen/ops/sort_meta_dispatch.h>
254
+ #include <ATen/ops/sparse_resize_meta_dispatch.h>
255
+ #include <ATen/ops/sparse_resize_and_clear_meta_dispatch.h>
256
+ #include <ATen/ops/special_airy_ai_meta_dispatch.h>
257
+ #include <ATen/ops/special_bessel_j0_meta_dispatch.h>
258
+ #include <ATen/ops/special_bessel_j1_meta_dispatch.h>
259
+ #include <ATen/ops/special_bessel_y0_meta_dispatch.h>
260
+ #include <ATen/ops/special_bessel_y1_meta_dispatch.h>
261
+ #include <ATen/ops/special_chebyshev_polynomial_t_meta_dispatch.h>
262
+ #include <ATen/ops/special_chebyshev_polynomial_u_meta_dispatch.h>
263
+ #include <ATen/ops/special_chebyshev_polynomial_v_meta_dispatch.h>
264
+ #include <ATen/ops/special_chebyshev_polynomial_w_meta_dispatch.h>
265
+ #include <ATen/ops/special_entr_meta_dispatch.h>
266
+ #include <ATen/ops/special_erfcx_meta_dispatch.h>
267
+ #include <ATen/ops/special_hermite_polynomial_h_meta_dispatch.h>
268
+ #include <ATen/ops/special_hermite_polynomial_he_meta_dispatch.h>
269
+ #include <ATen/ops/special_i0e_meta_dispatch.h>
270
+ #include <ATen/ops/special_i1_meta_dispatch.h>
271
+ #include <ATen/ops/special_i1e_meta_dispatch.h>
272
+ #include <ATen/ops/special_laguerre_polynomial_l_meta_dispatch.h>
273
+ #include <ATen/ops/special_legendre_polynomial_p_meta_dispatch.h>
274
+ #include <ATen/ops/special_log_ndtr_meta_dispatch.h>
275
+ #include <ATen/ops/special_modified_bessel_i0_meta_dispatch.h>
276
+ #include <ATen/ops/special_modified_bessel_i1_meta_dispatch.h>
277
+ #include <ATen/ops/special_modified_bessel_k0_meta_dispatch.h>
278
+ #include <ATen/ops/special_modified_bessel_k1_meta_dispatch.h>
279
+ #include <ATen/ops/special_ndtri_meta_dispatch.h>
280
+ #include <ATen/ops/special_scaled_modified_bessel_k0_meta_dispatch.h>
281
+ #include <ATen/ops/special_scaled_modified_bessel_k1_meta_dispatch.h>
282
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta_dispatch.h>
283
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta_dispatch.h>
284
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta_dispatch.h>
285
+ #include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta_dispatch.h>
286
+ #include <ATen/ops/special_spherical_bessel_j0_meta_dispatch.h>
287
+ #include <ATen/ops/special_xlog1py_meta_dispatch.h>
288
+ #include <ATen/ops/special_zeta_meta_dispatch.h>
289
+ #include <ATen/ops/sqrt_meta_dispatch.h>
290
+ #include <ATen/ops/sub_meta_dispatch.h>
291
+ #include <ATen/ops/sum_meta_dispatch.h>
292
+ #include <ATen/ops/tan_meta_dispatch.h>
293
+ #include <ATen/ops/tanh_meta_dispatch.h>
294
+ #include <ATen/ops/tanh_backward_meta_dispatch.h>
295
+ #include <ATen/ops/threshold_meta_dispatch.h>
296
+ #include <ATen/ops/threshold_backward_meta_dispatch.h>
297
+ #include <ATen/ops/topk_meta_dispatch.h>
298
+ #include <ATen/ops/triangular_solve_meta_dispatch.h>
299
+ #include <ATen/ops/tril_meta_dispatch.h>
300
+ #include <ATen/ops/triu_meta_dispatch.h>
301
+ #include <ATen/ops/trunc_meta_dispatch.h>
302
+ #include <ATen/ops/unfold_meta_dispatch.h>
303
+ #include <ATen/ops/uniform_meta_dispatch.h>
304
+ #include <ATen/ops/upsample_bicubic2d_meta_dispatch.h>
305
+ #include <ATen/ops/upsample_bicubic2d_backward_meta_dispatch.h>
306
+ #include <ATen/ops/upsample_bilinear2d_meta_dispatch.h>
307
+ #include <ATen/ops/upsample_bilinear2d_backward_meta_dispatch.h>
308
+ #include <ATen/ops/upsample_linear1d_meta_dispatch.h>
309
+ #include <ATen/ops/upsample_linear1d_backward_meta_dispatch.h>
310
+ #include <ATen/ops/upsample_nearest1d_meta_dispatch.h>
311
+ #include <ATen/ops/upsample_nearest1d_backward_meta_dispatch.h>
312
+ #include <ATen/ops/upsample_nearest2d_meta_dispatch.h>
313
+ #include <ATen/ops/upsample_nearest2d_backward_meta_dispatch.h>
314
+ #include <ATen/ops/upsample_nearest3d_meta_dispatch.h>
315
+ #include <ATen/ops/upsample_nearest3d_backward_meta_dispatch.h>
316
+ #include <ATen/ops/upsample_trilinear3d_meta_dispatch.h>
317
+ #include <ATen/ops/upsample_trilinear3d_backward_meta_dispatch.h>
318
+ #include <ATen/ops/view_meta_dispatch.h>
319
+ #include <ATen/ops/view_as_complex_meta_dispatch.h>
320
+ #include <ATen/ops/view_as_real_meta_dispatch.h>
321
+ #include <ATen/ops/xlogy_meta_dispatch.h>
322
+ #include <ATen/ops/zero_meta_dispatch.h>
323
+
324
+
325
+
.venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Parallel.h>
4
+ #include <c10/core/thread_pool.h>
5
+
6
+ namespace at {
7
+
8
+ class TORCH_API PTThreadPool : public c10::ThreadPool {
9
+ public:
10
+ explicit PTThreadPool(int pool_size, int numa_node_id = -1)
11
+ : c10::ThreadPool(pool_size, numa_node_id, []() {
12
+ c10::setThreadName("PTThreadPool");
13
+ at::init_num_threads();
14
+ }) {}
15
+ };
16
+
17
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Export.h>
4
+ #include <cstdint>
5
+
6
+ // A simple thread local enumeration, used to link forward and backward pass
7
+ // ops and is used by autograd and observers framework
8
+ namespace at::sequence_number {
9
+
10
+ TORCH_API uint64_t peek();
11
+ TORCH_API uint64_t get_and_increment();
12
+
13
+ } // namespace at::sequence_number
.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/SafePyObject.h>
4
+ #include <c10/macros/Macros.h>
5
+ #include <unordered_map>
6
+
7
+ namespace at::impl {
8
+
9
+ struct TORCH_API ThreadLocalPythonObjects {
10
+ static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
11
+ static const std::shared_ptr<SafePyObject>& get(const std::string& key);
12
+ static bool contains(const std::string& key);
13
+
14
+ static const ThreadLocalPythonObjects& get_state();
15
+ static void set_state(ThreadLocalPythonObjects state);
16
+
17
+ private:
18
+ std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
19
+ };
20
+
21
+ } // namespace at::impl
.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/InferenceMode.h>
4
+ #include <c10/core/impl/LocalDispatchKeySet.h>
5
+ #include <c10/util/Exception.h>
6
+ #include <c10/util/ThreadLocalDebugInfo.h>
7
+
8
+ #include <ATen/FuncTorchTLS.h>
9
+ #include <ATen/PythonTorchFunctionTLS.h>
10
+ #include <ATen/SavedTensorHooks.h>
11
+ #include <ATen/ThreadLocalPythonObjects.h>
12
+ #include <ATen/record_function.h>
13
+ #include <c10/core/impl/PythonDispatcherTLS.h>
14
+ #include <c10/core/impl/TorchDispatchModeTLS.h>
15
+
16
+ namespace at {
17
+
18
+ // Thread local state contains values that are preserved across
19
+ // thread boundaries (e.g. at::launch/JIT fork, autograd).
20
+ // Note at::parallel_for doesn't preserve TLS across thread boundaries.
21
+ class TORCH_API ThreadLocalState {
22
+ public:
23
+ // Saves the thread local variables' values and
24
+ // returns them as a ThreadLocalState
25
+ ThreadLocalState();
26
+
27
+ // set_grad_mode - force the value of the grad mode TLS in
28
+ // the current state object. This is used for example in the
29
+ // autograd engine.
30
+ void set_grad_mode(bool enabled);
31
+
32
+ // set_multithreading_enabled - force the value of the multithreadinmaximum
33
+ // threads TLS in
34
+ // the current state object. This is used for example in the
35
+ // autograd engine.
36
+ void set_multithreading_enabled(bool enabled);
37
+
38
+ // Sets thread local variables in the current thread,
39
+ // according to the thread boundary specified
40
+ static void setThreadLocalState(const ThreadLocalState& state);
41
+
42
+ private:
43
+ c10::impl::LocalDispatchKeySet dispatch_key_;
44
+
45
+ // ThreadLocalDebugInfo does not change after being created
46
+ // with DebugInfoGuard
47
+ std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
48
+
49
+ // RecordFunction TLS
50
+ RecordFunctionTLS rf_tls_;
51
+
52
+ // TLS for out-of-tree functorch
53
+ // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
54
+ // pointer (spoiler alert: it's due to the indirection)
55
+ // This needs to be a shared_ptr instead of a unique_ptr because
56
+ // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
57
+ // consider adding an explicit copy constructor for ThreadLocalState in the
58
+ // future but I didn't want to add one just for this.
59
+ std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
60
+
61
+ // TLS for AutogradModes
62
+ AutogradState autograd_tls_;
63
+
64
+ // TLS for enable_torch_dispatch_mode
65
+ c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
66
+
67
+ // TLS for enable_python_dispatcher
68
+ c10::impl::PyInterpreter* python_dispatcher_state_;
69
+
70
+ // TLS for __torch_function__ (mode and disable_torch_function)
71
+ at::impl::PythonTorchFunctionTLS python_torch_function_state_;
72
+
73
+ // TLS for saved tensors default hooks
74
+ at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
75
+
76
+ bool functionalization_reapply_views_state_;
77
+
78
+ // TLS for arbitrary python objects that is registered via hooks
79
+ at::impl::ThreadLocalPythonObjects saved_objects_;
80
+
81
+ #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \
82
+ !defined(BUILD_LITE_INTERPRETER)
83
+ // TLS for autocast dtypes
84
+ std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
85
+ autocast_dtypes_;
86
+ #endif
87
+
88
+ friend class ThreadLocalStateGuard;
89
+ };
90
+
91
+ // Guard to set and reset the thread local state
92
+ class TORCH_API ThreadLocalStateGuard {
93
+ public:
94
+ explicit ThreadLocalStateGuard(const ThreadLocalState& state)
95
+ : prev_state_(ThreadLocalState()) {
96
+ // set the given state across the thread boundary
97
+ ThreadLocalState::setThreadLocalState(state);
98
+ }
99
+
100
+ ~ThreadLocalStateGuard() {
101
+ // restore previously set variables
102
+ ThreadLocalState::setThreadLocalState(prev_state_);
103
+ }
104
+
105
+ private:
106
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
107
+ const ThreadLocalState prev_state_;
108
+ };
109
+
110
+ template <typename T>
111
+ auto wrapPropagateTLSState(T callback) {
112
+ return [tls_state = ThreadLocalState(),
113
+ callback = std::move(callback)](auto&&... args) {
114
+ ThreadLocalStateGuard g(tls_state);
115
+ // Propagate value returned by callback().
116
+ return callback(std::forward<decltype(args)>(args)...);
117
+ };
118
+ }
119
+
120
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/Version.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/Context.h>
2
+
3
+ namespace at {
4
+
5
+ /// Returns a detailed string describing the configuration PyTorch.
6
+ TORCH_API std::string show_config();
7
+
8
+ TORCH_API std::string get_mkl_version();
9
+
10
+ TORCH_API std::string get_mkldnn_version();
11
+
12
+ TORCH_API std::string get_openmp_version();
13
+
14
+ TORCH_API std::string get_cxx_flags();
15
+
16
+ TORCH_API std::string get_cpu_capability();
17
+
18
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/AsmUtils.cuh ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <cstdint>
3
+
4
+ // Collection of direct PTX functions
5
+
6
+ namespace at::cuda {
7
+
8
+ template <typename T>
9
+ struct Bitfield {};
10
+
11
+ template <>
12
+ struct Bitfield<unsigned int> {
13
+ static __device__ __host__ __forceinline__
14
+ unsigned int getBitfield(unsigned int val, int pos, int len) {
15
+ #if !defined(__CUDA_ARCH__)
16
+ pos &= 0xff;
17
+ len &= 0xff;
18
+
19
+ unsigned int m = (1u << len) - 1u;
20
+ return (val >> pos) & m;
21
+ #else
22
+ unsigned int ret;
23
+ asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
24
+ return ret;
25
+ #endif
26
+ }
27
+
28
+ static __device__ __host__ __forceinline__
29
+ unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
30
+ #if !defined(__CUDA_ARCH__)
31
+ pos &= 0xff;
32
+ len &= 0xff;
33
+
34
+ unsigned int m = (1u << len) - 1u;
35
+ toInsert &= m;
36
+ toInsert <<= pos;
37
+ m <<= pos;
38
+
39
+ return (val & ~m) | toInsert;
40
+ #else
41
+ unsigned int ret;
42
+ asm("bfi.b32 %0, %1, %2, %3, %4;" :
43
+ "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
44
+ return ret;
45
+ #endif
46
+ }
47
+ };
48
+
49
+ template <>
50
+ struct Bitfield<uint64_t> {
51
+ static __device__ __host__ __forceinline__
52
+ uint64_t getBitfield(uint64_t val, int pos, int len) {
53
+ #if !defined(__CUDA_ARCH__)
54
+ pos &= 0xff;
55
+ len &= 0xff;
56
+
57
+ uint64_t m = (1u << len) - 1u;
58
+ return (val >> pos) & m;
59
+ #else
60
+ uint64_t ret;
61
+ asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
62
+ return ret;
63
+ #endif
64
+ }
65
+
66
+ static __device__ __host__ __forceinline__
67
+ uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
68
+ #if !defined(__CUDA_ARCH__)
69
+ pos &= 0xff;
70
+ len &= 0xff;
71
+
72
+ uint64_t m = (1u << len) - 1u;
73
+ toInsert &= m;
74
+ toInsert <<= pos;
75
+ m <<= pos;
76
+
77
+ return (val & ~m) | toInsert;
78
+ #else
79
+ uint64_t ret;
80
+ asm("bfi.b64 %0, %1, %2, %3, %4;" :
81
+ "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
82
+ return ret;
83
+ #endif
84
+ }
85
+ };
86
+
87
+ __device__ __forceinline__ int getLaneId() {
88
+ #if defined(USE_ROCM)
89
+ return __lane_id();
90
+ #else
91
+ int laneId;
92
+ asm("mov.s32 %0, %%laneid;" : "=r"(laneId) );
93
+ return laneId;
94
+ #endif
95
+ }
96
+
97
+ #if defined(USE_ROCM)
98
+ __device__ __forceinline__ unsigned long long int getLaneMaskLt() {
99
+ const std::uint64_t m = (1ull << getLaneId()) - 1ull;
100
+ return m;
101
+ }
102
+ #else
103
+ __device__ __forceinline__ unsigned getLaneMaskLt() {
104
+ unsigned mask;
105
+ asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
106
+ return mask;
107
+ }
108
+ #endif
109
+
110
+ #if defined (USE_ROCM)
111
+ __device__ __forceinline__ unsigned long long int getLaneMaskLe() {
112
+ std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
113
+ return m;
114
+ }
115
+ #else
116
+ __device__ __forceinline__ unsigned getLaneMaskLe() {
117
+ unsigned mask;
118
+ asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
119
+ return mask;
120
+ }
121
+ #endif
122
+
123
+ #if defined(USE_ROCM)
124
+ __device__ __forceinline__ unsigned long long int getLaneMaskGt() {
125
+ const std::uint64_t m = getLaneMaskLe();
126
+ return m ? ~m : m;
127
+ }
128
+ #else
129
+ __device__ __forceinline__ unsigned getLaneMaskGt() {
130
+ unsigned mask;
131
+ asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
132
+ return mask;
133
+ }
134
+ #endif
135
+
136
+ #if defined(USE_ROCM)
137
+ __device__ __forceinline__ unsigned long long int getLaneMaskGe() {
138
+ const std::uint64_t m = getLaneMaskLt();
139
+ return ~m;
140
+ }
141
+ #else
142
+ __device__ __forceinline__ unsigned getLaneMaskGe() {
143
+ unsigned mask;
144
+ asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
145
+ return mask;
146
+ }
147
+ #endif
148
+
149
+ } // namespace at::cuda
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContextLight.h ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // Light-weight version of CUDAContext.h with fewer transitive includes
3
+
4
+ #include <cstdint>
5
+
6
+ #include <cuda_runtime_api.h>
7
+ #include <cusparse.h>
8
+ #include <cublas_v2.h>
9
+
10
+ // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
11
+ // added bf16 support
12
+ #include <cublasLt.h>
13
+
14
+ #ifdef CUDART_VERSION
15
+ #include <cusolverDn.h>
16
+ #endif
17
+
18
+ #if defined(USE_CUDSS)
19
+ #include <cudss.h>
20
+ #endif
21
+
22
+ #if defined(USE_ROCM)
23
+ #include <hipsolver/hipsolver.h>
24
+ #endif
25
+
26
+ #include <c10/core/Allocator.h>
27
+ #include <c10/cuda/CUDAFunctions.h>
28
+
29
+ namespace c10 {
30
+ struct Allocator;
31
+ }
32
+
33
+ namespace at::cuda {
34
+
35
+ /*
36
+ A common CUDA interface for ATen.
37
+
38
+ This interface is distinct from CUDAHooks, which defines an interface that links
39
+ to both CPU-only and CUDA builds. That interface is intended for runtime
40
+ dispatch and should be used from files that are included in both CPU-only and
41
+ CUDA builds.
42
+
43
+ CUDAContext, on the other hand, should be preferred by files only included in
44
+ CUDA builds. It is intended to expose CUDA functionality in a consistent
45
+ manner.
46
+
47
+ This means there is some overlap between the CUDAContext and CUDAHooks, but
48
+ the choice of which to use is simple: use CUDAContext when in a CUDA-only file,
49
+ use CUDAHooks otherwise.
50
+
51
+ Note that CUDAContext simply defines an interface with no associated class.
52
+ It is expected that the modules whose functions compose this interface will
53
+ manage their own state. There is only a single CUDA context/state.
54
+ */
55
+
56
+ /**
57
+ * DEPRECATED: use device_count() instead
58
+ */
59
+ inline int64_t getNumGPUs() {
60
+ return c10::cuda::device_count();
61
+ }
62
+
63
+ /**
64
+ * CUDA is available if we compiled with CUDA, and there are one or more
65
+ * devices. If we compiled with CUDA but there is a driver problem, etc.,
66
+ * this function will report CUDA is not available (rather than raise an error.)
67
+ */
68
+ inline bool is_available() {
69
+ return c10::cuda::device_count() > 0;
70
+ }
71
+
72
+ TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties();
73
+
74
+ TORCH_CUDA_CPP_API int warp_size();
75
+
76
+ TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device);
77
+
78
+ TORCH_CUDA_CPP_API bool canDeviceAccessPeer(
79
+ c10::DeviceIndex device,
80
+ c10::DeviceIndex peer_device);
81
+
82
+ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
83
+
84
+ /* Handles */
85
+ TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
86
+ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
87
+ TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
88
+
89
+ TORCH_CUDA_CPP_API void clearCublasWorkspaces();
90
+
91
+ #if defined(CUDART_VERSION) || defined(USE_ROCM)
92
+ TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
93
+ #endif
94
+
95
+ #if defined(USE_CUDSS)
96
+ TORCH_CUDA_CPP_API cudssHandle_t getCurrentCudssHandle();
97
+ #endif
98
+
99
+ } // namespace at::cuda
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/ScalarType.h>
4
+
5
+ #include <cuda.h>
6
+ #include <library_types.h>
7
+
8
+ namespace at::cuda {
9
+
10
+ template <typename scalar_t>
11
+ cudaDataType getCudaDataType() {
12
+ static_assert(false && sizeof(scalar_t), "Cannot convert type to cudaDataType.");
13
+ return {};
14
+ }
15
+
16
+ template<> inline cudaDataType getCudaDataType<at::Half>() {
17
+ return CUDA_R_16F;
18
+ }
19
+ template<> inline cudaDataType getCudaDataType<float>() {
20
+ return CUDA_R_32F;
21
+ }
22
+ template<> inline cudaDataType getCudaDataType<double>() {
23
+ return CUDA_R_64F;
24
+ }
25
+ template<> inline cudaDataType getCudaDataType<c10::complex<c10::Half>>() {
26
+ return CUDA_C_16F;
27
+ }
28
+ template<> inline cudaDataType getCudaDataType<c10::complex<float>>() {
29
+ return CUDA_C_32F;
30
+ }
31
+ template<> inline cudaDataType getCudaDataType<c10::complex<double>>() {
32
+ return CUDA_C_64F;
33
+ }
34
+
35
+ template<> inline cudaDataType getCudaDataType<uint8_t>() {
36
+ return CUDA_R_8U;
37
+ }
38
+ template<> inline cudaDataType getCudaDataType<int8_t>() {
39
+ return CUDA_R_8I;
40
+ }
41
+ template<> inline cudaDataType getCudaDataType<int>() {
42
+ return CUDA_R_32I;
43
+ }
44
+
45
+ template<> inline cudaDataType getCudaDataType<int16_t>() {
46
+ return CUDA_R_16I;
47
+ }
48
+ template<> inline cudaDataType getCudaDataType<int64_t>() {
49
+ return CUDA_R_64I;
50
+ }
51
+ template<> inline cudaDataType getCudaDataType<at::BFloat16>() {
52
+ return CUDA_R_16BF;
53
+ }
54
+
55
+ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) {
56
+ switch (scalar_type) {
57
+ case c10::ScalarType::Byte:
58
+ return CUDA_R_8U;
59
+ case c10::ScalarType::Char:
60
+ return CUDA_R_8I;
61
+ case c10::ScalarType::Int:
62
+ return CUDA_R_32I;
63
+ case c10::ScalarType::Half:
64
+ return CUDA_R_16F;
65
+ case c10::ScalarType::Float:
66
+ return CUDA_R_32F;
67
+ case c10::ScalarType::Double:
68
+ return CUDA_R_64F;
69
+ case c10::ScalarType::ComplexHalf:
70
+ return CUDA_C_16F;
71
+ case c10::ScalarType::ComplexFloat:
72
+ return CUDA_C_32F;
73
+ case c10::ScalarType::ComplexDouble:
74
+ return CUDA_C_64F;
75
+ case c10::ScalarType::Short:
76
+ return CUDA_R_16I;
77
+ case c10::ScalarType::Long:
78
+ return CUDA_R_64I;
79
+ case c10::ScalarType::BFloat16:
80
+ return CUDA_R_16BF;
81
+ #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
82
+ case c10::ScalarType::Float8_e4m3fn:
83
+ return CUDA_R_8F_E4M3;
84
+ case c10::ScalarType::Float8_e5m2:
85
+ return CUDA_R_8F_E5M2;
86
+ #endif
87
+ #if defined(USE_ROCM)
88
+ #if defined(HIP_NEW_TYPE_ENUMS)
89
+ case c10::ScalarType::Float8_e4m3fnuz:
90
+ return HIP_R_8F_E4M3_FNUZ;
91
+ case c10::ScalarType::Float8_e5m2fnuz:
92
+ return HIP_R_8F_E5M2_FNUZ;
93
+ #else
94
+ case c10::ScalarType::Float8_e4m3fnuz:
95
+ return static_cast<hipDataType>(1000);
96
+ case c10::ScalarType::Float8_e5m2fnuz:
97
+ return static_cast<hipDataType>(1001);
98
+ #endif
99
+ #endif
100
+ default:
101
+ TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")
102
+ }
103
+ }
104
+
105
+ } // namespace at::cuda
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADevice.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/cuda/Exceptions.h>
4
+
5
+ #include <cuda.h>
6
+ #include <cuda_runtime.h>
7
+
8
+ namespace at::cuda {
9
+
10
+ inline Device getDeviceFromPtr(void* ptr) {
11
+ cudaPointerAttributes attr{};
12
+
13
+ AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
14
+
15
+ #if !defined(USE_ROCM)
16
+ TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered,
17
+ "The specified pointer resides on host memory and is not registered with any CUDA device.");
18
+ #endif
19
+
20
+ return {c10::DeviceType::CUDA, static_cast<DeviceIndex>(attr.device)};
21
+ }
22
+
23
+ } // namespace at::cuda
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAUtils.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/cuda/CUDAContext.h>
4
+
5
+ namespace at::cuda {
6
+
7
+ // Check if every tensor in a list of tensors matches the current
8
+ // device.
9
+ inline bool check_device(ArrayRef<Tensor> ts) {
10
+ if (ts.empty()) {
11
+ return true;
12
+ }
13
+ Device curDevice = Device(kCUDA, current_device());
14
+ for (const Tensor& t : ts) {
15
+ if (t.device() != curDevice) return false;
16
+ }
17
+ return true;
18
+ }
19
+
20
+ } // namespace at::cuda
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda.h>
4
+ #include <limits.h>
5
+ #include <math.h>
6
+ #include <float.h>
7
+
8
+ // NumericLimits.cuh is a holder for numeric limits definitions of commonly used
9
+ // types. This header is very specific to ROCm HIP and may be removed in the future.
10
+ // This header is derived from the legacy THCNumerics.cuh.
11
+
12
+ // The lower_bound and upper_bound constants are same as lowest and max for
13
+ // integral types, but are -inf and +inf for floating point types. They are
14
+ // useful in implementing min, max, etc.
15
+
16
+ namespace at {
17
+
18
+ template <typename T>
19
+ struct numeric_limits {
20
+ };
21
+
22
+ // WARNING: the following at::numeric_limits definitions are there only to support
23
+ // HIP compilation for the moment. Use std::numeric_limits if you are not
24
+ // compiling for ROCm.
25
+ // from @colesbury: "The functions on numeric_limits aren't marked with
26
+ // __device__ which is why they don't work with ROCm. CUDA allows them
27
+ // because they're constexpr."
28
+
29
+ namespace {
30
+ // ROCm doesn't like INFINITY too.
31
+ constexpr double inf = INFINITY;
32
+ }
33
+
34
+ template <>
35
+ struct numeric_limits<bool> {
36
+ static inline __host__ __device__ bool lowest() { return false; }
37
+ static inline __host__ __device__ bool max() { return true; }
38
+ static inline __host__ __device__ bool lower_bound() { return false; }
39
+ static inline __host__ __device__ bool upper_bound() { return true; }
40
+ };
41
+
42
+ template <>
43
+ struct numeric_limits<uint8_t> {
44
+ static inline __host__ __device__ uint8_t lowest() { return 0; }
45
+ static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
46
+ static inline __host__ __device__ uint8_t lower_bound() { return 0; }
47
+ static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; }
48
+ };
49
+
50
+ template <>
51
+ struct numeric_limits<int8_t> {
52
+ static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
53
+ static inline __host__ __device__ int8_t max() { return INT8_MAX; }
54
+ static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; }
55
+ static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
56
+ };
57
+
58
+ template <>
59
+ struct numeric_limits<int16_t> {
60
+ static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
61
+ static inline __host__ __device__ int16_t max() { return INT16_MAX; }
62
+ static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; }
63
+ static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
64
+ };
65
+
66
+ template <>
67
+ struct numeric_limits<int32_t> {
68
+ static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
69
+ static inline __host__ __device__ int32_t max() { return INT32_MAX; }
70
+ static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; }
71
+ static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
72
+ };
73
+
74
+ template <>
75
+ struct numeric_limits<int64_t> {
76
+ #ifdef _MSC_VER
77
+ static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
78
+ static inline __host__ __device__ int64_t max() { return _I64_MAX; }
79
+ static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; }
80
+ static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; }
81
+ #else
82
+ static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
83
+ static inline __host__ __device__ int64_t max() { return INT64_MAX; }
84
+ static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; }
85
+ static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; }
86
+ #endif
87
+ };
88
+
89
+ template <>
90
+ struct numeric_limits<at::Half> {
91
+ static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); }
92
+ static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); }
93
+ static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); }
94
+ static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); }
95
+ };
96
+
97
+ template <>
98
+ struct numeric_limits<at::BFloat16> {
99
+ static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); }
100
+ static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); }
101
+ static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); }
102
+ static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); }
103
+ };
104
+
105
+ template <>
106
+ struct numeric_limits<float> {
107
+ static inline __host__ __device__ float lowest() { return -FLT_MAX; }
108
+ static inline __host__ __device__ float max() { return FLT_MAX; }
109
+ static inline __host__ __device__ float lower_bound() { return -static_cast<float>(inf); }
110
+ static inline __host__ __device__ float upper_bound() { return static_cast<float>(inf); }
111
+ };
112
+
113
+ template <>
114
+ struct numeric_limits<double> {
115
+ static inline __host__ __device__ double lowest() { return -DBL_MAX; }
116
+ static inline __host__ __device__ double max() { return DBL_MAX; }
117
+ static inline __host__ __device__ double lower_bound() { return -inf; }
118
+ static inline __host__ __device__ double upper_bound() { return inf; }
119
+ };
120
+
121
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <c10/macros/Macros.h>
2
+ #include <cstdint>
3
+
4
+ namespace at::cuda {
5
+ namespace detail {
6
+ void init_p2p_access_cache(int64_t num_devices);
7
+ }
8
+
9
+ TORCH_CUDA_CPP_API bool get_p2p_access(int source_dev, int dest_dev);
10
+
11
+ } // namespace at::cuda
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/macros/Export.h>
3
+ #include <cstdint>
4
+
5
+ namespace at::cuda {
6
+
7
+ // enqueues a kernel that spins for the specified number of cycles
8
+ TORCH_CUDA_CU_API void sleep(int64_t cycles);
9
+
10
+ // flushes instruction cache for ROCm; no-op for CUDA
11
+ TORCH_CUDA_CU_API void flush_icache();
12
+
13
+ } // namespace at::cuda
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/cub_definitions.cuh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if !defined(USE_ROCM)
4
+ #include <cuda.h> // for CUDA_VERSION
5
+ #endif
6
+
7
+ #if !defined(USE_ROCM)
8
+ #include <cub/version.cuh>
9
+ #else
10
+ #define CUB_VERSION 0
11
+ #endif
12
+
13
+ // cub sort support for __nv_bfloat16 is added to cub 1.13 in:
14
+ // https://github.com/NVIDIA/cub/pull/306
15
+ #if CUB_VERSION >= 101300
16
+ #define CUB_SUPPORTS_NV_BFLOAT16() true
17
+ #else
18
+ #define CUB_SUPPORTS_NV_BFLOAT16() false
19
+ #endif
20
+
21
+ // cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
22
+ // https://github.com/NVIDIA/cub/pull/326
23
+ // CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
24
+ // starting from CUDA 11.5
25
+ #if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE)
26
+ #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true
27
+ #else
28
+ #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
29
+ #endif
30
+
31
+ // cub support for UniqueByKey is added to cub 1.16 in:
32
+ // https://github.com/NVIDIA/cub/pull/405
33
+ #if CUB_VERSION >= 101600
34
+ #define CUB_SUPPORTS_UNIQUE_BY_KEY() true
35
+ #else
36
+ #define CUB_SUPPORTS_UNIQUE_BY_KEY() false
37
+ #endif
38
+
39
+ // cub support for scan by key is added to cub 1.15
40
+ // in https://github.com/NVIDIA/cub/pull/376
41
+ #if CUB_VERSION >= 101500
42
+ #define CUB_SUPPORTS_SCAN_BY_KEY() 1
43
+ #else
44
+ #define CUB_SUPPORTS_SCAN_BY_KEY() 0
45
+ #endif
46
+
47
+ // cub support for cub::FutureValue is added to cub 1.15 in:
48
+ // https://github.com/NVIDIA/cub/pull/305
49
+ #if CUB_VERSION >= 101500
50
+ #define CUB_SUPPORTS_FUTURE_VALUE() true
51
+ #else
52
+ #define CUB_SUPPORTS_FUTURE_VALUE() false
53
+ #endif
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/detail/CUDAHooksInterface.h>
4
+
5
+ #include <ATen/Generator.h>
6
+ #include <optional>
7
+
8
+ // TODO: No need to have this whole header, we can just put it all in
9
+ // the cpp file
10
+
11
+ namespace at::cuda::detail {
12
+
13
+ // Set the callback to initialize Magma, which is set by
14
+ // torch_cuda_cu. This indirection is required so magma_init is called
15
+ // in the same library where Magma will be used.
16
+ TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
17
+
18
+
19
+ // The real implementation of CUDAHooksInterface
20
+ struct CUDAHooks : public at::CUDAHooksInterface {
21
+ CUDAHooks(at::CUDAHooksArgs) {}
22
+ void initCUDA() const override;
23
+ Device getDeviceFromPtr(void* data) const override;
24
+ bool isPinnedPtr(const void* data) const override;
25
+ const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
26
+ bool hasCUDA() const override;
27
+ bool hasMAGMA() const override;
28
+ bool hasCuDNN() const override;
29
+ bool hasCuSOLVER() const override;
30
+ bool hasCuBLASLt() const override;
31
+ bool hasROCM() const override;
32
+ const at::cuda::NVRTC& nvrtc() const override;
33
+ DeviceIndex current_device() const override;
34
+ bool hasPrimaryContext(DeviceIndex device_index) const override;
35
+ Allocator* getCUDADeviceAllocator() const override;
36
+ Allocator* getPinnedMemoryAllocator() const override;
37
+ bool compiledWithCuDNN() const override;
38
+ bool compiledWithMIOpen() const override;
39
+ bool supportsDilatedConvolutionWithCuDNN() const override;
40
+ bool supportsDepthwiseConvolutionWithCuDNN() const override;
41
+ bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
42
+ bool hasCUDART() const override;
43
+ long versionCUDART() const override;
44
+ long versionCuDNN() const override;
45
+ std::string showConfig() const override;
46
+ double batchnormMinEpsilonCuDNN() const override;
47
+ int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
48
+ void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
49
+ int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
50
+ void cuFFTClearPlanCache(DeviceIndex device_index) const override;
51
+ int getNumGPUs() const override;
52
+ #ifdef USE_ROCM
53
+ bool isGPUArch(DeviceIndex device_index, const std::vector<std::string>& archs) const override;
54
+ #endif
55
+ void deviceSynchronize(DeviceIndex device_index) const override;
56
+ };
57
+
58
+ } // at::cuda::detail
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
2
+ // These handles are tied to device, and these libraries requires/recommends not to
3
+ // share handles across host threads.
4
+ //
5
+ // These libraries recommend using one handle per host thread. We may not want to do
6
+ // this because threads are relatively light-weight, but creating and destroying
7
+ // handles is expensive (destroying the handle causes synchronizations). DataParallel,
8
+ // for example, creates new threads for each forward pass.
9
+ //
10
+ // This file implements a handle pool mechanism. The handle pool returns handles on
11
+ // demand as threads request them. If all existing handles in the pool are in use,
12
+ // it creates a new one. As threads terminate, they release handles back into the pool.
13
+ // In this way, the handle pool never creates more handles than the high-water mark of
14
+ // active threads, so it's efficient with DataParallel.
15
+
16
+ #pragma once
17
+
18
+ #include <unordered_map>
19
+ #include <vector>
20
+ #include <utility>
21
+ #include <mutex>
22
+ #include <memory>
23
+
24
+ #include <c10/util/Exception.h>
25
+
26
+ namespace at::cuda { namespace {
27
+
28
+ template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
29
+ struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
30
+
31
+ struct Handle {
32
+ Handle_t handle;
33
+ Handle(bool create = false) : handle(nullptr)
34
+ {
35
+ if(create) Create(&handle);
36
+ }
37
+ // std::vector.emplace() and push_back() may route through temporaries and call
38
+ // copy/move constructors along the way. If this is the case, we don't want
39
+ // the destructors of temporaries to call cudnnDestroy on the handle.
40
+ // We can achieve safety (for the narrow case of stashing within std::vectors)
41
+ // by making Handle moveable but not copyable, and transferring handle ownership
42
+ // to the latest constructed object. This is not a substitute for full-blown
43
+ // reference counting, but reference counting may be overkill here.
44
+ // Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
45
+ // unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
46
+ Handle(const Handle& rhs) = delete;
47
+ // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
48
+ Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); }
49
+ // operator= takes argument by value
50
+ Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
51
+ ~Handle() {
52
+ if(handle) Destroy(handle);
53
+ }
54
+ };
55
+
56
+ std::mutex mutex;
57
+
58
+ // Handles are lazily created as different threads request them,
59
+ // but are never destroyed until the end of the process.
60
+ // The maximum number of handles this process will create for each device is equal
61
+ // to the high-water mark of the number of concurrently active threads that request
62
+ // handles for that device.
63
+ // When threads terminate, they release their handles back into the pool for reuse.
64
+ // Otherwise, new handles would be created every time new threads were spawned,
65
+ // resulting in poor performance for Python modules that repeatedly or frequently
66
+ // spawned new sets of threads (like DataParallel, which creates a new set of threads
67
+ // for each forward pass).
68
+ //
69
+ // To prevent potential deadlocks, we explicitly choose not to cap the number
70
+ // of handles that are created per device.
71
+ // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
72
+ // only 4 can make forward progress at any time. The other 4 will not release their
73
+ // handles until they exit, so the fifth cannot make progress until then. This is
74
+ // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
75
+ // intermediate point (ie, before any of them have exited). We have no way to anticipate
76
+ // or enforce that user threads will not attempt such intermediate synchronization.
77
+ // The only way to ensure safety is to avoid imposing a cap on the number of handles.
78
+ std::unordered_map<int, std::vector<Handle>> created_handles;
79
+ std::unordered_map<int, std::vector<Handle_t>> available_handles;
80
+
81
+ // PoolWindow lazily creates and caches the handles that a particular thread is using,
82
+ // so in the common case handle access doesn't incur either handle creation or a mutex lock.
83
+ class PoolWindow
84
+ {
85
+ public:
86
+ PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
87
+ ~PoolWindow(){ release(); }
88
+
89
+ Handle_t reserve(int device)
90
+ {
91
+ // If this thread already has a handle for this device, return it
92
+ if(my_handles.find(device) != my_handles.end())
93
+ return my_handles[device];
94
+
95
+ // otherwise, either grab a handle from the pool if one is available,
96
+ // or if not, create a new one.
97
+ auto parent = weak_parent.lock();
98
+ TORCH_CHECK(parent, "Cannot create handle during program termination");
99
+ std::lock_guard<std::mutex> guard(parent->mutex);
100
+
101
+ if(parent->available_handles[device].size() > 0)
102
+ {
103
+ my_handles[device] = parent->available_handles[device].back();
104
+ parent->available_handles[device].pop_back();
105
+ }
106
+ else
107
+ {
108
+ // In local testing, I do observe that emplace_back sometimes routes through temporaries
109
+ // that incur move-constructor and destructor calls. See comments in Handle above.
110
+ parent->created_handles[device].emplace_back(true /*create*/);
111
+ my_handles[device] = parent->created_handles[device].back().handle;
112
+ }
113
+
114
+ return my_handles[device];
115
+ }
116
+
117
+ private:
118
+ // Stores the per-device handles currently owned by this thread
119
+ std::unordered_map<int, Handle_t> my_handles;
120
+
121
+ std::weak_ptr<DeviceThreadHandlePool> weak_parent;
122
+
123
+ // Called by the destructor. Releases this thread's handles back into the pool.
124
+ void release() {
125
+ if(my_handles.size() > 0) {
126
+ auto parent = weak_parent.lock();
127
+ if (!parent) {
128
+ // If this thread exits after atexit handlers have completed, the
129
+ // cuda context itself may be invalid, so we must leak the handles.
130
+ return;
131
+ }
132
+
133
+ std::lock_guard<std::mutex> guard(parent->mutex);
134
+ for(auto d_h : my_handles)
135
+ parent->available_handles[d_h.first].push_back(d_h.second);
136
+ }
137
+ }
138
+ };
139
+
140
+ // Warning:
141
+ // If you want to change this function, be aware that this function will be called
142
+ // by multiple threads and there is no mutex guarding the call of this function, so
143
+ // make sure your implementation is thread-safe.
144
+ PoolWindow *newPoolWindow() {
145
+ // The returned pointer will be owned by a thread local variable
146
+ // so that different threads does not share the same PoolWindow.
147
+ return new PoolWindow(this->shared_from_this());
148
+ }
149
+ };
150
+
151
+ }} // namespace at::cuda::detail::<anonymous>
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <assert.h>
4
+ #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
5
+ #include <cuda_runtime.h>
6
+ #endif
7
+
8
+ namespace at::cuda::detail {
9
+
10
+ // A utility class to implement integer division by multiplication, given a fixed
11
+ // divisor.
12
+ //
13
+ // WARNING: The fast divider algorithm is only implemented for unsigned int;
14
+ // otherwise we default to plain integer division. For unsigned int,
15
+ // we further assume that the dividend is at most INT32_MAX. Thus,
16
+ // IntDivider must NOT be used for general integer division.
17
+ //
18
+ // This reduced range is enough for our purpose, and it allows us to
19
+ // slightly simplify the computation.
20
+ //
21
+ // (NOTE: Below, "2^k" denotes exponentiation, i.e., 1<<k.)
22
+ //
23
+ // For any N-bit unsigned integer d (> 0), we can find a "magic number" m (2^N
24
+ // <= m < 2^(N+1)) and shift s such that:
25
+ //
26
+ // \floor(n / d) = \floor((m * n) / 2^(N+s)).
27
+ //
28
+ // Given such m and s, the integer division can be then implemented as:
29
+ //
30
+ // let m' = m - 2^N // 0 <= m' < 2^N
31
+ //
32
+ // fast_integer_division(n):
33
+ // // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned
34
+ // // integer. Then take the higher N bits.
35
+ // t = (m' * n) >> N
36
+ //
37
+ // // Here we use the fact that n is less than 2^(N-1): otherwise the value
38
+ // // of (t + n) may not fit in an N-bit integer.
39
+ // return (t + n) >> s
40
+ //
41
+ // Finding such a magic number is surprisingly easy:
42
+ //
43
+ // s = \ceil(\log_2 d)
44
+ // m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic.
45
+ //
46
+ // See also:
47
+ // - Division by Invariant Integers Using Multiplication,
48
+ // Torbjörn Granlund and Peter L. Montgomery, 1994.
49
+ //
50
+ // - http://www.hackersdelight.org/magic.htm
51
+ //
52
+ // - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
53
+
54
+ // Result of div/mod operation stored together.
55
+ template <typename Value>
56
+ struct DivMod {
57
+ Value div, mod;
58
+
59
+ C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { }
60
+ };
61
+
62
+ // Base case: we only have an implementation for uint32_t for now. For
63
+ // everything else, we use plain division.
64
+ template <typename Value>
65
+ struct IntDivider {
66
+ IntDivider() = default;
67
+ IntDivider(Value d) : divisor(d) { }
68
+
69
+ C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; }
70
+ C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; }
71
+ C10_HOST_DEVICE inline DivMod<Value> divmod(Value n) const {
72
+ return DivMod<Value>(n / divisor, n % divisor);
73
+ }
74
+
75
+ Value divisor;
76
+ };
77
+
78
+ // Implement fast integer division.
79
+ template <>
80
+ struct IntDivider<unsigned int> {
81
+ static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int.");
82
+
83
+ IntDivider() = default;
84
+
85
+ IntDivider(unsigned int d) : divisor(d) {
86
+ assert(divisor >= 1 && divisor <= INT32_MAX);
87
+
88
+ // TODO: gcc/clang has __builtin_clz() but it's not portable.
89
+ for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break;
90
+
91
+ uint64_t one = 1;
92
+ uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
93
+ m1 = magic;
94
+ assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits.
95
+ }
96
+
97
+ C10_HOST_DEVICE inline unsigned int div(unsigned int n) const {
98
+ #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
99
+ // 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
100
+ // 'm1'.
101
+ unsigned int t = __umulhi(n, m1);
102
+ return (t + n) >> shift;
103
+ #else
104
+ // Using uint64_t so that the addition does not overflow.
105
+ uint64_t t = ((uint64_t) n * m1) >> 32;
106
+ return (t + n) >> shift;
107
+ #endif
108
+ }
109
+
110
+ C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const {
111
+ return n - div(n) * divisor;
112
+ }
113
+
114
+ C10_HOST_DEVICE inline DivMod<unsigned int> divmod(unsigned int n) const {
115
+ unsigned int q = div(n);
116
+ return DivMod<unsigned int>(q, n - q * divisor);
117
+ }
118
+
119
+ unsigned int divisor; // d above.
120
+ unsigned int m1; // Magic number: m' above.
121
+ unsigned int shift; // Shift amounts.
122
+ };
123
+
124
+ } // namespace at::cuda::detail
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <limits>
4
+ #include <c10/util/Exception.h>
5
+
6
+ namespace at::cuda::detail {
7
+
8
+ // CUDA: grid stride looping
9
+ //
10
+ // int64_t _i_n_d_e_x specifically prevents overflow in the loop increment.
11
+ // If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final
12
+ // iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be
13
+ // greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no
14
+ // further iterations and the overflowed value in i=_i_n_d_e_x is not used.
15
+ #define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
16
+ int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \
17
+ for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
18
+
19
+ #define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
20
+
21
+
22
+ // Use 1024 threads per block, which requires cuda sm_2x or above
23
+ constexpr int CUDA_NUM_THREADS = 1024;
24
+
25
+ // CUDA: number of blocks for threads.
26
+ inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) {
27
+ TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
28
+ constexpr int64_t max_int = std::numeric_limits<int>::max();
29
+
30
+ // Round up division for positive number that cannot cause integer overflow
31
+ auto block_num = (N - 1) / max_threads_per_block + 1;
32
+ TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
33
+
34
+ return static_cast<int>(block_num);
35
+ }
36
+
37
+ } // namespace at::cuda::detail
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/detail/CUDAHooksInterface.h>
3
+ namespace at::cuda {
4
+ // Forward-declares at::cuda::NVRTC
5
+ struct NVRTC;
6
+
7
+ namespace detail {
8
+ extern NVRTC lazyNVRTC;
9
+ } // namespace detail
10
+
11
+ } // namespace at::cuda
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
2
+ // Eager mode clients should not include this file directly, instead,
3
+ // they should #include <ATen/cuda/PhiloxCudaState.h>, which has a #pragma once.
4
+
5
+ // Stores RNG state values. Passed as a kernel argument.
6
+ // See Note [CUDA Graph-safe RNG states].
7
+ //
8
+ // The raw definition lives in its own file so jit codegen can easily copy it.
9
+ namespace at {
10
+
11
+ struct PhiloxCudaState {
12
+ PhiloxCudaState() = default;
13
+ // Called if graph capture is not underway
14
+ PhiloxCudaState(uint64_t seed,
15
+ uint64_t offset) {
16
+ seed_.val = seed;
17
+ offset_.val = offset;
18
+ }
19
+ // Called if graph capture is underway
20
+ PhiloxCudaState(int64_t* seed,
21
+ int64_t* offset_extragraph,
22
+ uint32_t offset_intragraph) {
23
+ seed_.ptr = seed;
24
+ offset_.ptr = offset_extragraph;
25
+ offset_intragraph_ = offset_intragraph;
26
+ captured_ = true;
27
+ }
28
+
29
+ // Public members, directly accessible by at::cuda::philox::unpack.
30
+ // If we made them private with getters/setters, the getters/setters
31
+ // would have to be __device__, and we can't declare __device__ in ATen.
32
+ union Payload {
33
+ uint64_t val;
34
+ int64_t* ptr;
35
+ };
36
+
37
+ Payload seed_{};
38
+ Payload offset_{};
39
+ uint32_t offset_intragraph_ = 0;
40
+ bool captured_ = false;
41
+ };
42
+
43
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/CollapseDims.h>
4
+
5
+ namespace at::cuda::detail {
6
+
7
+ #define MAX_TENSORINFO_DIMS 25
8
+
9
+ // CUDA kernel argument that defines tensor layout
10
+ template <typename T, typename IndexType>
11
+ struct TensorInfo {
12
+ TensorInfo();
13
+ TensorInfo(T* p,
14
+ int dim,
15
+ IndexType sz[MAX_TENSORINFO_DIMS],
16
+ IndexType st[MAX_TENSORINFO_DIMS]);
17
+
18
+ // Set the size of the given dimension to 1, as if it were a
19
+ // reduction dim (allows you to calculate offsets of the reduction
20
+ // slice)
21
+ void reduceDim(int dim);
22
+
23
+ // See note on [collapse dims].
24
+ int collapseDims(const int excludeDim = -1);
25
+
26
+ // Contiguous tensors of more than one dimension are collapsed down
27
+ // to one tensor
28
+ __host__ __device__ inline bool isContiguous() const {
29
+ return (dims == 1 && strides[0] == 1);
30
+ }
31
+
32
+ T* data;
33
+ IndexType sizes[MAX_TENSORINFO_DIMS];
34
+ IndexType strides[MAX_TENSORINFO_DIMS];
35
+ int dims;
36
+ };
37
+
38
+ template <typename T, typename IndexType>
39
+ TensorInfo<T, IndexType>::TensorInfo() {
40
+ data = nullptr;
41
+ dims = 0;
42
+ }
43
+
44
+ template <typename T, typename IndexType>
45
+ TensorInfo<T, IndexType>::TensorInfo(T* p,
46
+ int dim,
47
+ IndexType sz[MAX_TENSORINFO_DIMS],
48
+ IndexType st[MAX_TENSORINFO_DIMS]) {
49
+ data = p;
50
+ dims = dim;
51
+ TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
52
+
53
+ for (int i = 0; i < dim; ++i) {
54
+ sizes[i] = sz[i];
55
+ strides[i] = st[i];
56
+ }
57
+ }
58
+
59
+ template <typename T, typename IndexType>
60
+ void
61
+ TensorInfo<T, IndexType>::reduceDim(int dim) {
62
+ TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
63
+ sizes[dim] = 1;
64
+ }
65
+
66
+ template <typename T, typename IndexType>
67
+ int
68
+ TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {
69
+ auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
70
+ dims = std::get<1>(result);
71
+ return std::get<0>(result);
72
+ }
73
+
74
+ // Translate a linear index for the apply to a T* offset;
75
+ // specialized on `Dims` to reduce nvcc compilation time
76
+ template <typename T, typename IndexType, int Dims>
77
+ struct IndexToOffset {
78
+ static __host__ __device__ IndexType get(
79
+ IndexType linearId,
80
+ const TensorInfo<T, IndexType>& info) {
81
+
82
+ IndexType offset = 0;
83
+
84
+ // Uses static dims
85
+ for (int i = Dims - 1; i > 0; --i) {
86
+ IndexType curDimIndex = linearId % info.sizes[i];
87
+ IndexType curDimOffset = curDimIndex * info.strides[i];
88
+ offset += curDimOffset;
89
+ linearId /= info.sizes[i];
90
+ }
91
+
92
+ return offset + linearId * info.strides[0];
93
+ }
94
+ };
95
+
96
+ // Uses dynamic (runtime) instead of static (compiletime) dims
97
+ template <typename T, typename IndexType>
98
+ struct IndexToOffset<T, IndexType, -1> {
99
+ static inline __host__ __device__ IndexType get(
100
+ IndexType linearId,
101
+ const TensorInfo<T, IndexType>& info) {
102
+
103
+ IndexType offset = 0;
104
+
105
+ for (int i = info.dims - 1; i > 0; --i) {
106
+ IndexType curDimIndex = linearId % info.sizes[i];
107
+ IndexType curDimOffset = curDimIndex * info.strides[i];
108
+ offset += curDimOffset;
109
+ linearId /= info.sizes[i];
110
+ }
111
+
112
+ return offset + linearId * info.strides[0];
113
+ }
114
+ };
115
+
116
+ } // namespace at::cuda::detail
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
2
+ // Eager mode clients should not include this file directly, instead,
3
+ // they should #include <ATen/cuda/PhiloxUtils.cuh>, which has a #pragma once.
4
+
5
+ namespace at::cuda::philox {
6
+
7
+ // In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
8
+ // that instance was created with graph capture underway or not.
9
+ // See Note [CUDA Graph-safe RNG states].
10
+ //
11
+ // We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen.
12
+ // Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable.
13
+ // Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
14
+ //
15
+ // The raw definition lives in its own file so jit codegen can easily copy it.
16
+ __host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
17
+ unpack(at::PhiloxCudaState arg) {
18
+ if (arg.captured_) {
19
+ // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
20
+ // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
21
+ // For most threads' reads it will hit in cache, so it shouldn't hurt performance.
22
+ return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
23
+ } else {
24
+ return std::make_tuple(arg.seed_.val, arg.offset_.val);
25
+ }
26
+ }
27
+
28
+ } // namespace at::cuda::philox
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <string>
4
+ #include <c10/macros/Export.h>
5
+
6
+ namespace at::cuda {
7
+
8
+ TORCH_CUDA_CPP_API const std::string &get_traits_string();
9
+ TORCH_CUDA_CPP_API const std::string &get_cmath_string();
10
+ TORCH_CUDA_CPP_API const std::string &get_complex_body_string();
11
+ TORCH_CUDA_CPP_API const std::string &get_complex_half_body_string();
12
+ TORCH_CUDA_CPP_API const std::string &get_complex_math_string();
13
+
14
+ } // namespace at::cuda
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Original TunableOp is from onnxruntime.
2
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4
+ // Copyright (c) Microsoft Corporation.
5
+ // Licensed under the MIT license.
6
+ //
7
+ // Adapting TunableOp into PyTorch
8
+ // Copyright (c) Advanced Micro Devices, Inc.
9
+ //
10
+ #pragma once
11
+
12
+ #include <string>
13
+
14
+ #include <ATen/cuda/tunable/TunableOp.h>
15
+ #include <ATen/cuda/Exceptions.h>
16
+ #include <c10/util/StringUtil.h>
17
+
18
+ #ifndef AT_PER_OPERATOR_HEADERS
19
+ #include <ATen/Functions.h>
20
+ #include <ATen/NativeFunctions.h>
21
+ #else
22
+ #include <ATen/ops/allclose.h>
23
+ #include <ATen/ops/from_blob.h>
24
+ #endif
25
+
26
+ namespace at::cuda::tunable {
27
+
28
+ enum class BlasOp {
29
+ N = 0,
30
+ T = 1
31
+ };
32
+
33
+ inline std::string BlasOpToString(BlasOp op) {
34
+ switch (op) {
35
+ case BlasOp::N:
36
+ return "N";
37
+ case BlasOp::T:
38
+ return "T";
39
+ }
40
+ TORCH_CHECK(false, "unrecognized BlasOp");
41
+ return "N";
42
+ }
43
+
44
+ namespace detail {
45
+
46
+ static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
47
+ auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
48
+ // comparison done as 1D tensor
49
+ at::Tensor ref = at::from_blob(c, {size}, options);
50
+ at::Tensor oth = at::from_blob(other_c, {size}, options);
51
+ at::Tensor ref_float = ref.to(at::kFloat);
52
+ at::Tensor oth_float = oth.to(at::kFloat);
53
+ std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
54
+ std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
55
+ double last_succeed_atol = 1;
56
+ double last_succeed_rtol = 1;
57
+ for (auto& atol : atols) {
58
+ for (auto& rtol : rtols) {
59
+ if (at::allclose(ref_float, oth_float, rtol, atol)) {
60
+ last_succeed_atol = atol;
61
+ last_succeed_rtol = rtol;
62
+ }
63
+ }
64
+ }
65
+ if (last_succeed_atol == 1) {
66
+ return false;
67
+ }
68
+ else {
69
+ TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
70
+ }
71
+
72
+ return true;
73
+ }
74
+
75
+ }
76
+
77
+ template <typename T>
78
+ struct GemmParams : OpParams {
79
+ GemmParams() {
80
+ duplicate_inputs_ = false;
81
+ }
82
+
83
+ std::string Signature() const override {
84
+ return c10::str(transa, transb, "_", m, "_", n, "_", k);
85
+ }
86
+
87
+ size_t GetSizeA() const {
88
+ return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
89
+ }
90
+
91
+ size_t GetSizeB() const {
92
+ return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
93
+ }
94
+
95
+ size_t GetSizeC() const {
96
+ return sizeof(T) * ldc * n;
97
+ }
98
+
99
+ size_t GetSize(bool duplicate_inputs) const {
100
+ size_t size = GetSizeC();
101
+ if (duplicate_inputs) {
102
+ size += GetSizeA();
103
+ size += GetSizeB();
104
+ }
105
+ return size;
106
+ }
107
+
108
+ GemmParams* DeepCopy(bool duplicate_inputs) const {
109
+ GemmParams* copy = new GemmParams;
110
+ *copy = *this;
111
+ c10::DeviceIndex device = 0;
112
+ AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
113
+ size_t c_size = GetSizeC();
114
+ copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
115
+ AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
116
+ copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
117
+ if (duplicate_inputs) {
118
+ size_t a_size = GetSizeA();
119
+ size_t b_size = GetSizeB();
120
+ copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
121
+ copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
122
+ copy->duplicate_inputs_ = true;
123
+ }
124
+ return copy;
125
+ }
126
+
127
+ // only call on object returned by DeepCopy
128
+ void Delete() {
129
+ c10::cuda::CUDACachingAllocator::raw_delete(c);
130
+ if (duplicate_inputs_) {
131
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
132
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
133
+ }
134
+ }
135
+
136
+ TuningStatus NumericalCheck(GemmParams<T> *other) {
137
+ auto c_dtype = c10::CppTypeToScalarType<T>::value;
138
+ return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
139
+ }
140
+
141
+ char transa;
142
+ char transb;
143
+ int64_t m;
144
+ int64_t n;
145
+ int64_t k;
146
+ at::opmath_type<T> alpha;
147
+ const T* a;
148
+ int64_t lda;
149
+ const T* b;
150
+ int64_t ldb;
151
+ at::opmath_type<T> beta;
152
+ T* c;
153
+ int64_t ldc;
154
+ private:
155
+ bool duplicate_inputs_;
156
+ };
157
+
158
+ template <typename T>
159
+ struct GemmAndBiasParams : OpParams {
160
+ std::string Signature() const override {
161
+ return c10::str(transa, transb, "_", m, "_", n, "_", k);
162
+ }
163
+
164
+ size_t GetSize(bool duplicate_inputs) const {
165
+ size_t size = sizeof(T) * ldc * n;
166
+ if (duplicate_inputs) {
167
+ size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
168
+ size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
169
+ }
170
+ return size;
171
+ }
172
+
173
+ GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const {
174
+ GemmAndBiasParams* copy = new GemmAndBiasParams;
175
+ *copy = *this;
176
+ c10::DeviceIndex device = 0;
177
+ AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
178
+ size_t c_size = ldc * n * sizeof(T);
179
+ copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
180
+ AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
181
+ copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
182
+ if (duplicate_inputs) {
183
+ size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
184
+ size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
185
+ copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
186
+ copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
187
+ copy->duplicate_inputs_ = true;
188
+ }
189
+ return copy;
190
+ }
191
+
192
+ // only call on object returned by DeepCopy
193
+ void Delete() {
194
+ c10::cuda::CUDACachingAllocator::raw_delete(c);
195
+ if (duplicate_inputs_) {
196
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
197
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
198
+ }
199
+ }
200
+
201
+ TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
202
+ auto c_dtype = c10::CppTypeToScalarType<T>::value;
203
+ return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
204
+ }
205
+
206
+ char transa;
207
+ char transb;
208
+ int64_t m;
209
+ int64_t n;
210
+ int64_t k;
211
+ at::opmath_type<T> alpha;
212
+ const T* a;
213
+ int64_t lda;
214
+ const T* b;
215
+ int64_t ldb;
216
+ T* c;
217
+ int64_t ldc;
218
+ const T* bias;
219
+ at::cuda::blas::GEMMAndBiasActivationEpilogue activation;
220
+ private:
221
+ bool duplicate_inputs_;
222
+ };
223
+
224
+ template <typename T>
225
+ struct GemmStridedBatchedParams : OpParams {
226
+ GemmStridedBatchedParams() {
227
+ duplicate_inputs_ = false;
228
+ }
229
+
230
+ std::string Signature() const override {
231
+ return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
232
+ }
233
+
234
+ size_t GetSizeA() const {
235
+ return sizeof(T) * std::min(lda, stride_a) * ((transa == 'n' || transa == 'N') ? k : m) * batch;
236
+ }
237
+
238
+ size_t GetSizeB() const {
239
+ return sizeof(T) * std::min(ldb, stride_b) * ((transb == 'n' || transb == 'N') ? n : k) * batch;
240
+ }
241
+
242
+ size_t GetSizeC() const {
243
+ return sizeof(T) * std::min(ldc, stride_c) * n * batch;
244
+ }
245
+
246
+ size_t GetSize(bool duplicate_inputs) const {
247
+ size_t size = GetSizeC();
248
+ if (duplicate_inputs) {
249
+ size += GetSizeA();
250
+ size += GetSizeB();
251
+ }
252
+ return size;
253
+ }
254
+
255
+ GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const {
256
+ GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
257
+ *copy = *this;
258
+ c10::DeviceIndex device = 0;
259
+ AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
260
+ size_t c_size = GetSizeC();
261
+ copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
262
+ AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
263
+ copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
264
+ if (duplicate_inputs) {
265
+ size_t a_size = GetSizeA();
266
+ size_t b_size = GetSizeB();
267
+ copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
268
+ copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
269
+ copy->duplicate_inputs_ = true;
270
+ }
271
+ return copy;
272
+ }
273
+
274
+ // only call on object returned by DeepCopy
275
+ void Delete() {
276
+ c10::cuda::CUDACachingAllocator::raw_delete(c);
277
+ if (duplicate_inputs_) {
278
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
279
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
280
+ }
281
+ }
282
+
283
+ TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
284
+ auto c_dtype = c10::CppTypeToScalarType<T>::value;
285
+ return detail::NumericalCheck(c_dtype, c, other->c, batch*stride_c) ? OK : FAIL;
286
+ }
287
+
288
+ char transa;
289
+ char transb;
290
+ int64_t m;
291
+ int64_t n;
292
+ int64_t k;
293
+ at::opmath_type<T> alpha;
294
+ const T* a;
295
+ int64_t lda;
296
+ int64_t stride_a;
297
+ const T* b;
298
+ int64_t ldb;
299
+ int64_t stride_b;
300
+ at::opmath_type<T> beta;
301
+ T* c;
302
+ int64_t ldc;
303
+ int64_t stride_c;
304
+ int64_t batch;
305
+ private:
306
+ bool duplicate_inputs_;
307
+ };
308
+
309
+ template <typename T>
310
+ struct ScaledGemmParams : OpParams {
311
+ ScaledGemmParams() {
312
+ duplicate_inputs_ = false;
313
+ }
314
+
315
+ std::string Signature() const override {
316
+ return c10::str(transa, transb, "_", m, "_", n, "_", k);
317
+ }
318
+
319
+ size_t GetSizeA() const {
320
+ return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
321
+ }
322
+
323
+ size_t GetSizeB() const {
324
+ return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
325
+ }
326
+
327
+ size_t GetSizeC() const {
328
+ return sizeof(T) * ldc * n;
329
+ }
330
+
331
+ size_t GetSize(bool duplicate_inputs) const {
332
+ size_t size = GetSizeC();
333
+ if (duplicate_inputs) {
334
+ size += GetSizeA();
335
+ size += GetSizeB();
336
+ }
337
+ return size;
338
+ }
339
+
340
+ ScaledGemmParams* DeepCopy(bool duplicate_inputs) const {
341
+ ScaledGemmParams* copy = new ScaledGemmParams;
342
+ *copy = *this;
343
+ c10::DeviceIndex device = 0;
344
+ AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
345
+ size_t c_size = GetSizeC();
346
+ copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size);
347
+ AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
348
+ copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
349
+ if (duplicate_inputs) {
350
+ size_t a_size = GetSizeA();
351
+ size_t b_size = GetSizeB();
352
+ copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size);
353
+ copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size);
354
+ copy->duplicate_inputs_ = true;
355
+ }
356
+ return copy;
357
+ }
358
+
359
+ // only call on object returned by DeepCopy
360
+ void Delete() {
361
+ c10::cuda::CUDACachingAllocator::raw_delete(c);
362
+ if (duplicate_inputs_) {
363
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(a));
364
+ c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(b));
365
+ }
366
+ }
367
+
368
+ TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
369
+ return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
370
+ }
371
+
372
+ char transa;
373
+ char transb;
374
+ int64_t m;
375
+ int64_t n;
376
+ int64_t k;
377
+ const void* a;
378
+ const void* a_scale_ptr;
379
+ int64_t lda;
380
+ ScalarType a_dtype;
381
+ const void* b;
382
+ const void* b_scale_ptr;
383
+ int64_t ldb;
384
+ ScalarType b_dtype;
385
+ const void* bias_ptr;
386
+ ScalarType bias_dtype;
387
+ void* c;
388
+ const void* c_scale_ptr;
389
+ int64_t ldc;
390
+ ScalarType c_dtype;
391
+ void* amax_ptr;
392
+ bool use_fast_accum;
393
+ private:
394
+ bool duplicate_inputs_;
395
+ };
396
+
397
+ } // namespace at::cuda::tunable
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ #include <ATen/cuda/CUDAContext.h>
7
+ #include <ATen/cuda/CUDADataType.h>
8
+ #include <ATen/cuda/tunable/TunableOp.h>
9
+ #include <ATen/cuda/tunable/GemmCommon.h>
10
+ #include <c10/cuda/CUDACachingAllocator.h>
11
+ #include <c10/util/StringUtil.h>
12
+
13
+ #include <hipblaslt/hipblaslt.h>
14
+ #include <hipblaslt/hipblaslt-ext.hpp>
15
+
16
+ #define TORCH_HIPBLASLT_CHECK(EXPR) \
17
+ do { \
18
+ hipblasStatus_t __err = EXPR; \
19
+ TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \
20
+ "hipblaslt error: ", \
21
+ hipblasStatusToString(__err), \
22
+ " when calling `" #EXPR "`"); \
23
+ } while (0)
24
+
25
+ namespace at::cuda::tunable {
26
+
27
+ template <typename T>
28
+ constexpr hipblasDatatype_t HipDataTypeFor();
29
+
30
+ template <>
31
+ constexpr hipblasDatatype_t HipDataTypeFor<float>() {
32
+ return HIP_R_32F;
33
+ }
34
+
35
+ template <>
36
+ constexpr hipblasDatatype_t HipDataTypeFor<Half>() {
37
+ return HIP_R_16F;
38
+ }
39
+
40
+ template <>
41
+ constexpr hipblasDatatype_t HipDataTypeFor<BFloat16>() {
42
+ return HIP_R_16BF;
43
+ }
44
+
45
+ template <>
46
+ constexpr hipblasDatatype_t HipDataTypeFor<double>() {
47
+ return HIP_R_64F;
48
+ }
49
+
50
+ template <>
51
+ constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e4m3fnuz>() {
52
+ return HIP_R_8F_E4M3_FNUZ;
53
+ }
54
+
55
+ template <>
56
+ constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e5m2fnuz>() {
57
+ return HIP_R_8F_E5M2_FNUZ;
58
+ }
59
+
60
+ template <typename T>
61
+ int GetBatchFromParams(const GemmParams<T>* params) {
62
+ return 1;
63
+ }
64
+
65
+ template <typename T>
66
+ int GetBatchFromParams(const GemmAndBiasParams<T>* params) {
67
+ return 1;
68
+ }
69
+
70
+ template <typename T>
71
+ int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
72
+ return params->batch;
73
+ }
74
+
75
+ template <typename T>
76
+ int GetBatchFromParams(const ScaledGemmParams<T>* params) {
77
+ return 1;
78
+ }
79
+
80
+ template <typename T>
81
+ int GetStrideAFromParams(const GemmParams<T>* params) {
82
+ return 1;
83
+ }
84
+
85
+ template <typename T>
86
+ int GetStrideAFromParams(const GemmAndBiasParams<T>* params) {
87
+ return 1;
88
+ }
89
+
90
+ template <typename T>
91
+ int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
92
+ return params->stride_a;
93
+ }
94
+
95
+ template <typename T>
96
+ int GetStrideAFromParams(const ScaledGemmParams<T>* params) {
97
+ return 1;
98
+ }
99
+
100
+ template <typename T>
101
+ int GetStrideBFromParams(const GemmParams<T>* params) {
102
+ return 1;
103
+ }
104
+
105
+ template <typename T>
106
+ int GetStrideBFromParams(const GemmAndBiasParams<T>* params) {
107
+ return 1;
108
+ }
109
+
110
+ template <typename T>
111
+ int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
112
+ return params->stride_b;
113
+ }
114
+
115
+ template <typename T>
116
+ int GetStrideBFromParams(const ScaledGemmParams<T>* params) {
117
+ return 1;
118
+ }
119
+
120
+ template <typename T>
121
+ int GetStrideCFromParams(const GemmParams<T>* params) {
122
+ return 1;
123
+ }
124
+
125
+ template <typename T>
126
+ int GetStrideCFromParams(const GemmAndBiasParams<T>* params) {
127
+ return 1;
128
+ }
129
+
130
+ template <typename T>
131
+ int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
132
+ return params->stride_c;
133
+ }
134
+
135
+ template <typename T>
136
+ int GetStrideCFromParams(const ScaledGemmParams<T>* params) {
137
+ return 1;
138
+ }
139
+
140
+ template <typename T>
141
+ float GetAlphaFromParams(const GemmParams<T>* params) {
142
+ return params->alpha;
143
+ }
144
+
145
+ template <typename T>
146
+ float GetAlphaFromParams(const GemmAndBiasParams<T>* params) {
147
+ return params->alpha;
148
+ }
149
+
150
+ template <typename T>
151
+ float GetAlphaFromParams(const GemmStridedBatchedParams<T>* params) {
152
+ return params->alpha;
153
+ }
154
+
155
+ template <typename T>
156
+ float GetAlphaFromParams(const ScaledGemmParams<T>* params) {
157
+ return 1.0;
158
+ }
159
+
160
+ template <typename T>
161
+ float GetBetaFromParams(const GemmParams<T>* params) {
162
+ return params->beta;
163
+ }
164
+
165
+ template <typename T>
166
+ float GetBetaFromParams(const GemmAndBiasParams<T>* params) {
167
+ return 0.0;
168
+ }
169
+
170
+ template <typename T>
171
+ float GetBetaFromParams(const GemmStridedBatchedParams<T>* params) {
172
+ return params->beta;
173
+ }
174
+
175
+ template <typename T>
176
+ float GetBetaFromParams(const ScaledGemmParams<T>* params) {
177
+ return 0.0;
178
+ }
179
+
180
+ template <typename T>
181
+ const void* GetAScalePointerFromParams(const GemmParams<T>* params) {
182
+ return nullptr;
183
+ }
184
+
185
+ template <typename T>
186
+ const void* GetAScalePointerFromParams(const GemmAndBiasParams<T>* params) {
187
+ return nullptr;
188
+ }
189
+
190
+ template <typename T>
191
+ const void* GetAScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
192
+ return nullptr;
193
+ }
194
+
195
+ template <typename T>
196
+ const void* GetAScalePointerFromParams(const ScaledGemmParams<T>* params) {
197
+ return params->a_scale_ptr;
198
+ }
199
+
200
+ template <typename T>
201
+ const void* GetBScalePointerFromParams(const GemmParams<T>* params) {
202
+ return nullptr;
203
+ }
204
+
205
+ template <typename T>
206
+ const void* GetBScalePointerFromParams(const GemmAndBiasParams<T>* params) {
207
+ return nullptr;
208
+ }
209
+
210
+ template <typename T>
211
+ const void* GetBScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
212
+ return nullptr;
213
+ }
214
+
215
+ template <typename T>
216
+ const void* GetBScalePointerFromParams(const ScaledGemmParams<T>* params) {
217
+ return params->b_scale_ptr;
218
+ }
219
+
220
+ template <typename T>
221
+ const void* GetDScalePointerFromParams(const GemmParams<T>* params) {
222
+ return nullptr;
223
+ }
224
+
225
+ template <typename T>
226
+ const void* GetDScalePointerFromParams(const GemmAndBiasParams<T>* params) {
227
+ return nullptr;
228
+ }
229
+
230
+ template <typename T>
231
+ const void* GetDScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
232
+ return nullptr;
233
+ }
234
+
235
+ template <typename T>
236
+ const void* GetDScalePointerFromParams(const ScaledGemmParams<T>* params) {
237
+ return params->c_scale_ptr;
238
+ }
239
+
240
+ template <typename T>
241
+ const void* GetBiasPointerFromParams(const GemmParams<T>* params) {
242
+ return nullptr;
243
+ }
244
+
245
+ template <typename T>
246
+ const void* GetBiasPointerFromParams(const GemmAndBiasParams<T>* params) {
247
+ return params->bias;
248
+ }
249
+
250
+ template <typename T>
251
+ const void* GetBiasPointerFromParams(const GemmStridedBatchedParams<T>* params) {
252
+ return nullptr;
253
+ }
254
+
255
+ template <typename T>
256
+ const void* GetBiasPointerFromParams(const ScaledGemmParams<T>* params) {
257
+ return params->bias_ptr;
258
+ }
259
+
260
+ template <typename T>
261
+ hipDataType GetBiasTypeFromParams(const GemmParams<T>* params) {
262
+ return HIP_R_32F;
263
+ }
264
+
265
+ template <typename T>
266
+ hipDataType GetBiasTypeFromParams(const GemmAndBiasParams<T>* params) {
267
+ return HipDataTypeFor<T>();
268
+ }
269
+
270
+ template <typename T>
271
+ hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams<T>* params) {
272
+ return HIP_R_32F;
273
+ }
274
+
275
+ template <typename T>
276
+ hipDataType GetBiasTypeFromParams(const ScaledGemmParams<T>* params) {
277
+ return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
278
+ }
279
+
280
+ template <typename T>
281
+ at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams<T>* params) {
282
+ return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
283
+ }
284
+
285
+ template <typename T>
286
+ at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams<T>* params) {
287
+ return params->activation;
288
+ }
289
+
290
+ template <typename T>
291
+ at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams<T>* params) {
292
+ return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
293
+ }
294
+
295
+ template <typename T>
296
+ at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams<T>* params) {
297
+ return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
298
+ }
299
+
300
+ static hipblasOperation_t _hipblasOpFromChar(char op) {
301
+ switch (op) {
302
+ case 'n':
303
+ case 'N':
304
+ return HIPBLAS_OP_N;
305
+ case 't':
306
+ case 'T':
307
+ return HIPBLAS_OP_T;
308
+ case 'c':
309
+ case 'C':
310
+ return HIPBLAS_OP_C;
311
+ }
312
+ AT_ERROR(
313
+ "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
314
+ }
315
+
316
+ static char _charFromhipblasOp(hipblasOperation_t op) {
317
+ switch (op) {
318
+ case HIPBLAS_OP_N:
319
+ return 'N';
320
+ case HIPBLAS_OP_T:
321
+ return 'T';
322
+ case HIPBLAS_OP_C:
323
+ return 'C';
324
+ }
325
+ AT_ERROR(
326
+ "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
327
+ }
328
+
329
+ static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
330
+ if (layout == BlasOp::N) {
331
+ return HIPBLAS_OP_N;
332
+ }
333
+ return HIPBLAS_OP_T;
334
+ }
335
+
336
+ static size_t GetHipblasltWorkspaceSize() {
337
+ static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE");
338
+ // 256MB is max workspace size allowed for hipblaslt
339
+ // hipblaslt-bench uses 32MB
340
+ // recommendation from hipblaslt author was 76MB
341
+ size_t workspace_size = 32*1024; // going with 32MB
342
+ if (env) {
343
+ try {
344
+ workspace_size = std::stoi(env);
345
+ } catch(std::invalid_argument const& e) {
346
+ TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
347
+ " using default workspace size of ", workspace_size, " KiB.");
348
+ } catch(std::out_of_range const& e) {
349
+ TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
350
+ " using default workspace size of ", workspace_size, " KiB.");
351
+ }
352
+ }
353
+ return workspace_size * 1024;
354
+ }
355
+
356
+ template <typename T, cublasStatus_t (*destructor)(T*)>
357
+ struct HipBlasLtDeleter {
358
+ void operator()(T* x) {
359
+ if (x != nullptr) {
360
+ TORCH_CUDABLAS_CHECK(destructor(x));
361
+ }
362
+ }
363
+ };
364
+
365
+ template <typename T, hipblasStatus_t (*destructor)(T*)>
366
+ class HipBlasLtDescriptor {
367
+ public:
368
+ T* descriptor() const {
369
+ return descriptor_.get();
370
+ }
371
+ T* descriptor() {
372
+ return descriptor_.get();
373
+ }
374
+
375
+ protected:
376
+ std::unique_ptr<T, HipBlasLtDeleter<T, destructor>> descriptor_;
377
+ };
378
+
379
+ class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor<
380
+ hipblasLtMatmulDescOpaque_t,
381
+ &hipblasLtMatmulDescDestroy> {
382
+ public:
383
+ HipBlasLtMatmulDescriptor(
384
+ hipblasComputeType_t compute_type,
385
+ hipDataType scale_type) {
386
+ hipblasLtMatmulDesc_t raw_descriptor = nullptr;
387
+ TORCH_HIPBLASLT_CHECK(
388
+ hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
389
+ descriptor_.reset(raw_descriptor);
390
+ }
391
+ template <typename T>
392
+ inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) {
393
+ TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
394
+ }
395
+ };
396
+
397
+ template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
398
+ class HipblasltGemmOp : public Callable<ParamsT> {
399
+ public:
400
+ HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
401
+
402
+ TuningStatus Call(const ParamsT* params) override {
403
+ hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
404
+ hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
405
+ auto a_datatype = HipDataTypeFor<AT>();
406
+ auto b_datatype = HipDataTypeFor<BT>();
407
+ auto in_out_datatype = HipDataTypeFor<CT>();
408
+ auto opa = _hipblasOpFromChar(params->transa);
409
+ auto opb = _hipblasOpFromChar(params->transb);
410
+
411
+ TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
412
+
413
+ float alpha = GetAlphaFromParams<CT>(params);
414
+ float beta = GetBetaFromParams<CT>(params);
415
+
416
+ hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
417
+ if (opa == HIPBLAS_OP_N) {
418
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda));
419
+ }
420
+ else {
421
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda));
422
+ }
423
+ if (opb == HIPBLAS_OP_N) {
424
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb));
425
+ }
426
+ else {
427
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb));
428
+ }
429
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
430
+
431
+ // specific to batched gemmm
432
+ int batch = GetBatchFromParams<CT>(params);
433
+ if (batch > 1) {
434
+ int64_t stride_a = GetStrideAFromParams<CT>(params);
435
+ int64_t stride_b = GetStrideBFromParams<CT>(params);
436
+ int64_t stride_c = GetStrideCFromParams<CT>(params);
437
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
438
+ mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
439
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
440
+ mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
441
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
442
+ mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
443
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
444
+ mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
445
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
446
+ mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
447
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
448
+ mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
449
+ }
450
+
451
+ HipBlasLtMatmulDescriptor matmul(HIPBLAS_COMPUTE_32F, HIP_R_32F);
452
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
453
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
454
+
455
+ // specific to scaled gemm
456
+ const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
457
+ const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
458
+ const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
459
+ if (mat1_scale_ptr && mat2_scale_ptr) {
460
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
461
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
462
+ }
463
+ if (result_scale_ptr) {
464
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
465
+ }
466
+
467
+ const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
468
+ auto bias_datatype = GetBiasTypeFromParams<CT>(params);
469
+ if (bias_ptr) {
470
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
471
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
472
+ auto activation = GetActivationFromParams<CT>(params);
473
+ if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) {
474
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS);
475
+ }
476
+ else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) {
477
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS);
478
+ }
479
+ else {
480
+ matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
481
+ }
482
+ }
483
+
484
+ size_t workspace_size = GetHipblasltWorkspaceSize();
485
+
486
+ auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
487
+
488
+ size_t ret_workspace_size = 0;
489
+ auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
490
+ matmul.descriptor(),
491
+ &alpha,
492
+ mat_a,
493
+ mat_b,
494
+ &beta,
495
+ mat_c,
496
+ mat_c,
497
+ algo_,
498
+ ret_workspace_size);
499
+
500
+ if (status == HIPBLAS_STATUS_SUCCESS) {
501
+ if (ret_workspace_size >= workspace_size) {
502
+ return FAIL;
503
+ }
504
+ }
505
+ else {
506
+ return FAIL;
507
+ }
508
+
509
+ void* workspace_buffer = nullptr;
510
+ if (workspace_size > 0) {
511
+ workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size);
512
+ }
513
+
514
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
515
+ matmul.descriptor(),
516
+ &alpha,
517
+ params->a,
518
+ mat_a,
519
+ params->b,
520
+ mat_b,
521
+ &beta,
522
+ params->c,
523
+ mat_c,
524
+ params->c,
525
+ mat_c,
526
+ &algo_,
527
+ workspace_buffer,
528
+ workspace_size,
529
+ at::cuda::getCurrentCUDAStream()));
530
+
531
+ //TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
532
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
533
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
534
+ TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
535
+ if (workspace_size > 0) {
536
+ c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer);
537
+ }
538
+ return OK;
539
+ }
540
+
541
+ private:
542
+ hipblasLtMatmulAlgo_t algo_;
543
+ };
544
+
545
+ template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
546
+ auto GetHipBlasLtTypeStringAndOps() {
547
+ hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
548
+ hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
549
+ auto a_datatype = HipDataTypeFor<AT>();
550
+ auto b_datatype = HipDataTypeFor<BT>();
551
+ auto in_out_datatype = HipDataTypeFor<CT>();
552
+ std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
553
+
554
+ hipblasLtHandle_t handle;
555
+ TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
556
+ TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
557
+ hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
558
+ transa_outer,
559
+ transb_outer,
560
+ a_datatype,
561
+ b_datatype,
562
+ in_out_datatype,
563
+ in_out_datatype,
564
+ HIPBLAS_COMPUTE_32F,
565
+ heuristic_result));
566
+ TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
567
+
568
+ // Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
569
+ std::sort(heuristic_result.begin(),
570
+ heuristic_result.end(),
571
+ [](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
572
+ return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo);
573
+ });
574
+
575
+ int returned_algo_count = heuristic_result.size();
576
+ std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
577
+ for (int i = 0; i < returned_algo_count; i++) {
578
+ auto algo = heuristic_result[i].algo;
579
+ int algo_index = hipblaslt_ext::getIndexFromAlgo(algo);
580
+ auto callable = std::make_unique<HipblasltGemmOp<AT, BT, CT, ALayout, BLayout, ParamsT>>(algo);
581
+ std::string type_string = c10::str(
582
+ "Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
583
+ ret.emplace_back(type_string, std::move(callable));
584
+ }
585
+
586
+ return ret;
587
+ }
588
+
589
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
590
+ auto GetHipBlasLtGemmTypeStringAndOps() {
591
+ return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmParams<T>>();
592
+ }
593
+
594
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
595
+ auto GetHipBlasLtGemmAndBiasTypeStringAndOps() {
596
+ return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmAndBiasParams<T>>();
597
+ }
598
+
599
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
600
+ auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
601
+ return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
602
+ }
603
+
604
+ template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
605
+ auto GetHipBlasLtScaledGemmTypeStringAndOps() {
606
+ return GetHipBlasLtTypeStringAndOps<AT, BT, CT, ALayout, BLayout, ScaledGemmParams<CT>>();
607
+ }
608
+
609
+ #undef TORCH_HIPBLASLT_CHECK
610
+
611
+ } // namespace at::cuda::tunable
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ #include <ATen/cuda/CUDAContext.h>
7
+ #include <ATen/cuda/tunable/TunableOp.h>
8
+ #include <ATen/cuda/tunable/GemmCommon.h>
9
+ #include <c10/util/StringUtil.h>
10
+
11
+ #define ROCBLAS_BETA_FEATURES_API
12
+ #include <rocblas/rocblas.h>
13
+
14
+ #define TORCH_ROCBLAS_CHECK(EXPR) \
15
+ do { \
16
+ rocblas_status __err = EXPR; \
17
+ TORCH_CHECK(__err == rocblas_status_success, \
18
+ "rocblas error: ", \
19
+ rocblas_status_to_string(__err), \
20
+ " when calling `" #EXPR "`"); \
21
+ } while (0)
22
+
23
+ namespace at::cuda::tunable {
24
+
25
+ template <typename T>
26
+ constexpr rocblas_datatype RocBlasDataTypeFor();
27
+
28
+ template <>
29
+ constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
30
+ return rocblas_datatype_f32_r;
31
+ }
32
+
33
+ template <>
34
+ constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
35
+ return rocblas_datatype_f64_r;
36
+ }
37
+
38
+ template <>
39
+ constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
40
+ return rocblas_datatype_f16_r;
41
+ }
42
+
43
+ template <>
44
+ constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
45
+ return rocblas_datatype_bf16_r;
46
+ }
47
+
48
+ template <>
49
+ constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
50
+ return rocblas_datatype_f32_c;
51
+ }
52
+
53
+ template <>
54
+ constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
55
+ return rocblas_datatype_f64_c;
56
+ }
57
+
58
+ template <typename T>
59
+ constexpr rocblas_datatype RocBlasComputeTypeFor();
60
+
61
+ template <>
62
+ constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
63
+ return rocblas_datatype_f32_r;
64
+ }
65
+
66
+ template <>
67
+ constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
68
+ return rocblas_datatype_f64_r;
69
+ }
70
+
71
+ template <>
72
+ constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
73
+ // Note that we're returning the _compute_ type for a given datatype.
74
+ // As of 12/2022, using compute type FP16 for 16-bit floats was much
75
+ // slower than using compute type FP32. So we use FP32 compute even for
76
+ // FP16 datatypes. This is how GEMM is implemented even in the function
77
+ // rocblasGemmHelper (see fpgeneric.h)
78
+ return rocblas_datatype_f32_r;
79
+ }
80
+
81
+ template <>
82
+ constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
83
+ // Note that we're returning the _compute_ type for a given datatype.
84
+ // As of 12/2022, using compute type FP16 for 16-bit floats was much
85
+ // slower than using compute type FP32. So we use FP32 compute even for
86
+ // BF16 datatypes. This is how GEMM is implemented even in the function
87
+ // rocblasGemmHelper (see fpgeneric.h)
88
+ return rocblas_datatype_f32_r;
89
+ }
90
+
91
+ template <>
92
+ constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
93
+ return rocblas_datatype_f32_c;
94
+ }
95
+
96
+ template <>
97
+ constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
98
+ return rocblas_datatype_f64_c;
99
+ }
100
+
101
+ template <typename T>
102
+ auto DoCastForHalfOrBfloat16(const T fp) {
103
+ return fp;
104
+ }
105
+
106
+ template <>
107
+ inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
108
+ // alpha and beta should be the same as compute_type, in Half case it is float.
109
+ float h = fp;
110
+ return h;
111
+ }
112
+
113
+ template <>
114
+ inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
115
+ // alpha and beta should be the same as compute_type, in bfloat16 case it is float.
116
+ float h = fp;
117
+ return h;
118
+ }
119
+
120
+ static rocblas_operation _rocblasOpFromChar(char op) {
121
+ switch (op) {
122
+ case 'n':
123
+ case 'N':
124
+ return rocblas_operation_none;
125
+ case 't':
126
+ case 'T':
127
+ return rocblas_operation_transpose;
128
+ case 'c':
129
+ case 'C':
130
+ return rocblas_operation_conjugate_transpose;
131
+ }
132
+ AT_ERROR(
133
+ "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
134
+ }
135
+
136
+ template <typename T>
137
+ class RocblasGemmOp : public Callable<GemmParams<T>> {
138
+ public:
139
+ RocblasGemmOp(int solution) : solution_{solution} {}
140
+
141
+ TuningStatus Call(const GemmParams<T>* params) override {
142
+ auto input_output_type = RocBlasDataTypeFor<T>();
143
+ auto compute_type = RocBlasComputeTypeFor<T>();
144
+ auto h_a = DoCastForHalfOrBfloat16(params->alpha);
145
+ auto h_b = DoCastForHalfOrBfloat16(params->beta);
146
+ auto status = rocblas_gemm_ex(
147
+ (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
148
+ _rocblasOpFromChar(params->transa),
149
+ _rocblasOpFromChar(params->transb),
150
+ params->m, params->n, params->k,
151
+ &h_a,
152
+ params->a, input_output_type, params->lda,
153
+ params->b, input_output_type, params->ldb,
154
+ &h_b,
155
+ params->c, input_output_type, params->ldc,
156
+ params->c, input_output_type, params->ldc,
157
+ compute_type,
158
+ rocblas_gemm_algo_solution_index,
159
+ solution_,
160
+ rocblas_gemm_flags_none);
161
+ if (status != rocblas_status_success) {
162
+ return FAIL;
163
+ }
164
+ return OK;
165
+ }
166
+
167
+ private:
168
+ int solution_;
169
+ };
170
+
171
+ template <typename T>
172
+ auto GetRocBlasGemmTypeStringAndOps() {
173
+ rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
174
+ int solution_size;
175
+ auto input_output_type = RocBlasDataTypeFor<T>();
176
+ auto compute_type = RocBlasComputeTypeFor<T>();
177
+ // Get the number of available solutions
178
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
179
+ input_output_type,
180
+ input_output_type,
181
+ compute_type,
182
+ rocblas_gemm_flags_none,
183
+ nullptr,
184
+ &solution_size));
185
+ std::vector<int> solutions(solution_size);
186
+ // Get the list of available solutions
187
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
188
+ input_output_type,
189
+ input_output_type,
190
+ compute_type,
191
+ rocblas_gemm_flags_none,
192
+ solutions.data(),
193
+ &solution_size));
194
+ // Sort the solutions in ascending order to make the solution vector deterministic across runs
195
+ std::sort(solutions.begin(), solutions.end());
196
+
197
+ std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
198
+ for (size_t i = 0; i < solutions.size(); ++i) {
199
+ auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
200
+ ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
201
+ }
202
+ return ret;
203
+ }
204
+
205
+ template <typename T>
206
+ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
207
+ public:
208
+ RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
209
+
210
+ TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
211
+ auto input_output_type = RocBlasDataTypeFor<T>();
212
+ auto compute_type = RocBlasComputeTypeFor<T>();
213
+ auto h_a = DoCastForHalfOrBfloat16(params->alpha);
214
+ auto h_b = DoCastForHalfOrBfloat16(params->beta);
215
+ auto status = rocblas_gemm_strided_batched_ex(
216
+ (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
217
+ _rocblasOpFromChar(params->transa),
218
+ _rocblasOpFromChar(params->transb),
219
+ params->m, params->n, params->k,
220
+ &h_a,
221
+ params->a, input_output_type, params->lda, params->stride_a,
222
+ params->b, input_output_type, params->ldb, params->stride_b,
223
+ &h_b,
224
+ params->c, input_output_type, params->ldc, params->stride_c,
225
+ params->c, input_output_type, params->ldc, params->stride_c,
226
+ params->batch,
227
+ compute_type,
228
+ rocblas_gemm_algo_solution_index,
229
+ solution_,
230
+ rocblas_gemm_flags_none);
231
+ if (status != rocblas_status_success) {
232
+ return FAIL;
233
+ }
234
+ return OK;
235
+ }
236
+
237
+ private:
238
+ int solution_;
239
+ };
240
+
241
+ template <typename T>
242
+ auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
243
+ rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
244
+ int solution_size;
245
+ auto input_output_type = RocBlasDataTypeFor<T>();
246
+ auto compute_type = RocBlasComputeTypeFor<T>();
247
+ // Get the number of available solutions
248
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
249
+ input_output_type,
250
+ input_output_type,
251
+ compute_type,
252
+ rocblas_gemm_flags_none,
253
+ nullptr,
254
+ &solution_size));
255
+ std::vector<int> solutions(solution_size);
256
+ // Get the list of available solutions
257
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
258
+ input_output_type,
259
+ input_output_type,
260
+ compute_type,
261
+ rocblas_gemm_flags_none,
262
+ solutions.data(),
263
+ &solution_size));
264
+ // Sort the solutions in ascending order to make the solution vector deterministic across runs
265
+ std::sort(solutions.begin(), solutions.end());
266
+
267
+ std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
268
+ for (size_t i = 0; i < solutions.size(); ++i) {
269
+ auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
270
+ ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
271
+ }
272
+ return ret;
273
+ }
274
+
275
+ } // namespace at::cuda::tunable
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Original TunableOp is from onnxruntime.
2
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4
+ // Copyright (c) Microsoft Corporation.
5
+ // Licensed under the MIT license.
6
+ //
7
+ // Adapting TunableOp into PyTorch
8
+ // Copyright (c) Advanced Micro Devices, Inc.
9
+ //
10
+ #pragma once
11
+
12
+ #include <cuda_runtime.h>
13
+
14
+ #include <ATen/cuda/tunable/Tunable.h>
15
+
16
+ namespace at::cuda::tunable {
17
+
18
+ class StreamTimer : public ITimer {
19
+ public:
20
+ StreamTimer();
21
+ virtual ~StreamTimer() override;
22
+
23
+ void Start() override;
24
+
25
+ void End() override;
26
+
27
+ float Duration() override;
28
+
29
+ private:
30
+ cudaEvent_t start_;
31
+ cudaEvent_t end_;
32
+ };
33
+
34
+ } // namespace at::cuda::tunable
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Original TunableOp is from onnxruntime.
2
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4
+ // Copyright (c) Microsoft Corporation.
5
+ // Licensed under the MIT license.
6
+ //
7
+ // Adapting TunableOp into PyTorch
8
+ // Copyright (c) Advanced Micro Devices, Inc.
9
+ //
10
+ #pragma once
11
+
12
+ #include <c10/util/CallOnce.h>
13
+
14
+ #include <fstream>
15
+ #include <functional>
16
+ #include <iostream>
17
+ #include <memory>
18
+ #include <mutex>
19
+ #include <string>
20
+ #include <type_traits>
21
+ #include <unordered_map>
22
+ #include <utility>
23
+ #include <vector>
24
+
25
+ namespace at::cuda::tunable {
26
+
27
+ namespace detail {
28
+
29
+ struct MaybeDelete {
30
+ bool owns_pointer;
31
+ void operator()(std::ostream* os) const { if (owns_pointer) delete os; }
32
+ };
33
+
34
+ using OstreamPtr = std::unique_ptr<std::ostream, MaybeDelete>;
35
+
36
+ static OstreamPtr get_stream(std::string filename) {
37
+ if (filename.compare("out") == 0) {
38
+ return OstreamPtr { &std::cout, MaybeDelete {false} };
39
+ }
40
+ else if (filename.compare("err") == 0) {
41
+ return OstreamPtr { &std::cerr, MaybeDelete {false} };
42
+ }
43
+ else {
44
+ return OstreamPtr { new std::ofstream {filename.c_str()}, MaybeDelete {true} };
45
+ }
46
+ }
47
+
48
+ }
49
+
50
+ static void TunableLog(int level, const std::string& msg) {
51
+ static const char *env_file = getenv("PYTORCH_TUNABLEOP_VERBOSE_FILENAME");
52
+ static const char *env_verbose = getenv("PYTORCH_TUNABLEOP_VERBOSE");
53
+ static int level_user = env_verbose ? atoi(env_verbose) : 0;
54
+ static auto streamptr = detail::get_stream(env_file ? env_file : "err");
55
+ if (level_user >= level) {
56
+ (*streamptr) << msg <<std::endl;
57
+ }
58
+ }
59
+ #define TUNABLE_LOGV(LEVEL, ...) TunableLog(LEVEL, c10::str(__VA_ARGS__))
60
+ #define TUNABLE_LOG1(...) TUNABLE_LOGV(1, __VA_ARGS__)
61
+ #define TUNABLE_LOG2(...) TUNABLE_LOGV(2, __VA_ARGS__)
62
+ #define TUNABLE_LOG3(...) TUNABLE_LOGV(3, __VA_ARGS__)
63
+
64
+ enum TORCH_CUDA_CPP_API TuningStatus {
65
+ OK = 0,
66
+ FAIL = 1,
67
+ UNSUPPORTED = 2,
68
+ };
69
+
70
+ // Mapping from params signature to kernel id
71
+ class TORCH_CUDA_CPP_API ResultEntry {
72
+ public:
73
+ explicit ResultEntry(const std::string& key, double time) : key_(key), time_(time) {}
74
+ bool operator==(const ResultEntry& other) { return key_ == other.key_; }
75
+ bool operator!=(const ResultEntry& other) { return key_ != other.key_; }
76
+ operator std::string () { return key_; }
77
+ std::string GetKey() const { return key_; }
78
+ double GetTime() const { return time_; }
79
+ friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
80
+ static ResultEntry Null() { return ResultEntry("Null", 0.0); }
81
+ static ResultEntry Default() { return ResultEntry("Default", 0.0); }
82
+
83
+ private:
84
+ std::string key_;
85
+ double time_;
86
+ };
87
+
88
+ typedef std::unordered_map<std::string, ResultEntry> KernelMap;
89
+ typedef std::unordered_map<std::string, KernelMap> ResultsMap;
90
+
91
+ struct TORCH_CUDA_CPP_API TuningResults {
92
+ // Validates if these results are compatible with the libraries
93
+ std::unordered_map<std::string, std::string> validators;
94
+
95
+ // Mapping from Callable signature to Callable's tuning result
96
+ ResultsMap results;
97
+ };
98
+
99
+ class TORCH_CUDA_CPP_API TuningResultsManager {
100
+ public:
101
+ TuningResultsManager() = default;
102
+ ~TuningResultsManager() = default;
103
+
104
+ KernelMap Lookup(const std::string& op_signature);
105
+
106
+ ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
107
+
108
+ inline void AddImpl(const std::string& op_signature,
109
+ const std::string& params_signature,
110
+ ResultEntry best,
111
+ KernelMap& kernel_map);
112
+
113
+ void Add(const std::string& op_signature,
114
+ const std::string& params_signature,
115
+ ResultEntry best);
116
+
117
+ void Delete(const std::string& op_signature, const std::string& params_signature);
118
+
119
+ inline void DisjointMergeImpl(
120
+ const std::string& op_signature,
121
+ const KernelMap& kernel_map,
122
+ /*out*/ ResultsMap& results);
123
+
124
+ void Load(const ResultsMap& results_to_load);
125
+
126
+ ResultsMap Dump();
127
+
128
+ void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
129
+
130
+ size_t GetSize();
131
+
132
+ private:
133
+ std::mutex lock_;
134
+ ResultsMap results_;
135
+ };
136
+
137
+ class TORCH_CUDA_CPP_API TuningResultsValidator {
138
+ public:
139
+ using GetFunc = std::function<std::string()>;
140
+ using ValidateFunc = std::function<TuningStatus(const std::string&)>;
141
+ using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
142
+
143
+ TuningResultsValidator();
144
+ ~TuningResultsValidator() = default;
145
+
146
+ std::unordered_map<std::string, std::string> GetAllValidators() const;
147
+ TuningStatus ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
148
+ void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
149
+
150
+ protected:
151
+ std::string GetPyTorchVersion() const;
152
+ TuningStatus ValidatePyTorchVersion(const std::string& value) const;
153
+
154
+ public:
155
+ static constexpr const std::array mandatory_keys{"PT_VERSION"};
156
+
157
+ private:
158
+ GetValidateFuncs validators_;
159
+ };
160
+
161
+ class TORCH_CUDA_CPP_API TuningContext {
162
+ public:
163
+ TuningContext();
164
+ ~TuningContext();
165
+ TuningContext(TuningContext &) = delete;
166
+ TuningContext(TuningContext &&) = delete;
167
+ TuningContext &operator=(TuningContext &) = delete;
168
+ TuningContext &operator=(TuningContext &&) = delete;
169
+
170
+ void EnableTunableOp(bool value);
171
+ bool IsTunableOpEnabled() const;
172
+
173
+ void EnableTuning(bool value);
174
+ bool IsTuningEnabled() const;
175
+
176
+ void EnableNumericsCheck(bool value);
177
+ bool IsNumericsCheckEnabled() const;
178
+
179
+ void SetMaxTuningDurationMs(int max_duration_ms);
180
+ int GetMaxTuningDurationMs() const;
181
+
182
+ void SetMaxTuningIterations(int max_iter);
183
+ int GetMaxTuningIterations() const;
184
+
185
+ void SetMaxWarmupDurationMs(int max_duration_ms);
186
+ int GetMaxWarmupDurationMs() const;
187
+
188
+ void SetMaxWarmupIterations(int max_iter);
189
+ int GetMaxWarmupIterations() const;
190
+
191
+ void EnableICacheFlush(bool value);
192
+ bool IsICacheFlushEnabled() const;
193
+
194
+ void SetRotatingBufferSize(int size);
195
+ int GetRotatingBufferSize() const;
196
+
197
+ TuningResultsManager& GetTuningResultsManager();
198
+
199
+ TuningResultsValidator& GetTuningResultsValidator();
200
+
201
+ TuningResults GetTuningResults();
202
+
203
+ TuningStatus LoadTuningResults(const TuningResults& tr);
204
+
205
+ void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
206
+ std::string GetFilename() const;
207
+
208
+ void WriteFileOnExit(bool value);
209
+
210
+ bool ReadFile(const std::string& filename={});
211
+ bool WriteFile(const std::string& filename={});
212
+
213
+ private:
214
+ bool enable_;
215
+ bool tuning_enable_;
216
+ bool manager_initialized_;
217
+ bool write_file_on_exit_;
218
+ bool numerics_check_enable_;
219
+ int max_tuning_duration_ms_;
220
+ int max_tuning_iterations_;
221
+ int max_warmup_duration_ms_;
222
+ int max_warmup_iterations_;
223
+ bool icache_flush_;
224
+ int rotating_buffer_size_;
225
+ mutable TuningResultsManager manager_;
226
+ mutable c10::once_flag manager_init_once_;
227
+ TuningResultsValidator validator_;
228
+ std::string filename_;
229
+ size_t results_count_from_input_file_;
230
+ };
231
+
232
+ TORCH_CUDA_CPP_API TuningContext* getTuningContext();
233
+
234
+ class ITimer {
235
+ public:
236
+ ITimer() = default;
237
+ virtual ~ITimer() = default;
238
+
239
+ virtual void Start() = 0;
240
+ virtual void End() = 0;
241
+
242
+ /// Computes the elapsed time in milliseconds between Start() and End()
243
+ virtual float Duration() = 0;
244
+ };
245
+
246
+ } // namespace at::cuda::tunable
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Original TunableOp is from onnxruntime.
2
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4
+ // Copyright (c) Microsoft Corporation.
5
+ // Licensed under the MIT license.
6
+ //
7
+ // Adapting TunableOp into PyTorch
8
+ // Copyright (c) Advanced Micro Devices, Inc.
9
+ //
10
+ #pragma once
11
+
12
+ #include <ATen/cuda/tunable/GemmCommon.h>
13
+ #ifdef USE_ROCM
14
+ #include <ATen/cuda/tunable/GemmHipblaslt.h>
15
+ #include <ATen/cuda/tunable/GemmRocblas.h>
16
+ #endif
17
+ #include <ATen/cuda/tunable/StreamTimer.h>
18
+ #include <ATen/cuda/tunable/TunableOp.h>
19
+ #include <c10/cuda/CUDACachingAllocator.h>
20
+ #include <c10/util/Float8_e4m3fn.h>
21
+ #include <c10/util/Float8_e4m3fnuz.h>
22
+ #include <c10/util/Float8_e5m2.h>
23
+ #include <c10/util/Float8_e5m2fnuz.h>
24
+ #include <c10/util/StringUtil.h>
25
+
26
+ namespace at::cuda::tunable {
27
+
28
+ template <typename T>
29
+ class DefaultGemmOp : public Callable<GemmParams<T>> {
30
+ public:
31
+ TuningStatus Call(const GemmParams<T>* params) override {
32
+ at::cuda::blas::gemm_internal<T>(
33
+ params->transa, params->transb,
34
+ params->m, params->n, params->k,
35
+ params->alpha,
36
+ params->a, params->lda,
37
+ params->b, params->ldb,
38
+ params->beta,
39
+ params->c, params->ldc);
40
+ return OK;
41
+ }
42
+ };
43
+
44
+ static bool _transposeBoolFromChar(char op) {
45
+ return op == 't' || op == 'T';
46
+ }
47
+
48
+ template <typename T>
49
+ class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
50
+ public:
51
+ TuningStatus Call(const GemmAndBiasParams<T>* params) override {
52
+ at::cuda::blas::gemm_and_bias<T>(
53
+ _transposeBoolFromChar(params->transa),
54
+ _transposeBoolFromChar(params->transb),
55
+ params->m, params->n, params->k,
56
+ params->alpha,
57
+ params->a, params->lda,
58
+ params->b, params->ldb,
59
+ params->bias,
60
+ params->c, params->ldc,
61
+ params->activation);
62
+ return OK;
63
+ }
64
+ };
65
+
66
+ template <typename T>
67
+ class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
68
+ public:
69
+ TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
70
+ at::cuda::blas::bgemm_internal<T>(
71
+ params->transa, params->transb,
72
+ params->m, params->n, params->k,
73
+ params->alpha,
74
+ params->a, params->lda, params->stride_a,
75
+ params->b, params->ldb, params->stride_b,
76
+ params->beta,
77
+ params->c, params->ldc, params->stride_c,
78
+ params->batch);
79
+ return OK;
80
+ }
81
+ };
82
+
83
+ template <typename T>
84
+ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
85
+ public:
86
+ TuningStatus Call(const ScaledGemmParams<T>* params) override {
87
+ at::cuda::blas::scaled_gemm(
88
+ params->transa,
89
+ params->transb,
90
+ params->m,
91
+ params->n,
92
+ params->k,
93
+ params->a,
94
+ params->a_scale_ptr,
95
+ params->lda,
96
+ params->a_dtype,
97
+ params->b,
98
+ params->b_scale_ptr,
99
+ params->ldb,
100
+ params->b_dtype,
101
+ params->bias_ptr,
102
+ params->bias_dtype,
103
+ params->c,
104
+ params->c_scale_ptr,
105
+ params->ldc,
106
+ params->c_dtype,
107
+ params->amax_ptr,
108
+ params->use_fast_accum);
109
+ return OK;
110
+ }
111
+ };
112
+
113
+ template <typename T>
114
+ inline bool IsZero(T v) {
115
+ return v == 0.0f;
116
+ }
117
+
118
+ template <>
119
+ inline bool IsZero(BFloat16 v) {
120
+ return v.x == 0;
121
+ }
122
+
123
+ template <>
124
+ inline bool IsZero(Half v) {
125
+ return float(v) == 0.0f;
126
+ }
127
+
128
+ template <>
129
+ inline bool IsZero(c10::complex<double> v) {
130
+ return v == 0.0;
131
+ }
132
+
133
+ template <>
134
+ inline bool IsZero(c10::complex<float> v) {
135
+ return v == 0.0f;
136
+ }
137
+
138
+ template <typename T>
139
+ inline std::string TypeName(T v) {
140
+ return "unknown";
141
+ }
142
+
143
+ template <>
144
+ inline std::string TypeName(float v) {
145
+ return "float";
146
+ }
147
+
148
+ template <>
149
+ inline std::string TypeName(double v) {
150
+ return "double";
151
+ }
152
+
153
+ template <>
154
+ inline std::string TypeName(BFloat16 v) {
155
+ return "BFloat16";
156
+ }
157
+
158
+ template <>
159
+ inline std::string TypeName(Half v) {
160
+ return "Half";
161
+ }
162
+
163
+ template <>
164
+ inline std::string TypeName(Float8_e4m3fn v) {
165
+ return "Float8_e4m3fn";
166
+ }
167
+
168
+ template <>
169
+ inline std::string TypeName(Float8_e5m2 v) {
170
+ return "Float8_e5m2";
171
+ }
172
+
173
+ template <>
174
+ inline std::string TypeName(Float8_e4m3fnuz v) {
175
+ return "Float8_e4m3fnuz";
176
+ }
177
+
178
+ template <>
179
+ inline std::string TypeName(Float8_e5m2fnuz v) {
180
+ return "Float8_e5m2fnuz";
181
+ }
182
+
183
+ template <>
184
+ inline std::string TypeName(c10::complex<double> v) {
185
+ return "c10::complex<double>";
186
+ }
187
+
188
+ template <>
189
+ inline std::string TypeName(c10::complex<float> v) {
190
+ return "c10::complex<float>";
191
+ }
192
+
193
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
194
+ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
195
+ public:
196
+ GemmTunableOp() {
197
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
198
+
199
+ #ifdef USE_ROCM
200
+ static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
201
+ if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
202
+ for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
203
+ this->RegisterOp(std::move(name), std::move(op));
204
+ }
205
+ }
206
+
207
+ static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
208
+ if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
209
+ // disallow tuning of hipblaslt with c10::complex
210
+ if constexpr (
211
+ !std::is_same_v<T, c10::complex<float>> &&
212
+ !std::is_same_v<T, c10::complex<double>>) {
213
+ for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
214
+ this->RegisterOp(std::move(name), std::move(op));
215
+ }
216
+ }
217
+ }
218
+ #endif
219
+ }
220
+
221
+ std::string Signature() override {
222
+ return c10::str("GemmTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
223
+ }
224
+ };
225
+
226
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
227
+ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer> {
228
+ public:
229
+ GemmAndBiasTunableOp() {
230
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
231
+
232
+ #ifdef USE_ROCM
233
+ static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
234
+ if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
235
+ // disallow tuning of hipblaslt with c10::complex
236
+ if constexpr (
237
+ !std::is_same_v<T, c10::complex<float>> &&
238
+ !std::is_same_v<T, c10::complex<double>>) {
239
+ for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
240
+ this->RegisterOp(std::move(name), std::move(op));
241
+ }
242
+ }
243
+ }
244
+ #endif
245
+ }
246
+
247
+ std::string Signature() override {
248
+ return c10::str("GemmAndBiasTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
249
+ }
250
+ };
251
+
252
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
253
+ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
254
+ public:
255
+ GemmStridedBatchedTunableOp() {
256
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
257
+
258
+ #ifdef USE_ROCM
259
+ static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
260
+ if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
261
+ for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
262
+ this->RegisterOp(std::move(name), std::move(op));
263
+ }
264
+ }
265
+
266
+ static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
267
+ if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
268
+ // disallow tuning of hipblaslt with c10::complex
269
+ if constexpr (
270
+ !std::is_same_v<T, c10::complex<float>> &&
271
+ !std::is_same_v<T, c10::complex<double>>) {
272
+ for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
273
+ this->RegisterOp(std::move(name), std::move(op));
274
+ }
275
+ }
276
+ }
277
+ #endif
278
+ }
279
+
280
+ std::string Signature() override {
281
+ return c10::str("GemmStridedBatchedTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
282
+ }
283
+ };
284
+
285
+ template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
286
+ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
287
+ public:
288
+ ScaledGemmTunableOp() {
289
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
290
+
291
+ #ifdef USE_ROCM
292
+ for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
293
+ this->RegisterOp(std::move(name), std::move(op));
294
+ }
295
+ #endif
296
+ }
297
+
298
+ std::string Signature() override {
299
+ return c10::str("ScaledGemmTunableOp",
300
+ "_", TypeName<AT>(AT{}),
301
+ "_", TypeName<BT>(BT{}),
302
+ "_", TypeName<CT>(CT{}),
303
+ "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
304
+ }
305
+ };
306
+
307
+ } // namespace at::cuda::tunable
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Original TunableOp is from onnxruntime.
2
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4
+ // Copyright (c) Microsoft Corporation.
5
+ // Licensed under the MIT license.
6
+ //
7
+ // Adapting TunableOp into PyTorch
8
+ // Copyright (c) Advanced Micro Devices, Inc.
9
+ //
10
+ #pragma once
11
+
12
+ #include <ATen/cuda/tunable/Tunable.h>
13
+ #include <ATen/cuda/Sleep.h>
14
+ #include <c10/cuda/CUDACachingAllocator.h>
15
+
16
+ #ifndef _WIN32
17
+ #include <cxxabi.h>
18
+ #endif
19
+
20
+ #include <string>
21
+ #include <type_traits>
22
+ #include <unordered_map>
23
+ #include <vector>
24
+
25
+ namespace at::cuda::tunable {
26
+
27
+ template <typename ParamsT>
28
+ class Callable {
29
+ public:
30
+ Callable() = default;
31
+ Callable(Callable&&) = default;
32
+ virtual ~Callable() = default;
33
+ virtual TuningStatus Call(const ParamsT*) {
34
+ return FAIL;
35
+ }
36
+ virtual TuningStatus IsSupported(const ParamsT* params) {
37
+ return Call(params);
38
+ }
39
+ };
40
+
41
+ template <typename ParamsT, typename TimerT>
42
+ class TunableOp {
43
+ public:
44
+ TunableOp() = default;
45
+ TunableOp(TunableOp&&) = default;
46
+ virtual ~TunableOp() = default;
47
+
48
+ TuningStatus operator()(const ParamsT* params) {
49
+ ResultEntry result = ResultEntry::Null();
50
+ TuningContext* ctx = getTuningContext();
51
+ if (ctx->IsTunableOpEnabled()) {
52
+ auto& mgr = ctx->GetTuningResultsManager();
53
+ auto op_sig = Signature();
54
+ auto params_sig = params->Signature();
55
+ result = mgr.Lookup(op_sig, params_sig);
56
+ // If there is not previous tuning result been found, we do the tuning iff tuning is enabled
57
+ if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) {
58
+ result = FindFastest(params);
59
+ mgr.Add(op_sig, params_sig, result);
60
+ }
61
+ }
62
+ else {
63
+ result = ResultEntry::Default();
64
+ }
65
+ if (result == ResultEntry::Null()) {
66
+ TUNABLE_LOG2("no result, using default");
67
+ result = ResultEntry::Default();
68
+ }
69
+ auto iter = ops_.find(result);
70
+ TORCH_CHECK(iter != ops_.end());
71
+ return iter->second->Call(params);
72
+ }
73
+
74
+ virtual std::string Signature() {
75
+ // According to C++17 standard https://wg21.link/n4659 section 15.7.4
76
+ // > if the operand of typeid refers to the
77
+ // > object under construction or destruction, typeid yields the std::type_info object representing the constructor
78
+ // > or destructor’s class.
79
+ // So delay the op signature generation.
80
+ c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
81
+ return signature_;
82
+ }
83
+
84
+ protected:
85
+ void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) {
86
+ this->op_names_.emplace_back(name);
87
+ this->ops_.emplace(name, std::move(op));
88
+ }
89
+
90
+ private:
91
+ static void WarmUp(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
92
+ TuningContext* ctx = getTuningContext();
93
+ bool do_flush = ctx->IsICacheFlushEnabled();
94
+ for (size_t i = 0; i < num_iter; i++) {
95
+ if (do_flush) {
96
+ at::cuda::flush_icache();
97
+ }
98
+ TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
99
+ }
100
+ }
101
+
102
+ static double Profile(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
103
+ TuningContext* ctx = getTuningContext();
104
+ bool do_flush = ctx->IsICacheFlushEnabled();
105
+ TimerT timer{};
106
+ timer.Start();
107
+ for (size_t i = 0; i < num_iter; i++) {
108
+ if (do_flush) {
109
+ at::cuda::flush_icache();
110
+ }
111
+ TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
112
+ }
113
+ timer.End();
114
+ return timer.Duration() / num_iter;
115
+ }
116
+
117
+ protected:
118
+ virtual ResultEntry FindFastest(const ParamsT* params) {
119
+ TuningContext* ctx = getTuningContext();
120
+ auto op_sig = Signature();
121
+ auto params_sig = params->Signature();
122
+ TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
123
+ auto min_duration_ms = std::numeric_limits<double>::infinity();
124
+ std::string id_name = "Default";
125
+ ParamsT* reference_params = nullptr;
126
+
127
+ // numeric check option is controlled by non-static env var, so check it once per tuned operator
128
+ bool do_numerics_check = ctx->IsNumericsCheckEnabled();
129
+
130
+ // calcaulte a reference answer for numerical check
131
+ if (do_numerics_check) {
132
+ reference_params = params->DeepCopy(false);
133
+ TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
134
+ }
135
+
136
+ // need copies of params to reuse
137
+ // make as many copies as will fill the requested rotating buffer size, if requested
138
+ // rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int
139
+ size_t rotating_size = ctx->GetRotatingBufferSize();
140
+ bool use_buffer_rotation = (rotating_size > 0);
141
+ size_t param_size = params->GetSize(use_buffer_rotation);
142
+ size_t param_count = (rotating_size / param_size) + 1;
143
+ constexpr size_t MB = 1024*1024;
144
+ if (use_buffer_rotation) {
145
+ TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ",
146
+ "Needed Size: ", param_size/MB, " MiB. ",
147
+ "Needed number of param copies: ", param_count);
148
+ }
149
+ TORCH_CHECK(param_count > 0);
150
+
151
+ std::vector<ParamsT*> reusable_params(param_count);
152
+ for (size_t i = 0; i < param_count; i++) {
153
+ reusable_params[i] = params->DeepCopy(use_buffer_rotation);
154
+ }
155
+
156
+ // for rotating buffer
157
+ size_t offset = 0;
158
+
159
+ for (size_t i = 0; i < op_names_.size(); i++) {
160
+ auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
161
+
162
+ if (do_numerics_check) {
163
+ ParamsT* numerical_params = params->DeepCopy(false);
164
+ auto status = candidate->Call(numerical_params);
165
+ if (status != OK) {
166
+ numerical_params->Delete();
167
+ TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
168
+ continue;
169
+ }
170
+ status = reference_params->NumericalCheck(numerical_params);
171
+ numerical_params->Delete();
172
+ if (status != OK) {
173
+ TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
174
+ continue;
175
+ }
176
+ }
177
+ else {
178
+ auto status = candidate->Call(reusable_params[0]);
179
+ if (status != OK) {
180
+ TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
181
+ continue;
182
+ }
183
+ }
184
+
185
+ // collect a small profile
186
+ constexpr const int approx_num_iter = 3;
187
+ auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset);
188
+ // bail if too slow
189
+ if (approx_duration > 2 * min_duration_ms) {
190
+ TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
191
+ continue;
192
+ }
193
+
194
+ // for warmup does user set max duration, max iters, or both?
195
+ // warmup is allowed to be skipped by setting either iterations or duration to 0
196
+ double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
197
+ int max_warmup_iter = ctx->GetMaxWarmupIterations();
198
+ int warmup_iter = 1; // default
199
+ if (max_warmup_duration >= 0) {
200
+ int duration_iters = max_warmup_duration / approx_duration;
201
+ if (max_warmup_iter >= 0) {
202
+ warmup_iter = std::min(max_warmup_iter, duration_iters);
203
+ }
204
+ else {
205
+ warmup_iter = duration_iters;
206
+ }
207
+ }
208
+ else if (max_warmup_iter >= 0) {
209
+ warmup_iter = max_warmup_iter;
210
+ }
211
+
212
+ // for tuning does user set max duration, max iters, or both?
213
+ double max_tuning_duration = ctx->GetMaxTuningDurationMs();
214
+ int max_tuning_iter = ctx->GetMaxTuningIterations();
215
+ int tuning_iter = 100; // default
216
+ if (max_tuning_duration > 0) {
217
+ int duration_iters = max_tuning_duration / approx_duration;
218
+ if (max_tuning_iter > 0) {
219
+ tuning_iter = std::min(max_tuning_iter, duration_iters);
220
+ }
221
+ else {
222
+ tuning_iter = duration_iters;
223
+ }
224
+ }
225
+ else if (max_tuning_iter > 0) {
226
+ tuning_iter = max_tuning_iter;
227
+ }
228
+ // tuning must run at least 1 iteration
229
+ tuning_iter = std::max(1, tuning_iter);
230
+
231
+ // do the full warmup followed by tuning
232
+ double warmup_ms = warmup_iter * approx_duration;
233
+ double tuning_ms = tuning_iter * approx_duration;
234
+ TUNABLE_LOG3("├──tuning using "
235
+ "warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
236
+ "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
237
+ "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
238
+ TUNABLE_LOG3("├──offset at ", offset);
239
+ WarmUp(candidate, reusable_params, warmup_iter, offset);
240
+ auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset);
241
+ if (duration_ms < min_duration_ms) {
242
+ TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
243
+ min_duration_ms = duration_ms;
244
+ id_name = op_names_[i];
245
+ }
246
+ }
247
+
248
+ for (size_t i = 0; i < reusable_params.size(); i++) {
249
+ reusable_params[i]->Delete();
250
+ }
251
+ if (reference_params) {
252
+ reference_params->Delete();
253
+ }
254
+
255
+ TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
256
+ return ResultEntry(id_name, min_duration_ms);
257
+ }
258
+
259
+ private:
260
+ std::string CreateSignature() {
261
+ #ifndef _WIN32
262
+ const auto* name = typeid(*this).name();
263
+ char buf[256];
264
+ size_t buf_len = 256;
265
+ abi::__cxa_demangle(name, buf, &buf_len, nullptr);
266
+ buf[255] = '\0';
267
+ return buf;
268
+ #else
269
+ return typeid(*this).name();
270
+ #endif
271
+ }
272
+
273
+ mutable c10::once_flag signature_init_once_;
274
+ std::string signature_;
275
+
276
+ std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_;
277
+ std::vector<std::string> op_names_;
278
+ };
279
+
280
+ struct OpParams {
281
+ OpParams() {}
282
+ virtual ~OpParams() = default;
283
+ virtual std::string Signature() const = 0;
284
+ };
285
+
286
+ } // namespace at::cuda::tunable
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <c10/util/Exception.h>
5
+ #include <c10/util/string_view.h>
6
+
7
+ namespace c10 {
8
+ class Scalar;
9
+ }
10
+
11
+ namespace at {
12
+ struct TensorIterator;
13
+ struct TensorIteratorBase;
14
+ class TensorBase;
15
+ }
16
+
17
+ namespace at::native {
18
+
19
+ // These constants control the approximation behavior of gelu function.
20
+ enum class GeluType {
21
+ None, // Baseline Gelu
22
+ Tanh, // Tahn Gelu Approximation
23
+ END
24
+ };
25
+
26
+ inline GeluType get_gelutype_enum(const c10::string_view approximate) {
27
+ if (approximate == "none") {
28
+ return GeluType::None;
29
+ } else if (approximate == "tanh") {
30
+ return GeluType::Tanh;
31
+ } else {
32
+ TORCH_CHECK(false, "approximate argument must be either none or tanh.");
33
+ }
34
+ }
35
+
36
+ inline std::string gelutype_to_string(const GeluType type) {
37
+ switch(type) {
38
+ case GeluType::None: return "none";
39
+ case GeluType::Tanh: return "tanh";
40
+ default: TORCH_CHECK(false, "unknown GELU type: ", static_cast<int>(type));
41
+ }
42
+ }
43
+
44
+ using structured_activation_fn = void (*)(TensorIteratorBase&);
45
+ using structured_activation_backward_fn = void (*)(TensorIteratorBase&);
46
+
47
+ using activation_fn = void (*)(TensorIterator&);
48
+ using activation_backward_fn = void (*)(TensorIterator&);
49
+ using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
50
+ using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
51
+ using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
52
+ using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&);
53
+ using hardsigmoid_fn = void(*)(TensorIteratorBase&);
54
+ using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&);
55
+ using hardswish_fn = void(*)(TensorIterator&);
56
+ using hardswish_backward_fn = void(*)(TensorIterator&);
57
+ using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
58
+ using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
59
+ using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
60
+ using elu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&);
61
+ using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&, bool);
62
+ using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
63
+ using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
64
+ using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
65
+ using gelu_fn = void (*)(TensorIteratorBase&, GeluType);
66
+ using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType);
67
+ using glu_jvp_fn = void (*)(TensorIteratorBase&);
68
+
69
+ DECLARE_DISPATCH(elu_fn, elu_stub);
70
+ DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
71
+ DECLARE_DISPATCH(softplus_fn, softplus_stub);
72
+ DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
73
+ DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
74
+ DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
75
+ DECLARE_DISPATCH(threshold_fn, threshold_stub);
76
+ DECLARE_DISPATCH(gelu_fn, GeluKernel);
77
+ DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
78
+ DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
79
+ DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
80
+ DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
81
+ DECLARE_DISPATCH(hardswish_fn, hardswish_stub);
82
+ DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub);
83
+ DECLARE_DISPATCH(shrink_fn, hardshrink_stub);
84
+ DECLARE_DISPATCH(softshrink_fn, softshrink_stub);
85
+ DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);
86
+ DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub);
87
+ DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);
88
+ DECLARE_DISPATCH(structured_activation_fn, glu_stub);
89
+ DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub);
90
+ DECLARE_DISPATCH(glu_jvp_fn, glu_jvp_stub);
91
+ DECLARE_DISPATCH(structured_activation_fn, silu_stub);
92
+ DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub);
93
+ DECLARE_DISPATCH(structured_activation_fn, mish_stub);
94
+ DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub);
95
+ DECLARE_DISPATCH(activation_fn, prelu_stub);
96
+ DECLARE_DISPATCH(activation_backward_fn, prelu_backward_stub);
97
+
98
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <c10/util/ArrayRef.h>
6
+ #include <c10/util/irange.h>
7
+ #include <cmath>
8
+
9
+ namespace at::native {
10
+
11
+ using adaptive_avg_pooling2d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
12
+ using adaptive_avg_pooling2d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
13
+ DECLARE_DISPATCH(adaptive_avg_pooling2d_fn, adaptive_avg_pool2d_kernel);
14
+ DECLARE_DISPATCH(adaptive_avg_pooling2d_backward_fn, adaptive_avg_pool2d_backward_kernel);
15
+
16
+ using adaptive_max_pooling2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
17
+ using adaptive_max_pooling2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
18
+ DECLARE_DISPATCH(adaptive_max_pooling2d_fn, adaptive_max_pool2d_kernel);
19
+ DECLARE_DISPATCH(adaptive_max_pooling2d_backward_fn, adaptive_max_pool2d_backward_kernel);
20
+
21
+ using adaptive_avg_pooling3d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
22
+ using adaptive_avg_pooling3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
23
+ DECLARE_DISPATCH(adaptive_avg_pooling3d_fn, adaptive_avg_pool3d_kernel);
24
+ DECLARE_DISPATCH(adaptive_avg_pooling3d_backward_fn, adaptive_avg_pool3d_backward_kernel);
25
+
26
+ using adaptive_max_pooling3d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
27
+ using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
28
+ DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel);
29
+ DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel);
30
+
31
+ inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
32
+ return (a / b) * c + ((a % b) * c) / b;
33
+ }
34
+
35
+ inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
36
+ return 1 + ((a + 1) * c - 1) / b;
37
+ }
38
+
39
+ inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) {
40
+ int64_t ndim = gradOutput_.ndimension();
41
+ for (const auto i : c10::irange(1, ndim)) {
42
+ TORCH_CHECK(gradOutput_.size(i) > 0,
43
+ arg_name, "(): Expected grad_output to have non-zero size for non-batch dimensions, "
44
+ "but grad_output has sizes ", gradOutput_.sizes(), " with dimension ", i,
45
+ " being empty");
46
+ }
47
+ }
48
+
49
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <optional>
4
+ #include <c10/util/string_view.h>
5
+ #include <ATen/Config.h>
6
+ #include <ATen/native/DispatchStub.h>
7
+
8
+ // Forward declare TI
9
+ namespace at {
10
+ class Tensor;
11
+ struct TensorIterator;
12
+
13
+ namespace native {
14
+ enum class TransposeType;
15
+ }
16
+
17
+ }
18
+
19
+ namespace at::native {
20
+
21
+ enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss};
22
+
23
+ #if AT_BUILD_WITH_LAPACK()
24
+ // Define per-batch functions to be used in the implementation of batched
25
+ // linear algebra operations
26
+
27
+ template <class scalar_t>
28
+ void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);
29
+
30
+ template <class scalar_t>
31
+ void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
32
+
33
+ template <class scalar_t, class value_t=scalar_t>
34
+ void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
35
+
36
+ template <class scalar_t>
37
+ void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
38
+
39
+ template <class scalar_t>
40
+ void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
41
+
42
+ template <class scalar_t>
43
+ void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info);
44
+
45
+ template <class scalar_t, class value_t = scalar_t>
46
+ void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
47
+
48
+ template <class scalar_t>
49
+ void lapackGels(char trans, int m, int n, int nrhs,
50
+ scalar_t *a, int lda, scalar_t *b, int ldb,
51
+ scalar_t *work, int lwork, int *info);
52
+
53
+ template <class scalar_t, class value_t = scalar_t>
54
+ void lapackGelsd(int m, int n, int nrhs,
55
+ scalar_t *a, int lda, scalar_t *b, int ldb,
56
+ value_t *s, value_t rcond, int *rank,
57
+ scalar_t* work, int lwork,
58
+ value_t *rwork, int* iwork, int *info);
59
+
60
+ template <class scalar_t, class value_t = scalar_t>
61
+ void lapackGelsy(int m, int n, int nrhs,
62
+ scalar_t *a, int lda, scalar_t *b, int ldb,
63
+ int *jpvt, value_t rcond, int *rank,
64
+ scalar_t *work, int lwork, value_t* rwork, int *info);
65
+
66
+ template <class scalar_t, class value_t = scalar_t>
67
+ void lapackGelss(int m, int n, int nrhs,
68
+ scalar_t *a, int lda, scalar_t *b, int ldb,
69
+ value_t *s, value_t rcond, int *rank,
70
+ scalar_t *work, int lwork,
71
+ value_t *rwork, int *info);
72
+
73
+ template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t>
74
+ struct lapackLstsq_impl;
75
+
76
+ template <class scalar_t, class value_t>
77
+ struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> {
78
+ static void call(
79
+ char trans, int m, int n, int nrhs,
80
+ scalar_t *a, int lda, scalar_t *b, int ldb,
81
+ scalar_t *work, int lwork, int *info, // Gels flavor
82
+ int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
83
+ value_t *s, // Gelss flavor
84
+ int *iwork // Gelsd flavor
85
+ ) {
86
+ lapackGels<scalar_t>(
87
+ trans, m, n, nrhs,
88
+ a, lda, b, ldb,
89
+ work, lwork, info);
90
+ }
91
+ };
92
+
93
+ template <class scalar_t, class value_t>
94
+ struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> {
95
+ static void call(
96
+ char trans, int m, int n, int nrhs,
97
+ scalar_t *a, int lda, scalar_t *b, int ldb,
98
+ scalar_t *work, int lwork, int *info, // Gels flavor
99
+ int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
100
+ value_t *s, // Gelss flavor
101
+ int *iwork // Gelsd flavor
102
+ ) {
103
+ lapackGelsy<scalar_t, value_t>(
104
+ m, n, nrhs,
105
+ a, lda, b, ldb,
106
+ jpvt, rcond, rank,
107
+ work, lwork, rwork, info);
108
+ }
109
+ };
110
+
111
+ template <class scalar_t, class value_t>
112
+ struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> {
113
+ static void call(
114
+ char trans, int m, int n, int nrhs,
115
+ scalar_t *a, int lda, scalar_t *b, int ldb,
116
+ scalar_t *work, int lwork, int *info, // Gels flavor
117
+ int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
118
+ value_t *s, // Gelss flavor
119
+ int *iwork // Gelsd flavor
120
+ ) {
121
+ lapackGelsd<scalar_t, value_t>(
122
+ m, n, nrhs,
123
+ a, lda, b, ldb,
124
+ s, rcond, rank,
125
+ work, lwork,
126
+ rwork, iwork, info);
127
+ }
128
+ };
129
+
130
+ template <class scalar_t, class value_t>
131
+ struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> {
132
+ static void call(
133
+ char trans, int m, int n, int nrhs,
134
+ scalar_t *a, int lda, scalar_t *b, int ldb,
135
+ scalar_t *work, int lwork, int *info, // Gels flavor
136
+ int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
137
+ value_t *s, // Gelss flavor
138
+ int *iwork // Gelsd flavor
139
+ ) {
140
+ lapackGelss<scalar_t, value_t>(
141
+ m, n, nrhs,
142
+ a, lda, b, ldb,
143
+ s, rcond, rank,
144
+ work, lwork,
145
+ rwork, info);
146
+ }
147
+ };
148
+
149
+ template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t>
150
+ void lapackLstsq(
151
+ char trans, int m, int n, int nrhs,
152
+ scalar_t *a, int lda, scalar_t *b, int ldb,
153
+ scalar_t *work, int lwork, int *info, // Gels flavor
154
+ int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
155
+ value_t *s, // Gelss flavor
156
+ int *iwork // Gelsd flavor
157
+ ) {
158
+ lapackLstsq_impl<driver_type, scalar_t, value_t>::call(
159
+ trans, m, n, nrhs,
160
+ a, lda, b, ldb,
161
+ work, lwork, info,
162
+ jpvt, rcond, rank, rwork,
163
+ s,
164
+ iwork);
165
+ }
166
+
167
+ template <class scalar_t>
168
+ void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
169
+
170
+ template <class scalar_t>
171
+ void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
172
+
173
+ template <class scalar_t>
174
+ void lapackLdlHermitian(
175
+ char uplo,
176
+ int n,
177
+ scalar_t* a,
178
+ int lda,
179
+ int* ipiv,
180
+ scalar_t* work,
181
+ int lwork,
182
+ int* info);
183
+
184
+ template <class scalar_t>
185
+ void lapackLdlSymmetric(
186
+ char uplo,
187
+ int n,
188
+ scalar_t* a,
189
+ int lda,
190
+ int* ipiv,
191
+ scalar_t* work,
192
+ int lwork,
193
+ int* info);
194
+
195
+ template <class scalar_t>
196
+ void lapackLdlSolveHermitian(
197
+ char uplo,
198
+ int n,
199
+ int nrhs,
200
+ scalar_t* a,
201
+ int lda,
202
+ int* ipiv,
203
+ scalar_t* b,
204
+ int ldb,
205
+ int* info);
206
+
207
+ template <class scalar_t>
208
+ void lapackLdlSolveSymmetric(
209
+ char uplo,
210
+ int n,
211
+ int nrhs,
212
+ scalar_t* a,
213
+ int lda,
214
+ int* ipiv,
215
+ scalar_t* b,
216
+ int ldb,
217
+ int* info);
218
+
219
+ template<class scalar_t, class value_t=scalar_t>
220
+ void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
221
+ #endif
222
+
223
+ #if AT_BUILD_WITH_BLAS()
224
+ template <class scalar_t>
225
+ void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
226
+ #endif
227
+
228
+ using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
229
+ DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
230
+
231
+ using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
232
+
233
+ DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
234
+
235
+ using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
236
+
237
+ DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);
238
+
239
+ using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/);
240
+ DECLARE_DISPATCH(geqrf_fn, geqrf_stub);
241
+
242
+ using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/);
243
+ DECLARE_DISPATCH(orgqr_fn, orgqr_stub);
244
+
245
+ using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/);
246
+ DECLARE_DISPATCH(ormqr_fn, ormqr_stub);
247
+
248
+ using linalg_eigh_fn = void (*)(
249
+ const Tensor& /*eigenvalues*/,
250
+ const Tensor& /*eigenvectors*/,
251
+ const Tensor& /*infos*/,
252
+ bool /*upper*/,
253
+ bool /*compute_eigenvectors*/);
254
+ DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub);
255
+
256
+ using lstsq_fn = void (*)(
257
+ const Tensor& /*a*/,
258
+ Tensor& /*b*/,
259
+ Tensor& /*rank*/,
260
+ Tensor& /*singular_values*/,
261
+ Tensor& /*infos*/,
262
+ double /*rcond*/,
263
+ std::string /*driver_name*/);
264
+ DECLARE_DISPATCH(lstsq_fn, lstsq_stub);
265
+
266
+ using triangular_solve_fn = void (*)(
267
+ const Tensor& /*A*/,
268
+ const Tensor& /*B*/,
269
+ bool /*left*/,
270
+ bool /*upper*/,
271
+ TransposeType /*transpose*/,
272
+ bool /*unitriangular*/);
273
+ DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
274
+
275
+ using lu_factor_fn = void (*)(
276
+ const Tensor& /*input*/,
277
+ const Tensor& /*pivots*/,
278
+ const Tensor& /*infos*/,
279
+ bool /*compute_pivots*/);
280
+ DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);
281
+
282
+ using unpack_pivots_fn = void(*)(
283
+ TensorIterator& iter,
284
+ const int64_t dim_size,
285
+ const int64_t max_pivot);
286
+ DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
287
+
288
+ using lu_solve_fn = void (*)(
289
+ const Tensor& /*LU*/,
290
+ const Tensor& /*pivots*/,
291
+ const Tensor& /*B*/,
292
+ TransposeType /*trans*/);
293
+ DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);
294
+
295
+ using ldl_factor_fn = void (*)(
296
+ const Tensor& /*LD*/,
297
+ const Tensor& /*pivots*/,
298
+ const Tensor& /*info*/,
299
+ bool /*upper*/,
300
+ bool /*hermitian*/);
301
+ DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub);
302
+
303
+ using svd_fn = void (*)(
304
+ const Tensor& /*A*/,
305
+ const bool /*full_matrices*/,
306
+ const bool /*compute_uv*/,
307
+ const std::optional<c10::string_view>& /*driver*/,
308
+ const Tensor& /*U*/,
309
+ const Tensor& /*S*/,
310
+ const Tensor& /*Vh*/,
311
+ const Tensor& /*info*/);
312
+ DECLARE_DISPATCH(svd_fn, svd_stub);
313
+
314
+ using ldl_solve_fn = void (*)(
315
+ const Tensor& /*LD*/,
316
+ const Tensor& /*pivots*/,
317
+ const Tensor& /*result*/,
318
+ bool /*upper*/,
319
+ bool /*hermitian*/);
320
+ DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub);
321
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/TensorBase.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/util/TypeSafeSignMath.h>
7
+
8
+
9
+ namespace at {
10
+ struct TensorIterator;
11
+ struct TensorIteratorBase;
12
+ }
13
+
14
+ namespace at::native {
15
+
16
+ inline void alpha_check(const ScalarType dtype, const Scalar& alpha) {
17
+ TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
18
+ "Boolean alpha only supported for Boolean results.");
19
+ TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
20
+ || alpha.isIntegral(true),
21
+ "For integral input tensors, argument alpha must not be a floating point number.");
22
+ TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(),
23
+ "For non-complex input tensors, argument alpha must not be a complex number.")
24
+ }
25
+
26
+ // Basic checking for all sub functions.
27
+ inline void sub_check(const TensorBase& self, const TensorBase& other) {
28
+ TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool,
29
+ "Subtraction, the `-` operator, with two bool tensors is not supported. "
30
+ "Use the `^` or `logical_xor()` operator instead.")
31
+ TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool,
32
+ "Subtraction, the `-` operator, with a bool tensor is not supported. "
33
+ "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
34
+ }
35
+
36
+ inline void sub_check(const TensorBase& self, const Scalar& scalar) {
37
+ TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(),
38
+ "Subtraction, the `-` operator, with two bool tensors is not supported. "
39
+ "Use the `^` or `logical_xor()` operator instead.")
40
+ TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(),
41
+ "Subtraction, the `-` operator, with a bool tensor is not supported. "
42
+ "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
43
+ }
44
+
45
+ using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
46
+ using structured_binary_fn_double = void(*)(TensorIteratorBase&, double);
47
+ using structured_binary_fn = void(*)(TensorIteratorBase&);
48
+
49
+ using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
50
+ using binary_fn_double = void(*)(TensorIterator&, double);
51
+ using binary_fn = void(*)(TensorIterator&);
52
+ using binary_clamp_fn_alpha =
53
+ void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val);
54
+
55
+ // NB: codegenned
56
+ DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
57
+
58
+ DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
59
+ DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub);
60
+ DECLARE_DISPATCH(structured_binary_fn, mul_stub);
61
+ DECLARE_DISPATCH(structured_binary_fn, div_true_stub);
62
+ DECLARE_DISPATCH(structured_binary_fn, div_floor_stub);
63
+ DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub);
64
+ DECLARE_DISPATCH(structured_binary_fn, atan2_stub);
65
+ DECLARE_DISPATCH(structured_binary_fn, remainder_stub);
66
+ DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub);
67
+ DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub);
68
+ DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub);
69
+ DECLARE_DISPATCH(structured_binary_fn, lshift_stub);
70
+ DECLARE_DISPATCH(structured_binary_fn, rshift_stub);
71
+ DECLARE_DISPATCH(binary_fn, logical_xor_stub);
72
+ DECLARE_DISPATCH(binary_fn, logical_and_stub);
73
+ DECLARE_DISPATCH(binary_fn, logical_or_stub);
74
+ DECLARE_DISPATCH(structured_binary_fn, lt_stub);
75
+ DECLARE_DISPATCH(structured_binary_fn, le_stub);
76
+ DECLARE_DISPATCH(structured_binary_fn, gt_stub);
77
+ DECLARE_DISPATCH(structured_binary_fn, ge_stub);
78
+ DECLARE_DISPATCH(structured_binary_fn, eq_stub);
79
+ DECLARE_DISPATCH(structured_binary_fn, ne_stub);
80
+ DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
81
+ DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
82
+ DECLARE_DISPATCH(structured_binary_fn, maximum_stub);
83
+ DECLARE_DISPATCH(structured_binary_fn, minimum_stub);
84
+ DECLARE_DISPATCH(structured_binary_fn, fmax_stub);
85
+ DECLARE_DISPATCH(structured_binary_fn, fmin_stub);
86
+ DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub);
87
+ DECLARE_DISPATCH(binary_fn_double, huber_stub);
88
+ DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub);
89
+ DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
90
+ DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub);
91
+ DECLARE_DISPATCH(structured_binary_fn, mse_stub);
92
+ DECLARE_DISPATCH(structured_binary_fn, fmod_stub);
93
+ DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub);
94
+ DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub);
95
+ DECLARE_DISPATCH(structured_binary_fn, gcd_stub);
96
+ DECLARE_DISPATCH(structured_binary_fn, lcm_stub);
97
+ DECLARE_DISPATCH(structured_binary_fn, hypot_stub);
98
+ DECLARE_DISPATCH(structured_binary_fn, igamma_stub);
99
+ DECLARE_DISPATCH(structured_binary_fn, igammac_stub);
100
+ DECLARE_DISPATCH(structured_binary_fn, nextafter_stub);
101
+ DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
102
+ DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
103
+ DECLARE_DISPATCH(structured_binary_fn, xlogy_stub);
104
+ DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
105
+ DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
106
+ DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub);
107
+ DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub);
108
+ DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub);
109
+ DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub);
110
+ DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub);
111
+ DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub);
112
+ DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub);
113
+ DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub);
114
+ DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub);
115
+ DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub);
116
+ DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub);
117
+ DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub);
118
+
119
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <c10/util/irange.h>
5
+
6
+ #ifndef AT_PER_OPERATOR_HEADERS
7
+ #include <ATen/NativeFunctions.h>
8
+ #else
9
+ #include <ATen/ops/view_as_real_native.h>
10
+ #include <ATen/ops/view_as_complex_native.h>
11
+
12
+ #include <utility>
13
+ #endif
14
+
15
+ // WARNING: this header contains non-inline functions and should be only
16
+ // included from ONE cpp file
17
+
18
+ namespace at::native {
19
+
20
+ // View tensor with new dtype, storage offset, sizes and strides
21
+ inline Tensor view_tensor(
22
+ const Tensor &tensor, ScalarType dtype,
23
+ c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) {
24
+ Storage storage = tensor.storage();
25
+ auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
26
+ auto new_tensor = detail::make_tensor<TensorImpl>(
27
+ c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
28
+ auto * impl = new_tensor.unsafeGetTensorImpl();
29
+ impl->set_sizes_and_strides(sizes, strides, offset);
30
+ return new_tensor;
31
+ }
32
+
33
+ inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) {
34
+ SymDimVector res(oldstride.size() + 1);
35
+ for (const auto i : c10::irange(oldstride.size())) {
36
+ res[i] = oldstride[i] * 2;
37
+ }
38
+ res.back() = 1;
39
+ return res;
40
+ }
41
+
42
+ inline Tensor _view_as_real_physical(const Tensor& self) {
43
+ TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
44
+ auto old_sizes = self.sym_sizes();
45
+ SymDimVector new_sizes(old_sizes.size() + 1);
46
+ std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
47
+ // last dimension will always have two elements containing the real and imag vals
48
+ new_sizes.back() = 2;
49
+ auto new_strides = computeStrideForViewAsReal(self.sym_strides());
50
+ auto new_storage_offset = self.sym_storage_offset() * 2;
51
+ const auto float_type = c10::toRealValueType(self.scalar_type());
52
+ auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
53
+ return real_tensor;
54
+ }
55
+
56
+ // expects as input a complex tensor and returns back a tensor
57
+ // with corresponding real dtype containing the complex values
58
+ // in the last two dimensions
59
+ Tensor view_as_real(const Tensor& self) {
60
+ TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
61
+ return _view_as_real_physical(self);
62
+ }
63
+
64
+ inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) {
65
+ const auto dim = oldstride.size();
66
+ TORCH_CHECK(dim > 0 && oldstride[dim - 1] == 1, "Tensor must have a last dimension with stride 1");
67
+
68
+ SymDimVector res(dim - 1);
69
+ for (const auto i : c10::irange(res.size())) {
70
+ TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension");
71
+ res[i] = oldstride[i] / 2;
72
+ }
73
+ return res;
74
+ }
75
+
76
+ // expects as input a float or double tensor with last dimension of size 2
77
+ // and returns back a tensor with corresponding complex dtype
78
+ Tensor view_as_complex(const Tensor& self) {
79
+ TORCH_CHECK(
80
+ self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf,
81
+ "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
82
+
83
+ auto old_sizes = self.sym_sizes();
84
+ TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions");
85
+ TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2");
86
+ SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1);
87
+
88
+ const auto new_strides = computeStrideForViewAsComplex(self.sym_strides());
89
+ const auto complex_type = c10::toComplexType(self.scalar_type());
90
+
91
+ TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2");
92
+ const auto new_storage_offset = self.sym_storage_offset() / 2;
93
+
94
+ return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);
95
+ }
96
+
97
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distance.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace at {
6
+ class Tensor;
7
+
8
+ namespace native {
9
+
10
+ using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p);
11
+ using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
12
+ using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p);
13
+ using cdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
14
+
15
+ DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub);
16
+ DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub);
17
+ DECLARE_DISPATCH(cdist_fn, cdist_stub);
18
+ DECLARE_DISPATCH(cdist_backward_fn, cdist_backward_stub);
19
+
20
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Functions that fill Tensors with constants. Implementations are in Fill.cpp.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/native/DispatchStub.h>
6
+
7
+ namespace c10 {
8
+ class Scalar;
9
+ }
10
+
11
+ namespace at {
12
+ class Tensor;
13
+ struct TensorIterator;
14
+
15
+ namespace native {
16
+
17
+ DECLARE_DISPATCH(void(*)(TensorIterator&, const c10::Scalar&), fill_stub);
18
+
19
+ Tensor& fill_out(Tensor& self, const Scalar& value);
20
+
21
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/FractionalMaxPooling.h ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/TensorUtils.h>
4
+ #include <c10/util/irange.h>
5
+
6
+ namespace at::native {
7
+
8
+ template<typename scalar_t>
9
+ inline std::vector<int64_t> generate_intervals(
10
+ scalar_t sample,
11
+ int64_t inputSize,
12
+ int64_t outputSize,
13
+ int64_t poolSize) {
14
+ std::vector<int64_t> sequence(outputSize);
15
+ if (outputSize > 1) {
16
+ scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
17
+ static_cast<scalar_t>(outputSize - 1);
18
+
19
+ for (const auto i : c10::irange(outputSize - 1)) {
20
+ sequence[i] =
21
+ static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
22
+ }
23
+ }
24
+ if (outputSize > 0) {
25
+ sequence[outputSize - 1] = inputSize - poolSize;
26
+ }
27
+ return sequence;
28
+ }
29
+
30
+ template <int64_t ndim>
31
+ inline void fractional_max_pool_check_shape(
32
+ const Tensor& input,
33
+ const Tensor& randomSamples) {
34
+
35
+ TORCH_CHECK(
36
+ input.scalar_type() == randomSamples.scalar_type(),
37
+ "Expect _random_samples to have the same dtype as input");
38
+
39
+ int64_t ndimension = randomSamples.ndimension();
40
+ TORCH_CHECK(
41
+ ndimension == 3,
42
+ "Expect _random_samples to have 3 dimensions, got ", ndimension);
43
+
44
+ int64_t N = randomSamples.size(0);
45
+ int64_t C = randomSamples.size(1);
46
+ int64_t D = randomSamples.size(2);
47
+
48
+ int64_t input_batch = 0, input_channel = 0;
49
+ if (ndim == 2) {
50
+ // fractional_max_pool2d
51
+ if (input.ndimension() == 3) {
52
+ input_batch = 1;
53
+ input_channel = input.size(0);
54
+ } else {
55
+ input_batch = input.size(0);
56
+ input_channel = input.size(1);
57
+ }
58
+ } else {
59
+ // factional_max_pool3d
60
+ if (input.ndimension() == 4) {
61
+ input_batch = 1;
62
+ input_channel = input.size(0);
63
+ } else {
64
+ input_batch = input.size(0);
65
+ input_channel = input.size(1);
66
+ }
67
+ }
68
+
69
+ TORCH_CHECK(
70
+ N >= input_batch,
71
+ "Expect _random_samples.size(0) no less then input batch size.");
72
+ TORCH_CHECK(
73
+ C == input_channel,
74
+ "Expect _random_samples.size(1) equals to input channel size.");
75
+ TORCH_CHECK(
76
+ D == ndim,
77
+ "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
78
+ }
79
+
80
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <cstdint>
5
+
6
+ namespace at {
7
+ struct TensorIterator;
8
+
9
+ namespace native {
10
+
11
+ using _compute_linear_combination_fn = void(*)(
12
+ TensorIterator& iter,
13
+ int64_t in_stride,
14
+ int64_t coeff_stride,
15
+ int64_t num_summations
16
+ );
17
+
18
+ DECLARE_DISPATCH(_compute_linear_combination_fn, _compute_linear_combination_stub);
19
+
20
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSampler.h ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <algorithm>
4
+ #include <cmath>
5
+ #include <cstdint>
6
+ #include <utility>
7
+
8
+ #include <ATen/native/GridSamplerUtils.h>
9
+
10
+ namespace at::native {
11
+
12
+ using detail::GridSamplerInterpolation;
13
+ using detail::GridSamplerPadding;
14
+
15
+ // Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
16
+ // where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
17
+ // if align_corners: -1 and +1 get sent to the centers of the corner pixels
18
+ // -1 --> 0
19
+ // +1 --> (size - 1)
20
+ // scale_factor = (size - 1) / 2
21
+ // if not align_corners: -1 and +1 get sent to the image edges
22
+ // -1 --> -0.5
23
+ // +1 --> (size - 1) + 0.5 == size - 0.5
24
+ // scale_factor = size / 2
25
+ template <typename scalar_t>
26
+ static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
27
+ bool align_corners) {
28
+ if (align_corners) {
29
+ // unnormalize coord from [-1, 1] to [0, size - 1]
30
+ return ((coord + 1) / 2) * (size - 1);
31
+ } else {
32
+ // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
33
+ return ((coord + 1) * size - 1) / 2;
34
+ }
35
+ }
36
+
37
+ // grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
38
+ // except that it also returns the `d output / d input` via pointer argument
39
+ // `grad_in`.
40
+ // This is useful in the backward pass of grid_sampler.
41
+ template <typename scalar_t>
42
+ static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
43
+ bool align_corners, scalar_t *grad_in) {
44
+ if (align_corners) {
45
+ // unnormalize coord from [-1, 1] to [0, size - 1]
46
+ *grad_in = static_cast<scalar_t>(size - 1) / 2;
47
+ return ((coord + 1) / 2) * (size - 1);
48
+ } else {
49
+ // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
50
+ *grad_in = static_cast<scalar_t>(size) / 2;
51
+ return ((coord + 1) * size - 1) / 2;
52
+ }
53
+ }
54
+
55
+ // Clips coordinates to between 0 and clip_limit - 1
56
+ template<typename scalar_t>
57
+ static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
58
+ return std::min(static_cast<scalar_t>(clip_limit - 1), std::max(in, static_cast<scalar_t>(0)));
59
+ }
60
+
61
+ // clip_coordinates_set_grad works similarly to clip_coordinates except that
62
+ // it also returns the `d output / d input` via pointer argument `grad_in`.
63
+ // This is useful in the backward pass of grid_sampler.
64
+ template<typename scalar_t>
65
+ static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
66
+ scalar_t *grad_in) {
67
+ // Note that it is important for the gradient calculation that borders
68
+ // are considered out of bounds.
69
+ if (in <= static_cast<scalar_t>(0)) {
70
+ *grad_in = static_cast<scalar_t>(0);
71
+ return static_cast<scalar_t>(0);
72
+ } else {
73
+ scalar_t max = static_cast<scalar_t>(clip_limit - 1);
74
+ if (in >= max) {
75
+ *grad_in = static_cast<scalar_t>(0);
76
+ return max;
77
+ } else {
78
+ *grad_in = static_cast<scalar_t>(1);
79
+ return in;
80
+ }
81
+ }
82
+ }
83
+
84
+ // Reflects coordinates until they fall between low and high (inclusive).
85
+ // The bounds are passed as twice their value so that half-integer values
86
+ // can be represented as ints.
87
+ template<typename scalar_t>
88
+ static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
89
+ int64_t twice_high) {
90
+ if (twice_low == twice_high) {
91
+ return static_cast<scalar_t>(0);
92
+ }
93
+ scalar_t min = static_cast<scalar_t>(twice_low) / 2;
94
+ scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
95
+ in = std::fabs(in - min);
96
+ // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
97
+ scalar_t extra = std::fmod(in, span);
98
+ int flips = static_cast<int>(std::floor(in / span));
99
+ if (flips % 2 == 0) {
100
+ return extra + min;
101
+ } else {
102
+ return span - extra + min;
103
+ }
104
+ }
105
+
106
+ // reflect_coordinates_set_grad works similarly to reflect_coordinates except
107
+ // that it also returns the `d output / d input` via pointer argument
108
+ // `grad_in`.
109
+ // This is useful in the backward pass of grid_sampler.
110
+ template<typename scalar_t>
111
+ static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
112
+ int64_t twice_high, scalar_t *grad_in) {
113
+ if (twice_low == twice_high) {
114
+ *grad_in = static_cast<scalar_t>(0);
115
+ return static_cast<scalar_t>(0);
116
+ }
117
+ int grad_in_mult_;
118
+ scalar_t min = static_cast<scalar_t>(twice_low) / 2;
119
+ scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
120
+ in = in - min;
121
+ if (in < static_cast<scalar_t>(0)) {
122
+ grad_in_mult_ = -1;
123
+ in = -in;
124
+ } else {
125
+ grad_in_mult_ = 1;
126
+ }
127
+ // `fmod` returns same sign as `in`, which is positive after the `if` above.
128
+ scalar_t extra = std::fmod(in, span);
129
+ int flips = static_cast<int>(std::floor(in / span));
130
+ if (flips % 2 == 0) {
131
+ *grad_in = static_cast<scalar_t>(grad_in_mult_);
132
+ return extra + min;
133
+ } else {
134
+ *grad_in = static_cast<scalar_t>(-grad_in_mult_);
135
+ return span - extra + min;
136
+ }
137
+ }
138
+
139
+ // Mapping the out-of-boundary points back into boundary
140
+ // This would only affect padding_mode=border or reflection
141
+ template<typename scalar_t>
142
+ static inline scalar_t compute_coordinates(scalar_t coord, int64_t size,
143
+ GridSamplerPadding padding_mode,
144
+ bool align_corners) {
145
+ if (padding_mode == GridSamplerPadding::Border) {
146
+ // clip coordinates to image borders
147
+ coord = clip_coordinates(coord, size);
148
+ } else if (padding_mode == GridSamplerPadding::Reflection) {
149
+ // reflect coordinates by image borders
150
+ if (align_corners) {
151
+ coord = reflect_coordinates(coord, 0, 2*(size - 1));
152
+ } else {
153
+ coord = reflect_coordinates(coord, -1, 2*size - 1);
154
+ }
155
+ // clip coordinates to image borders
156
+ coord = clip_coordinates(coord, size);
157
+ }
158
+ return coord;
159
+ }
160
+
161
+ // Computes the pixel source index value for a grid coordinate
162
+ template <typename scalar_t>
163
+ static inline scalar_t grid_sampler_compute_source_index(
164
+ scalar_t coord,
165
+ int64_t size,
166
+ GridSamplerPadding padding_mode,
167
+ bool align_corners) {
168
+ coord = grid_sampler_unnormalize(coord, size, align_corners);
169
+ coord = compute_coordinates(coord, size, padding_mode, align_corners);
170
+ return coord;
171
+ }
172
+
173
+ // grid_sampler_compute_source_index_set_grad works similarly to
174
+ // grid_sampler_compute_source_index except that it also returns the
175
+ // `d output / d input` via pointer argument `grad_in`.
176
+ // This is useful in the backward pass of grid_sampler.
177
+ template <typename scalar_t>
178
+ static inline scalar_t grid_sampler_compute_source_index_set_grad(
179
+ scalar_t coord,
180
+ int64_t size,
181
+ GridSamplerPadding padding_mode,
182
+ bool align_corners,
183
+ scalar_t *grad_in) {
184
+ scalar_t grad_clip, grad_refl;
185
+ coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
186
+ if (padding_mode == GridSamplerPadding::Border) {
187
+ // clip coordinates to image borders
188
+ coord = clip_coordinates_set_grad(coord, size, &grad_clip);
189
+ *grad_in = (*grad_in) * grad_clip;
190
+ } else if (padding_mode == GridSamplerPadding::Reflection) {
191
+ // reflect coordinates by image borders
192
+ if (align_corners) {
193
+ coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
194
+ } else {
195
+ coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
196
+ }
197
+ // clip coordinates to image borders
198
+ coord = clip_coordinates_set_grad(coord, size, &grad_clip);
199
+ *grad_in = (*grad_in) * grad_refl * grad_clip;
200
+ }
201
+ return coord;
202
+ }
203
+
204
+ static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
205
+ return h >= 0 && h < H && w >= 0 && w < W;
206
+ }
207
+
208
+ static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
209
+ return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
210
+ }
211
+
212
+ template<typename scalar_t>
213
+ static inline scalar_t get_value_bounded(
214
+ const scalar_t* data,
215
+ scalar_t x,
216
+ scalar_t y,
217
+ int64_t W,
218
+ int64_t H,
219
+ int64_t sW,
220
+ int64_t sH,
221
+ GridSamplerPadding padding_mode,
222
+ bool align_corners) {
223
+
224
+ x = compute_coordinates(x, W, padding_mode, align_corners);
225
+ y = compute_coordinates(y, H, padding_mode, align_corners);
226
+
227
+ int64_t ix = static_cast<int64_t>(x);
228
+ int64_t iy = static_cast<int64_t>(y);
229
+
230
+ if (within_bounds_2d(iy, ix, H, W)) {
231
+ return data[iy * sH + ix * sW];
232
+ }
233
+ return static_cast<scalar_t>(0);
234
+ }
235
+
236
+ template<typename scalar_t>
237
+ static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
238
+ int64_t sH, int64_t sW, int64_t H, int64_t W,
239
+ scalar_t delta) {
240
+ if (within_bounds_2d(h, w, H, W)) {
241
+ data[h * sH + w * sW] += delta;
242
+ }
243
+ }
244
+
245
+ template<typename scalar_t>
246
+ static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
247
+ int64_t sD, int64_t sH, int64_t sW,
248
+ int64_t D, int64_t H, int64_t W,
249
+ scalar_t delta) {
250
+ if (within_bounds_3d(d, h, w, D, H, W)) {
251
+ data[d * sD + h * sH + w * sW] += delta;
252
+ }
253
+ }
254
+
255
+ template<typename scalar_t>
256
+ static inline void add_value_bounded(
257
+ scalar_t* data,
258
+ scalar_t x,
259
+ scalar_t y,
260
+ int64_t W,
261
+ int64_t H,
262
+ int64_t sW,
263
+ int64_t sH,
264
+ scalar_t delta,
265
+ GridSamplerPadding padding_mode,
266
+ bool align_corners) {
267
+
268
+ x = compute_coordinates(x, W, padding_mode, align_corners);
269
+ y = compute_coordinates(y, H, padding_mode, align_corners);
270
+
271
+ int64_t ix = static_cast<int64_t>(x);
272
+ int64_t iy = static_cast<int64_t>(y);
273
+
274
+ safe_add_2d(data, iy, ix, sH, sW, H, W, delta);
275
+ }
276
+
277
+ // Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
278
+ template<typename scalar_t>
279
+ static inline void get_cubic_coefficients_grad(
280
+ scalar_t coeffs[4],
281
+ scalar_t t) {
282
+
283
+ // Must be the same as forward calculation in
284
+ // aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients
285
+ scalar_t A = -0.75;
286
+
287
+ scalar_t x;
288
+ x = -1 - t; // 1 < x = |-1 - tx| < 2
289
+ coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
290
+ x = -t; // x = |0 - tx| <= 1
291
+ coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
292
+ x = 1 - t; // x = |1 - tx| <= 1
293
+ coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
294
+ x = 2 - t; // 1 < x = |2 - tx| < 2
295
+ coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
296
+ }
297
+
298
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/AccumulateType.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/TensorUtils.h>
6
+
7
+ namespace at::native {
8
+ inline void multilabel_margin_loss_shape_check(
9
+ int64_t& nframe,
10
+ int64_t& dim,
11
+ const int64_t& ndims,
12
+ const Tensor& input,
13
+ const Tensor& target) {
14
+ TORCH_CHECK(
15
+ (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
16
+ "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
17
+ input.sizes());
18
+
19
+ if (ndims <= 1) {
20
+ nframe = 1;
21
+ dim = ndims == 0 ? 1 : input.size(0);
22
+ TORCH_CHECK(
23
+ target.dim() <= 1 && target.numel() == dim,
24
+ "inconsistent target size: ", target.sizes(), " for input of size: ",
25
+ input.sizes());
26
+ } else {
27
+ nframe = input.size(0);
28
+ dim = input.size(1);
29
+ TORCH_CHECK(
30
+ target.dim() == 2 && target.size(0) == nframe &&
31
+ target.size(1) == dim,
32
+ "inconsistent target size: ", target.sizes(), " for input of size: ",
33
+ input.sizes());
34
+ }
35
+ }
36
+
37
+ inline void multi_margin_loss_shape_check(
38
+ int64_t& nframe,
39
+ int64_t& dim,
40
+ const int64_t& ndims,
41
+ const Tensor& input,
42
+ const Tensor& target,
43
+ const std::optional<Tensor>& weight) {
44
+ TORCH_CHECK(
45
+ (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
46
+ "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
47
+ input.sizes());
48
+
49
+ if (ndims <= 1) {
50
+ nframe = 1;
51
+ dim = ndims == 0 ? 1 : input.size(0);
52
+ } else {
53
+ nframe = input.size(0);
54
+ dim = input.size(1);
55
+ }
56
+
57
+ TORCH_CHECK(
58
+ target.dim() <= 1 && target.numel() == nframe,
59
+ "inconsistent target size, expected ", nframe, " but got ",
60
+ target.sizes());
61
+ if (weight && weight->defined()) {
62
+ TORCH_CHECK(
63
+ weight->dim() <= 1 && weight->numel() == dim,
64
+ "inconsistent weight size, expected ", dim, " but got ",
65
+ weight->sizes());
66
+ }
67
+ }
68
+
69
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at {
4
+ // views and their in-place version ops
5
+ #define TORCH_VIEW_FNS(m) \
6
+ m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \
7
+ m.impl("detach", torch::CppFunction::makeFallthrough()); \
8
+ m.impl("detach_", torch::CppFunction::makeFallthrough()); \
9
+ m.impl("diagonal", torch::CppFunction::makeFallthrough()); \
10
+ m.impl("expand", torch::CppFunction::makeFallthrough()); \
11
+ m.impl("expand_as", torch::CppFunction::makeFallthrough()); \
12
+ m.impl("movedim.int", torch::CppFunction::makeFallthrough()); \
13
+ m.impl("movedim.intlist", torch::CppFunction::makeFallthrough()); \
14
+ m.impl("narrow", torch::CppFunction::makeFallthrough()); \
15
+ m.impl("permute", torch::CppFunction::makeFallthrough()); \
16
+ m.impl("select.Dimname", torch::CppFunction::makeFallthrough()); \
17
+ m.impl("select.int", torch::CppFunction::makeFallthrough()); \
18
+ m.impl("squeeze", torch::CppFunction::makeFallthrough()); \
19
+ m.impl("squeeze_", torch::CppFunction::makeFallthrough()); \
20
+ m.impl("transpose.int", torch::CppFunction::makeFallthrough()); \
21
+ m.impl("transpose.Dimname", torch::CppFunction::makeFallthrough()); \
22
+ m.impl("transpose_", torch::CppFunction::makeFallthrough()); \
23
+ m.impl("t", torch::CppFunction::makeFallthrough()); \
24
+ m.impl("t_", torch::CppFunction::makeFallthrough()); \
25
+ m.impl("real", torch::CppFunction::makeFallthrough()); \
26
+ m.impl("imag", torch::CppFunction::makeFallthrough()); \
27
+ m.impl("view_as_real", torch::CppFunction::makeFallthrough()); \
28
+ m.impl("unflatten.int", torch::CppFunction::makeFallthrough()); \
29
+ m.impl("unflatten.Dimname", torch::CppFunction::makeFallthrough()); \
30
+ m.impl("unfold", torch::CppFunction::makeFallthrough()); \
31
+ m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \
32
+ m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \
33
+ m.impl("view_as", torch::CppFunction::makeFallthrough()); \
34
+ m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \
35
+ m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \
36
+ m.impl("split.Tensor", torch::CppFunction::makeFallthrough()); \
37
+ m.impl("split_with_sizes", torch::CppFunction::makeFallthrough()); \
38
+ m.impl("swapaxes", torch::CppFunction::makeFallthrough()); \
39
+ m.impl("swapdims", torch::CppFunction::makeFallthrough()); \
40
+ m.impl("chunk", torch::CppFunction::makeFallthrough()); \
41
+ m.impl("reshape", torch::CppFunction::makeFallthrough()); \
42
+ m.impl("alias", torch::CppFunction::makeFallthrough()); \
43
+ m.impl("hsplit.int", torch::CppFunction::makeFallthrough()); \
44
+ m.impl("hsplit.array", torch::CppFunction::makeFallthrough()); \
45
+ m.impl("dsplit.int", torch::CppFunction::makeFallthrough()); \
46
+ m.impl("dsplit.array", torch::CppFunction::makeFallthrough()); \
47
+ m.impl("vsplit.int", torch::CppFunction::makeFallthrough()); \
48
+ m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \
49
+ m.impl("conj", torch::CppFunction::makeFallthrough()); \
50
+ m.impl("_conj", torch::CppFunction::makeFallthrough()); \
51
+ m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); \
52
+ m.impl("resize_", torch::CppFunction::makeFallthrough());
53
+
54
+ #define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \
55
+ m.impl("empty_like", torch::CppFunction::makeFallthrough()); \
56
+ m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \
57
+ m.impl("empty.out", torch::CppFunction::makeFallthrough()); \
58
+ m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \
59
+ m.impl("full_like", torch::CppFunction::makeFallthrough()); \
60
+ m.impl("stride.int", torch::CppFunction::makeFallthrough()); \
61
+ m.impl("stride.Dimname", torch::CppFunction::makeFallthrough()); \
62
+ m.impl("size.int", torch::CppFunction::makeFallthrough()); \
63
+ m.impl("size.Dimname", torch::CppFunction::makeFallthrough()); \
64
+ m.impl("is_complex", torch::CppFunction::makeFallthrough()); \
65
+ m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \
66
+ m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
67
+ }
68
+
69
+ #define TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) \
70
+ m.impl("as_strided", torch::CppFunction::makeFallthrough()); \
71
+ m.impl("view", torch::CppFunction::makeFallthrough());
.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBase.h>
2
+ #include <algorithm>
3
+ #include <vector>
4
+
5
+ namespace at::native {
6
+
7
+ inline int64_t ensure_nonempty_dim(int64_t dim) {
8
+ return std::max<int64_t>(dim, 1);
9
+ }
10
+
11
+ inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) {
12
+ return t.dim() == 0 ? 1 : t.size(dim);
13
+ }
14
+
15
+ inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) {
16
+ return t.dim() == 0 ? 1 : t.stride(dim);
17
+ }
18
+
19
+ using IdxVec = std::vector<int64_t>;
20
+ inline IdxVec ensure_nonempty_vec(IdxVec vec) {
21
+ if (vec.empty()) {
22
+ vec.push_back(1);
23
+ }
24
+ return vec;
25
+ }
26
+
27
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/TensorIterator.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ namespace at::native {
7
+
8
+ using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
9
+ DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
10
+
11
+ enum class BatchNormBackend {
12
+ Native,
13
+ Cudnn,
14
+ Miopen,
15
+ };
16
+
17
+ TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
18
+
19
+ } // namespace at::native