Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py +233 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/conv.py +495 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h +37 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h +275 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h +205 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h +278 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h +242 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h +119 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUFallback.h +45 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h +13 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h +97 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distributions.h +518 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonSymbolicBC.h +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PixelShuffle.h +47 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RangeFactories.h +12 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceAllOps.h +16 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOpsUtils.h +449 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReductionType.h +40 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ScatterGatherChecks.h +128 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SegmentReduce.h +50 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h +49 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TypeProperties.h +20 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold2d.h +30 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/batch_norm.h +33 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/im2col_shape_check.h +232 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cholesky_solve_helper_cpu_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_lgamma_ops.h +50 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_cuda_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_triton_scaled_dot_attention_compositeexplicitautograd_dispatch.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_values_native.h +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_2d_native.h +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-311.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-311.pyc
ADDED
|
Binary file (5.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-311.pyc
ADDED
|
Binary file (6.53 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-311.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-311.pyc
ADDED
|
Binary file (6.81 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-311.pyc
ADDED
|
Binary file (26.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-311.pyc
ADDED
|
Binary file (50.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-311.pyc
ADDED
|
Binary file (26.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-311.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 35 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 36 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 37 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 38 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 39 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 40 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 41 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 42 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 43 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 44 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 45 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 46 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 47 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
|
| 48 |
+
amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
|
| 49 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
| 50 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 51 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 52 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 53 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 54 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 55 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 56 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 57 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 58 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 59 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 60 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 61 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 62 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 63 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 64 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 65 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 66 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 67 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
|
| 68 |
+
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
| 69 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 70 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 71 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 72 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
| 73 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 74 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 75 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 76 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 77 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 78 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
| 79 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 80 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 81 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 82 |
+
view_default_8 = CallFunction(aten.view.default, sub_Tensor_1, Ignored(), _users=2)
|
| 83 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 84 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 85 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 86 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
|
| 87 |
+
permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
|
| 88 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 89 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 90 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 91 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 92 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 93 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 94 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 95 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 96 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 97 |
+
_sfdp_pattern_9_training = MultiOutputPattern([view_default_5,
|
| 98 |
+
permute_default_6,
|
| 99 |
+
permute_default_9,
|
| 100 |
+
permute_default_11,
|
| 101 |
+
None
|
| 102 |
+
])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 106 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 107 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 108 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 109 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 110 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 111 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 112 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 113 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 114 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 115 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 116 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
|
| 117 |
+
amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
|
| 118 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
| 119 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 120 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 121 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 122 |
+
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
| 123 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
|
| 124 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 125 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 126 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 127 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 128 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 129 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 130 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 131 |
+
_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 135 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 136 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 137 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 138 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 139 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 140 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 141 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 142 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 143 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 144 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 145 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 146 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 147 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 148 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
|
| 149 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 150 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 151 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 152 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 153 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 154 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 155 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 156 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 157 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 158 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 159 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 160 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 161 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 162 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 163 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 164 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 165 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 166 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 167 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 168 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 169 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 170 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 171 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 172 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
| 173 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 174 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 175 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 176 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 177 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 178 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
| 179 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 180 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 181 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 182 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 183 |
+
view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2)
|
| 184 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 185 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 186 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 187 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
|
| 188 |
+
permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
|
| 189 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 190 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 191 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 192 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 193 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 194 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 195 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 196 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 197 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 198 |
+
_sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5,
|
| 199 |
+
permute_default_6,
|
| 200 |
+
permute_default_9,
|
| 201 |
+
permute_default_11,
|
| 202 |
+
None
|
| 203 |
+
])
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 207 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 208 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 209 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 210 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 211 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 212 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 213 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 214 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 215 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 216 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 217 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 218 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
|
| 219 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 220 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 221 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 222 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 223 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 224 |
+
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
| 225 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
|
| 226 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 227 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 228 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 229 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 230 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 231 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 232 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 233 |
+
_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-311.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-311.pyc
ADDED
|
Binary file (3.55 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/conv.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import logging
|
| 5 |
+
from typing import cast, List, Optional, Sequence, Tuple, TypedDict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from .. import config, ir
|
| 9 |
+
from ..ir import TensorBox
|
| 10 |
+
|
| 11 |
+
from ..lowering import (
|
| 12 |
+
add_layout_constraint,
|
| 13 |
+
constrain_to_fx_strides,
|
| 14 |
+
lowerings as L,
|
| 15 |
+
register_lowering,
|
| 16 |
+
)
|
| 17 |
+
from ..select_algorithm import (
|
| 18 |
+
autotune_select_algorithm,
|
| 19 |
+
ExternKernelChoice,
|
| 20 |
+
TritonTemplate,
|
| 21 |
+
)
|
| 22 |
+
from ..utils import (
|
| 23 |
+
ceildiv,
|
| 24 |
+
is_ones,
|
| 25 |
+
is_zeros,
|
| 26 |
+
pad_listlike,
|
| 27 |
+
sympy_product,
|
| 28 |
+
use_triton_template,
|
| 29 |
+
)
|
| 30 |
+
from ..virtualized import V
|
| 31 |
+
from .mm_common import filtered_configs
|
| 32 |
+
|
| 33 |
+
log = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
aten = torch.ops.aten
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def conv_grid(n, c, h, w, meta):
|
| 40 |
+
return (
|
| 41 |
+
ceildiv(n * h * w, meta["BLOCK_M"]),
|
| 42 |
+
ceildiv(c, meta["BLOCK_N"]),
|
| 43 |
+
meta["GROUPS"],
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
| 48 |
+
# will be utilised on the target platform
|
| 49 |
+
kernel_configs = [
|
| 50 |
+
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
|
| 51 |
+
{"config": (64, 256, 16, 2, 4), "cond": True},
|
| 52 |
+
{"config": (256, 64, 16, 2, 4), "cond": True},
|
| 53 |
+
{"config": (1024, 16, 16, 1, 8), "cond": True},
|
| 54 |
+
{"config": (128, 128, 32, 2, 8), "cond": True},
|
| 55 |
+
{"config": (64, 64, 32, 2, 4), "cond": True},
|
| 56 |
+
{"config": (64, 256, 32, 2, 8), "cond": True},
|
| 57 |
+
{"config": (256, 64, 32, 2, 8), "cond": True},
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
# Create filtered list of configs based on conv
|
| 61 |
+
platform_configs = tuple(
|
| 62 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 63 |
+
for config in kernel_configs
|
| 64 |
+
if config["cond"]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
| 68 |
+
if torch.version.hip:
|
| 69 |
+
platform_configs = tuple(
|
| 70 |
+
(config[0], config[1], config[2], 1, config[4]) for config in platform_configs
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
conv_configs = functools.partial(
|
| 74 |
+
filtered_configs,
|
| 75 |
+
configs=platform_configs,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
LOOP_BODY = """
|
| 79 |
+
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
|
| 80 |
+
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
|
| 81 |
+
idx_x_c = tl.arange(0, BLOCK_K) + k
|
| 82 |
+
|
| 83 |
+
x_ptrs = x_base + (
|
| 84 |
+
(idx_x_h * stride_xh)[:, None]
|
| 85 |
+
+ (idx_x_w * stride_xw)[:, None]
|
| 86 |
+
+ (idx_x_c * stride_xc)[None, :]
|
| 87 |
+
)
|
| 88 |
+
mask_x = (
|
| 89 |
+
(idx_n < BATCH)[:, None]
|
| 90 |
+
& (idx_x_h >= 0)[:, None]
|
| 91 |
+
& (idx_x_h < IN_H)[:, None]
|
| 92 |
+
& (idx_x_w >= 0)[:, None]
|
| 93 |
+
& (idx_x_w < IN_W)[:, None]
|
| 94 |
+
& (idx_x_c < GROUP_IN_C)[None, :]
|
| 95 |
+
)
|
| 96 |
+
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
|
| 97 |
+
|
| 98 |
+
w_ptrs = w_base + (
|
| 99 |
+
(idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
|
| 100 |
+
)
|
| 101 |
+
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
|
| 102 |
+
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
| 103 |
+
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
"""
|
| 107 |
+
This is a relatively simple conv implementation that can likely be
|
| 108 |
+
improved. Many alternate conv versions can be found here:
|
| 109 |
+
https://github.com/pytorch/torchdynamo/pull/971
|
| 110 |
+
"""
|
| 111 |
+
conv2d_template = TritonTemplate(
|
| 112 |
+
name="convolution",
|
| 113 |
+
grid=conv_grid,
|
| 114 |
+
source=r"""
|
| 115 |
+
{{def_kernel("X", "W")}}
|
| 116 |
+
# Tensor dimensions
|
| 117 |
+
BATCH = {{size("X", 0)}}
|
| 118 |
+
IN_C = {{size("X", 1)}}
|
| 119 |
+
IN_H = {{size("X", 2)}}
|
| 120 |
+
IN_W = {{size("X", 3)}}
|
| 121 |
+
OUT_C = {{size(None, 1)}}
|
| 122 |
+
OUT_H = {{size(None, 2)}}
|
| 123 |
+
OUT_W = {{size(None, 3)}}
|
| 124 |
+
|
| 125 |
+
# Strides:
|
| 126 |
+
stride_xn = {{stride("X", 0)}}
|
| 127 |
+
stride_xc = {{stride("X", 1)}}
|
| 128 |
+
stride_xh = {{stride("X", 2)}}
|
| 129 |
+
stride_xw = {{stride("X", 3)}}
|
| 130 |
+
stride_wc_out = {{stride("W", 0)}}
|
| 131 |
+
stride_wc_in = {{stride("W", 1)}}
|
| 132 |
+
stride_wh = {{stride("W", 2)}}
|
| 133 |
+
stride_ww = {{stride("W", 3)}}
|
| 134 |
+
|
| 135 |
+
nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 136 |
+
idx_y_w = nhw % OUT_W
|
| 137 |
+
nh = nhw // OUT_W
|
| 138 |
+
idx_y_h = nh % OUT_H
|
| 139 |
+
idx_n = nh // OUT_H
|
| 140 |
+
idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 141 |
+
|
| 142 |
+
{% if GROUPS == 1 %}
|
| 143 |
+
group = 0
|
| 144 |
+
GROUP_IN_C = IN_C
|
| 145 |
+
GROUP_OUT_C = OUT_C
|
| 146 |
+
{% else %}
|
| 147 |
+
group = tl.program_id(2)
|
| 148 |
+
GROUP_IN_C = IN_C // GROUPS
|
| 149 |
+
GROUP_OUT_C = OUT_C // GROUPS
|
| 150 |
+
{% endif %}
|
| 151 |
+
|
| 152 |
+
x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
|
| 153 |
+
w_base = (
|
| 154 |
+
W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 158 |
+
|
| 159 |
+
{% if UNROLL %}
|
| 160 |
+
{% for i in range(KERNEL_H) %}
|
| 161 |
+
{% for j in range(KERNEL_W) %}
|
| 162 |
+
i = {{i}}
|
| 163 |
+
j = {{j}}
|
| 164 |
+
for k in range(0, GROUP_IN_C, BLOCK_K):
|
| 165 |
+
"""
|
| 166 |
+
+ LOOP_BODY
|
| 167 |
+
+ """
|
| 168 |
+
{% endfor %}
|
| 169 |
+
{% endfor %}
|
| 170 |
+
{% else %}
|
| 171 |
+
# Could be simplified, but slightly slower:
|
| 172 |
+
# for i in range(KERNEL_H):
|
| 173 |
+
# for j in range(KERNEL_W):
|
| 174 |
+
# for k in range(0, GROUP_IN_C, BLOCK_K):
|
| 175 |
+
BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
|
| 176 |
+
for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
|
| 177 |
+
k = (ijk % BLOCK_K_COUNT) * BLOCK_K
|
| 178 |
+
ij = ijk // BLOCK_K_COUNT
|
| 179 |
+
i = ij // KERNEL_W
|
| 180 |
+
j = ij % KERNEL_W
|
| 181 |
+
"""
|
| 182 |
+
+ LOOP_BODY
|
| 183 |
+
+ """
|
| 184 |
+
{% endif %}
|
| 185 |
+
|
| 186 |
+
mask = (
|
| 187 |
+
(idx_n < BATCH)[:, None]
|
| 188 |
+
& (idx_y_h < OUT_H)[:, None]
|
| 189 |
+
& (idx_y_w < OUT_W)[:, None]
|
| 190 |
+
& (idx_y_c < GROUP_OUT_C)[None, :]
|
| 191 |
+
)
|
| 192 |
+
idx_n = idx_n[:, None]
|
| 193 |
+
idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
|
| 194 |
+
idx_h = idx_y_h[:, None]
|
| 195 |
+
idx_w = idx_y_w[:, None]
|
| 196 |
+
|
| 197 |
+
# inductor generates a suffix
|
| 198 |
+
{{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}}
|
| 199 |
+
""",
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
aten_convolution = ExternKernelChoice(
|
| 203 |
+
torch.convolution,
|
| 204 |
+
"at::convolution",
|
| 205 |
+
has_out_variant=False,
|
| 206 |
+
op_overload=aten.convolution.default,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def conv1x1_via_mm(x, w, *, out):
|
| 211 |
+
w = torch.squeeze(torch.squeeze(w, -1), -1)
|
| 212 |
+
return torch.matmul(
|
| 213 |
+
x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class ConvLayoutParams(TypedDict):
|
| 221 |
+
stride: tuple[int, ...]
|
| 222 |
+
padding: tuple[int, ...]
|
| 223 |
+
dilation: tuple[int, ...]
|
| 224 |
+
transposed: bool
|
| 225 |
+
output_padding: tuple[int, ...]
|
| 226 |
+
groups: int
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def conv_layout(
|
| 230 |
+
x: TensorBox,
|
| 231 |
+
weight: TensorBox,
|
| 232 |
+
bias: Optional[TensorBox],
|
| 233 |
+
stride: Sequence[int],
|
| 234 |
+
padding: tuple[int, ...],
|
| 235 |
+
dilation: tuple[int, ...],
|
| 236 |
+
transposed: bool,
|
| 237 |
+
output_padding: tuple[int, ...],
|
| 238 |
+
groups: int,
|
| 239 |
+
) -> ir.Layout:
|
| 240 |
+
"""Determine output layout for a convolution"""
|
| 241 |
+
with V.graph.fake_mode:
|
| 242 |
+
output = torch.ops.aten.convolution(
|
| 243 |
+
ir.ir_node_to_tensor(x, guard_shape=True),
|
| 244 |
+
ir.ir_node_to_tensor(weight, guard_shape=True),
|
| 245 |
+
ir.ir_node_to_tensor(bias, guard_shape=True),
|
| 246 |
+
stride,
|
| 247 |
+
tuple(V.graph.sizevars.size_hint(p) for p in padding), # type: ignore[arg-type]
|
| 248 |
+
dilation,
|
| 249 |
+
transposed,
|
| 250 |
+
tuple(V.graph.sizevars.size_hint(p) for p in output_padding), # type: ignore[arg-type]
|
| 251 |
+
groups,
|
| 252 |
+
)
|
| 253 |
+
sizes = ir.convert_shape_to_inductor(output.size())
|
| 254 |
+
stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
|
| 255 |
+
|
| 256 |
+
return ir.FixedLayout(
|
| 257 |
+
x.get_device(),
|
| 258 |
+
x.get_dtype(),
|
| 259 |
+
sizes,
|
| 260 |
+
stride,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def channels_last_order(rank):
|
| 265 |
+
order = list(reversed(range(rank)))
|
| 266 |
+
order.insert(1, order.pop(-1))
|
| 267 |
+
return order
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def convert_1x1_conv_to_mm(x, weight, bias):
|
| 271 |
+
# special case for 1x1 convolution, which is actually just a matmul
|
| 272 |
+
rank = len(weight.get_size())
|
| 273 |
+
for _ in range(rank - 2):
|
| 274 |
+
weight = L[aten.squeeze](weight, dim=-1)
|
| 275 |
+
weight = L[aten.permute](weight, [1, 0])
|
| 276 |
+
|
| 277 |
+
if x.get_size()[0] != 1:
|
| 278 |
+
x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
|
| 279 |
+
else:
|
| 280 |
+
x.realize()
|
| 281 |
+
x.freeze_layout()
|
| 282 |
+
|
| 283 |
+
x_permute = list(range(rank))
|
| 284 |
+
x_permute.append(x_permute.pop(1))
|
| 285 |
+
x = L[aten.permute](x, x_permute)
|
| 286 |
+
*sizes, in_chan = x.get_size()
|
| 287 |
+
x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
|
| 288 |
+
if bias is None:
|
| 289 |
+
result = L[aten.mm](x, weight)
|
| 290 |
+
else:
|
| 291 |
+
result = L[aten.addmm](bias, x, weight)
|
| 292 |
+
result = L[aten.reshape](result, [*sizes, -1])
|
| 293 |
+
result_permute = list(range(rank))
|
| 294 |
+
result_permute.insert(1, result_permute.pop(-1))
|
| 295 |
+
return L[aten.permute](result, result_permute)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
@register_lowering(aten.convolution)
|
| 299 |
+
def convolution(
|
| 300 |
+
x: TensorBox,
|
| 301 |
+
weight: TensorBox,
|
| 302 |
+
bias: TensorBox,
|
| 303 |
+
stride: List[int],
|
| 304 |
+
padding: List[int],
|
| 305 |
+
dilation: List[int],
|
| 306 |
+
transposed: bool,
|
| 307 |
+
output_padding: List[int],
|
| 308 |
+
groups: int,
|
| 309 |
+
):
|
| 310 |
+
stride = tuple(stride)
|
| 311 |
+
padding = tuple(padding)
|
| 312 |
+
dilation = tuple(dilation)
|
| 313 |
+
output_padding = tuple(output_padding)
|
| 314 |
+
if not isinstance(groups, int):
|
| 315 |
+
groups = V.graph.sizevars.evaluate_static_shape(groups)
|
| 316 |
+
assert isinstance(groups, int)
|
| 317 |
+
kwargs: ConvLayoutParams = {
|
| 318 |
+
"stride": stride,
|
| 319 |
+
"padding": padding,
|
| 320 |
+
"dilation": dilation,
|
| 321 |
+
"transposed": transposed,
|
| 322 |
+
"output_padding": output_padding,
|
| 323 |
+
"groups": groups,
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
if len(x.get_size()) == len(weight.get_size()) - 1:
|
| 327 |
+
# add batch dimension to simplify rest of function
|
| 328 |
+
return L[aten.squeeze](
|
| 329 |
+
convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
|
| 330 |
+
dim=0,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
|
| 334 |
+
weight.get_size()
|
| 335 |
+
)
|
| 336 |
+
ndim = len(kernel_shape)
|
| 337 |
+
stride = pad_listlike(stride, ndim)
|
| 338 |
+
padding = pad_listlike(padding, ndim)
|
| 339 |
+
dilation = pad_listlike(dilation, ndim)
|
| 340 |
+
output_padding = pad_listlike(output_padding, ndim)
|
| 341 |
+
|
| 342 |
+
def channels_last_conv():
|
| 343 |
+
if V.graph.layout_opt and ndim == 2:
|
| 344 |
+
return True
|
| 345 |
+
|
| 346 |
+
layout = conv_layout(x, weight, None, **kwargs)
|
| 347 |
+
req_stride_order = ir.get_stride_order(
|
| 348 |
+
V.graph.sizevars.size_hints(layout.stride)
|
| 349 |
+
)
|
| 350 |
+
return req_stride_order == ir.NHWC_STRIDE_ORDER
|
| 351 |
+
|
| 352 |
+
autotuning_gemm = config.max_autotune or config.max_autotune_gemm
|
| 353 |
+
|
| 354 |
+
if (
|
| 355 |
+
(config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
|
| 356 |
+
and is_ones(kernel_shape)
|
| 357 |
+
and is_ones(stride)
|
| 358 |
+
and is_zeros(padding)
|
| 359 |
+
and is_ones(dilation)
|
| 360 |
+
and not transposed
|
| 361 |
+
and is_zeros(output_padding)
|
| 362 |
+
and groups == 1
|
| 363 |
+
):
|
| 364 |
+
return convert_1x1_conv_to_mm(x, weight, bias)
|
| 365 |
+
|
| 366 |
+
if bias is not None and ir.get_device_type(x) != "cpu":
|
| 367 |
+
# peel off the bias, cudnn is slower with it
|
| 368 |
+
result = convolution(x, weight, None, **kwargs)
|
| 369 |
+
return L[aten.add](
|
| 370 |
+
result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
x.realize()
|
| 374 |
+
weight.realize()
|
| 375 |
+
|
| 376 |
+
# ndim can be 1 for convolution in models such as demucs
|
| 377 |
+
# TODO: check if it's beneficial to convert Conv1d to Conv2d and then
|
| 378 |
+
# apply channels last.
|
| 379 |
+
if V.graph.layout_opt and ndim == 2:
|
| 380 |
+
V.graph.num_channels_last_conv += 1
|
| 381 |
+
x = ir.ExternKernel.require_channels_last(x)
|
| 382 |
+
# TODO maybe we can convert weights to channels last just once before
|
| 383 |
+
# running the model.
|
| 384 |
+
weight = ir.ExternKernel.require_channels_last(weight)
|
| 385 |
+
layout = conv_layout(x, weight, None, **kwargs)
|
| 386 |
+
else:
|
| 387 |
+
layout = conv_layout(x, weight, None, **kwargs)
|
| 388 |
+
req_stride_order = ir.get_stride_order(
|
| 389 |
+
V.graph.sizevars.size_hints(layout.stride)
|
| 390 |
+
)
|
| 391 |
+
x = ir.ExternKernel.require_stride_order(x, req_stride_order)
|
| 392 |
+
weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)
|
| 393 |
+
|
| 394 |
+
ordered_kwargs_for_cpp_kernel = [
|
| 395 |
+
"stride",
|
| 396 |
+
"padding",
|
| 397 |
+
"dilation",
|
| 398 |
+
"transposed",
|
| 399 |
+
"output_padding",
|
| 400 |
+
"groups",
|
| 401 |
+
]
|
| 402 |
+
if bias is None:
|
| 403 |
+
args = [x, weight]
|
| 404 |
+
kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
|
| 405 |
+
ordered_kwargs_for_cpp_kernel.insert(0, "bias")
|
| 406 |
+
else:
|
| 407 |
+
args = [x, weight, bias]
|
| 408 |
+
bias.realize()
|
| 409 |
+
bias.freeze_layout()
|
| 410 |
+
V.graph.sizevars.evaluate_static_shapes(bias.get_size())
|
| 411 |
+
choices = [
|
| 412 |
+
aten_convolution.bind(
|
| 413 |
+
args,
|
| 414 |
+
layout,
|
| 415 |
+
ordered_kwargs_for_cpp_kernel,
|
| 416 |
+
**kwargs,
|
| 417 |
+
)
|
| 418 |
+
]
|
| 419 |
+
|
| 420 |
+
if (
|
| 421 |
+
use_triton_template(layout)
|
| 422 |
+
# templates only support these:
|
| 423 |
+
and ndim == 2
|
| 424 |
+
and is_ones(dilation)
|
| 425 |
+
and not transposed
|
| 426 |
+
and is_zeros(output_padding)
|
| 427 |
+
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
|
| 428 |
+
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
|
| 429 |
+
):
|
| 430 |
+
if (
|
| 431 |
+
is_ones(kernel_shape)
|
| 432 |
+
and is_ones(stride)
|
| 433 |
+
and is_zeros(padding)
|
| 434 |
+
and groups == 1
|
| 435 |
+
):
|
| 436 |
+
choices.append(aten_conv1x1_via_mm.bind(args, layout))
|
| 437 |
+
|
| 438 |
+
for cfg in conv_configs(
|
| 439 |
+
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
|
| 440 |
+
out_chan,
|
| 441 |
+
in_chan,
|
| 442 |
+
):
|
| 443 |
+
conv2d_template.maybe_append_choice(
|
| 444 |
+
choices,
|
| 445 |
+
input_nodes=(x, weight),
|
| 446 |
+
layout=layout,
|
| 447 |
+
KERNEL_H=kernel_shape[0],
|
| 448 |
+
KERNEL_W=kernel_shape[1],
|
| 449 |
+
STRIDE_H=stride[0],
|
| 450 |
+
STRIDE_W=stride[1],
|
| 451 |
+
PADDING_H=padding[0],
|
| 452 |
+
PADDING_W=padding[1],
|
| 453 |
+
GROUPS=groups,
|
| 454 |
+
# TODO(jansel): try unroll for bigger kernels once fixed:
|
| 455 |
+
# https://github.com/openai/triton/issues/1254
|
| 456 |
+
UNROLL=is_ones(kernel_shape),
|
| 457 |
+
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
| 458 |
+
num_stages=cfg.num_stages,
|
| 459 |
+
num_warps=cfg.num_warps,
|
| 460 |
+
**cfg.kwargs,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
return autotune_select_algorithm("convolution", choices, args, layout)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
@register_lowering(aten._convolution)
|
| 467 |
+
def _convolution(
|
| 468 |
+
x,
|
| 469 |
+
weight,
|
| 470 |
+
bias,
|
| 471 |
+
stride,
|
| 472 |
+
padding,
|
| 473 |
+
dilation,
|
| 474 |
+
transposed,
|
| 475 |
+
output_padding,
|
| 476 |
+
groups,
|
| 477 |
+
benchmark,
|
| 478 |
+
deterministic,
|
| 479 |
+
cudnn_enabled,
|
| 480 |
+
allow_tf32,
|
| 481 |
+
):
|
| 482 |
+
return convolution(
|
| 483 |
+
x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
|
| 488 |
+
assert fx_node.target == torch.ops.aten.convolution.default
|
| 489 |
+
if V.graph.layout_opt:
|
| 490 |
+
return args, kwargs
|
| 491 |
+
else:
|
| 492 |
+
return constrain_to_fx_strides(fx_node, *args, **kwargs)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <limits>
|
| 4 |
+
#include <c10/util/Exception.h>
|
| 5 |
+
|
| 6 |
+
namespace at::cuda::detail {
|
| 7 |
+
|
| 8 |
+
// CUDA: grid stride looping
|
| 9 |
+
//
|
| 10 |
+
// int64_t _i_n_d_e_x specifically prevents overflow in the loop increment.
|
| 11 |
+
// If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final
|
| 12 |
+
// iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be
|
| 13 |
+
// greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no
|
| 14 |
+
// further iterations and the overflowed value in i=_i_n_d_e_x is not used.
|
| 15 |
+
#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
|
| 16 |
+
int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \
|
| 17 |
+
for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
|
| 18 |
+
|
| 19 |
+
#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
// Use 1024 threads per block, which requires cuda sm_2x or above
|
| 23 |
+
constexpr int CUDA_NUM_THREADS = 1024;
|
| 24 |
+
|
| 25 |
+
// CUDA: number of blocks for threads.
|
| 26 |
+
inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) {
|
| 27 |
+
TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
|
| 28 |
+
constexpr int64_t max_int = std::numeric_limits<int>::max();
|
| 29 |
+
|
| 30 |
+
// Round up division for positive number that cannot cause integer overflow
|
| 31 |
+
auto block_num = (N - 1) / max_threads_per_block + 1;
|
| 32 |
+
TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
|
| 33 |
+
|
| 34 |
+
return static_cast<int>(block_num);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
} // namespace at::cuda::detail
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
|
| 2 |
+
// Eager mode clients should not include this file directly, instead,
|
| 3 |
+
// they should #include <ATen/cuda/PhiloxUtils.cuh>, which has a #pragma once.
|
| 4 |
+
|
| 5 |
+
namespace at::cuda::philox {
|
| 6 |
+
|
| 7 |
+
// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
|
| 8 |
+
// that instance was created with graph capture underway or not.
|
| 9 |
+
// See Note [CUDA Graph-safe RNG states].
|
| 10 |
+
//
|
| 11 |
+
// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen.
|
| 12 |
+
// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable.
|
| 13 |
+
// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
|
| 14 |
+
//
|
| 15 |
+
// The raw definition lives in its own file so jit codegen can easily copy it.
|
| 16 |
+
__host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
|
| 17 |
+
unpack(at::PhiloxCudaState arg) {
|
| 18 |
+
if (arg.captured_) {
|
| 19 |
+
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
|
| 20 |
+
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
|
| 21 |
+
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
|
| 22 |
+
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
|
| 23 |
+
} else {
|
| 24 |
+
return std::make_tuple(arg.seed_.val, arg.offset_.val);
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
} // namespace at::cuda::philox
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 7 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 8 |
+
#include <ATen/cuda/tunable/GemmCommon.h>
|
| 9 |
+
#include <c10/util/StringUtil.h>
|
| 10 |
+
|
| 11 |
+
#define ROCBLAS_BETA_FEATURES_API
|
| 12 |
+
#include <rocblas/rocblas.h>
|
| 13 |
+
|
| 14 |
+
#define TORCH_ROCBLAS_CHECK(EXPR) \
|
| 15 |
+
do { \
|
| 16 |
+
rocblas_status __err = EXPR; \
|
| 17 |
+
TORCH_CHECK(__err == rocblas_status_success, \
|
| 18 |
+
"rocblas error: ", \
|
| 19 |
+
rocblas_status_to_string(__err), \
|
| 20 |
+
" when calling `" #EXPR "`"); \
|
| 21 |
+
} while (0)
|
| 22 |
+
|
| 23 |
+
namespace at::cuda::tunable {
|
| 24 |
+
|
| 25 |
+
template <typename T>
|
| 26 |
+
constexpr rocblas_datatype RocBlasDataTypeFor();
|
| 27 |
+
|
| 28 |
+
template <>
|
| 29 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
|
| 30 |
+
return rocblas_datatype_f32_r;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
template <>
|
| 34 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
|
| 35 |
+
return rocblas_datatype_f64_r;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
template <>
|
| 39 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
|
| 40 |
+
return rocblas_datatype_f16_r;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
template <>
|
| 44 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
|
| 45 |
+
return rocblas_datatype_bf16_r;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
template <>
|
| 49 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
|
| 50 |
+
return rocblas_datatype_f32_c;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
template <>
|
| 54 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
|
| 55 |
+
return rocblas_datatype_f64_c;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <typename T>
|
| 59 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor();
|
| 60 |
+
|
| 61 |
+
template <>
|
| 62 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
|
| 63 |
+
return rocblas_datatype_f32_r;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
template <>
|
| 67 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
|
| 68 |
+
return rocblas_datatype_f64_r;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template <>
|
| 72 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
|
| 73 |
+
// Note that we're returning the _compute_ type for a given datatype.
|
| 74 |
+
// As of 12/2022, using compute type FP16 for 16-bit floats was much
|
| 75 |
+
// slower than using compute type FP32. So we use FP32 compute even for
|
| 76 |
+
// FP16 datatypes. This is how GEMM is implemented even in the function
|
| 77 |
+
// rocblasGemmHelper (see fpgeneric.h)
|
| 78 |
+
return rocblas_datatype_f32_r;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <>
|
| 82 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
|
| 83 |
+
// Note that we're returning the _compute_ type for a given datatype.
|
| 84 |
+
// As of 12/2022, using compute type FP16 for 16-bit floats was much
|
| 85 |
+
// slower than using compute type FP32. So we use FP32 compute even for
|
| 86 |
+
// BF16 datatypes. This is how GEMM is implemented even in the function
|
| 87 |
+
// rocblasGemmHelper (see fpgeneric.h)
|
| 88 |
+
return rocblas_datatype_f32_r;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
template <>
|
| 92 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
|
| 93 |
+
return rocblas_datatype_f32_c;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
template <>
|
| 97 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
|
| 98 |
+
return rocblas_datatype_f64_c;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
template <typename T>
|
| 102 |
+
auto DoCastForHalfOrBfloat16(const T fp) {
|
| 103 |
+
return fp;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
template <>
|
| 107 |
+
inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
|
| 108 |
+
// alpha and beta should be the same as compute_type, in Half case it is float.
|
| 109 |
+
float h = fp;
|
| 110 |
+
return h;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
template <>
|
| 114 |
+
inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
|
| 115 |
+
// alpha and beta should be the same as compute_type, in bfloat16 case it is float.
|
| 116 |
+
float h = fp;
|
| 117 |
+
return h;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
static rocblas_operation _rocblasOpFromChar(char op) {
|
| 121 |
+
switch (op) {
|
| 122 |
+
case 'n':
|
| 123 |
+
case 'N':
|
| 124 |
+
return rocblas_operation_none;
|
| 125 |
+
case 't':
|
| 126 |
+
case 'T':
|
| 127 |
+
return rocblas_operation_transpose;
|
| 128 |
+
case 'c':
|
| 129 |
+
case 'C':
|
| 130 |
+
return rocblas_operation_conjugate_transpose;
|
| 131 |
+
}
|
| 132 |
+
AT_ERROR(
|
| 133 |
+
"_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
template <typename T>
|
| 137 |
+
class RocblasGemmOp : public Callable<GemmParams<T>> {
|
| 138 |
+
public:
|
| 139 |
+
RocblasGemmOp(int solution) : solution_{solution} {}
|
| 140 |
+
|
| 141 |
+
TuningStatus Call(const GemmParams<T>* params) override {
|
| 142 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 143 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 144 |
+
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
| 145 |
+
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
| 146 |
+
auto status = rocblas_gemm_ex(
|
| 147 |
+
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
|
| 148 |
+
_rocblasOpFromChar(params->transa),
|
| 149 |
+
_rocblasOpFromChar(params->transb),
|
| 150 |
+
params->m, params->n, params->k,
|
| 151 |
+
&h_a,
|
| 152 |
+
params->a, input_output_type, params->lda,
|
| 153 |
+
params->b, input_output_type, params->ldb,
|
| 154 |
+
&h_b,
|
| 155 |
+
params->c, input_output_type, params->ldc,
|
| 156 |
+
params->c, input_output_type, params->ldc,
|
| 157 |
+
compute_type,
|
| 158 |
+
rocblas_gemm_algo_solution_index,
|
| 159 |
+
solution_,
|
| 160 |
+
rocblas_gemm_flags_none);
|
| 161 |
+
if (status != rocblas_status_success) {
|
| 162 |
+
return FAIL;
|
| 163 |
+
}
|
| 164 |
+
return OK;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
private:
|
| 168 |
+
int solution_;
|
| 169 |
+
};
|
| 170 |
+
|
| 171 |
+
template <typename T>
|
| 172 |
+
auto GetRocBlasGemmTypeStringAndOps() {
|
| 173 |
+
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
|
| 174 |
+
int solution_size;
|
| 175 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 176 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 177 |
+
// Get the number of available solutions
|
| 178 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 179 |
+
input_output_type,
|
| 180 |
+
input_output_type,
|
| 181 |
+
compute_type,
|
| 182 |
+
rocblas_gemm_flags_none,
|
| 183 |
+
nullptr,
|
| 184 |
+
&solution_size));
|
| 185 |
+
std::vector<int> solutions(solution_size);
|
| 186 |
+
// Get the list of available solutions
|
| 187 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 188 |
+
input_output_type,
|
| 189 |
+
input_output_type,
|
| 190 |
+
compute_type,
|
| 191 |
+
rocblas_gemm_flags_none,
|
| 192 |
+
solutions.data(),
|
| 193 |
+
&solution_size));
|
| 194 |
+
// Sort the solutions in ascending order to make the solution vector deterministic across runs
|
| 195 |
+
std::sort(solutions.begin(), solutions.end());
|
| 196 |
+
|
| 197 |
+
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
|
| 198 |
+
for (size_t i = 0; i < solutions.size(); ++i) {
|
| 199 |
+
auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
|
| 200 |
+
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
|
| 201 |
+
}
|
| 202 |
+
return ret;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
template <typename T>
|
| 206 |
+
class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
|
| 207 |
+
public:
|
| 208 |
+
RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
|
| 209 |
+
|
| 210 |
+
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
| 211 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 212 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 213 |
+
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
| 214 |
+
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
| 215 |
+
auto status = rocblas_gemm_strided_batched_ex(
|
| 216 |
+
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
|
| 217 |
+
_rocblasOpFromChar(params->transa),
|
| 218 |
+
_rocblasOpFromChar(params->transb),
|
| 219 |
+
params->m, params->n, params->k,
|
| 220 |
+
&h_a,
|
| 221 |
+
params->a, input_output_type, params->lda, params->stride_a,
|
| 222 |
+
params->b, input_output_type, params->ldb, params->stride_b,
|
| 223 |
+
&h_b,
|
| 224 |
+
params->c, input_output_type, params->ldc, params->stride_c,
|
| 225 |
+
params->c, input_output_type, params->ldc, params->stride_c,
|
| 226 |
+
params->batch,
|
| 227 |
+
compute_type,
|
| 228 |
+
rocblas_gemm_algo_solution_index,
|
| 229 |
+
solution_,
|
| 230 |
+
rocblas_gemm_flags_none);
|
| 231 |
+
if (status != rocblas_status_success) {
|
| 232 |
+
return FAIL;
|
| 233 |
+
}
|
| 234 |
+
return OK;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
private:
|
| 238 |
+
int solution_;
|
| 239 |
+
};
|
| 240 |
+
|
| 241 |
+
template <typename T>
|
| 242 |
+
auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
|
| 243 |
+
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
|
| 244 |
+
int solution_size;
|
| 245 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 246 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 247 |
+
// Get the number of available solutions
|
| 248 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 249 |
+
input_output_type,
|
| 250 |
+
input_output_type,
|
| 251 |
+
compute_type,
|
| 252 |
+
rocblas_gemm_flags_none,
|
| 253 |
+
nullptr,
|
| 254 |
+
&solution_size));
|
| 255 |
+
std::vector<int> solutions(solution_size);
|
| 256 |
+
// Get the list of available solutions
|
| 257 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 258 |
+
input_output_type,
|
| 259 |
+
input_output_type,
|
| 260 |
+
compute_type,
|
| 261 |
+
rocblas_gemm_flags_none,
|
| 262 |
+
solutions.data(),
|
| 263 |
+
&solution_size));
|
| 264 |
+
// Sort the solutions in ascending order to make the solution vector deterministic across runs
|
| 265 |
+
std::sort(solutions.begin(), solutions.end());
|
| 266 |
+
|
| 267 |
+
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
|
| 268 |
+
for (size_t i = 0; i < solutions.size(); ++i) {
|
| 269 |
+
auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
|
| 270 |
+
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
|
| 271 |
+
}
|
| 272 |
+
return ret;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
} // namespace at::cuda::tunable
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Original TunableOp is from onnxruntime.
|
| 2 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 4 |
+
// Copyright (c) Microsoft Corporation.
|
| 5 |
+
// Licensed under the MIT license.
|
| 6 |
+
//
|
| 7 |
+
// Adapting TunableOp into PyTorch
|
| 8 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 9 |
+
//
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <c10/util/CallOnce.h>
|
| 13 |
+
|
| 14 |
+
#include <functional>
|
| 15 |
+
#include <iostream>
|
| 16 |
+
#include <memory>
|
| 17 |
+
#include <mutex>
|
| 18 |
+
#include <string>
|
| 19 |
+
#include <type_traits>
|
| 20 |
+
#include <unordered_map>
|
| 21 |
+
#include <utility>
|
| 22 |
+
#include <vector>
|
| 23 |
+
|
| 24 |
+
namespace at::cuda::tunable {
|
| 25 |
+
|
| 26 |
+
static void TunableLog(const std::string& msg) {
|
| 27 |
+
static const char *env = getenv("PYTORCH_TUNABLEOP_VERBOSE");
|
| 28 |
+
if (env != nullptr && strcmp(env, "1") == 0) {
|
| 29 |
+
std::cerr << msg << std::endl;
|
| 30 |
+
}
|
| 31 |
+
}
|
| 32 |
+
#define TUNABLE_LOG(...) TunableLog(c10::str(__VA_ARGS__))
|
| 33 |
+
|
| 34 |
+
enum TuningStatus {
|
| 35 |
+
OK = 0,
|
| 36 |
+
FAIL = 1,
|
| 37 |
+
UNSUPPORTED = 2,
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
// Mapping from params signature to kernel id
|
| 41 |
+
class ResultEntry {
|
| 42 |
+
public:
|
| 43 |
+
explicit ResultEntry(const std::string& key, double time) : key_(key), time_(time) {}
|
| 44 |
+
bool operator==(const ResultEntry& other) { return key_ == other.key_; }
|
| 45 |
+
bool operator!=(const ResultEntry& other) { return key_ != other.key_; }
|
| 46 |
+
operator std::string () { return key_; }
|
| 47 |
+
friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
|
| 48 |
+
static ResultEntry Null() { return ResultEntry("Null", 0.0); }
|
| 49 |
+
static ResultEntry Default() { return ResultEntry("Default", 0.0); }
|
| 50 |
+
|
| 51 |
+
private:
|
| 52 |
+
std::string key_;
|
| 53 |
+
double time_;
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
typedef std::unordered_map<std::string, ResultEntry> KernelMap;
|
| 57 |
+
typedef std::unordered_map<std::string, KernelMap> ResultsMap;
|
| 58 |
+
|
| 59 |
+
struct TuningResults {
|
| 60 |
+
// Validates if these results are compatible with the libraries
|
| 61 |
+
std::unordered_map<std::string, std::string> validators;
|
| 62 |
+
|
| 63 |
+
// Mapping from Callable signature to Callable's tuning result
|
| 64 |
+
ResultsMap results;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
class TuningResultsManager {
|
| 68 |
+
public:
|
| 69 |
+
TuningResultsManager() = default;
|
| 70 |
+
~TuningResultsManager() = default;
|
| 71 |
+
|
| 72 |
+
KernelMap Lookup(const std::string& op_signature);
|
| 73 |
+
|
| 74 |
+
ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
|
| 75 |
+
|
| 76 |
+
inline void AddImpl(const std::string& op_signature,
|
| 77 |
+
const std::string& params_signature,
|
| 78 |
+
ResultEntry best,
|
| 79 |
+
KernelMap& kernel_map);
|
| 80 |
+
|
| 81 |
+
void Add(const std::string& op_signature,
|
| 82 |
+
const std::string& params_signature,
|
| 83 |
+
ResultEntry best);
|
| 84 |
+
|
| 85 |
+
void Delete(const std::string& op_signature, const std::string& params_signature);
|
| 86 |
+
|
| 87 |
+
inline void DisjointMergeImpl(
|
| 88 |
+
const std::string& op_signature,
|
| 89 |
+
const KernelMap& kernel_map,
|
| 90 |
+
/*out*/ ResultsMap& results);
|
| 91 |
+
|
| 92 |
+
void Load(const ResultsMap& results_to_load);
|
| 93 |
+
|
| 94 |
+
ResultsMap Dump();
|
| 95 |
+
|
| 96 |
+
void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
|
| 97 |
+
|
| 98 |
+
size_t GetSize();
|
| 99 |
+
|
| 100 |
+
private:
|
| 101 |
+
std::mutex lock_;
|
| 102 |
+
ResultsMap results_;
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
class TuningResultsValidator {
|
| 106 |
+
public:
|
| 107 |
+
using GetFunc = std::function<std::string()>;
|
| 108 |
+
using ValidateFunc = std::function<TuningStatus(const std::string&)>;
|
| 109 |
+
using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
|
| 110 |
+
|
| 111 |
+
TuningResultsValidator();
|
| 112 |
+
~TuningResultsValidator() = default;
|
| 113 |
+
|
| 114 |
+
std::unordered_map<std::string, std::string> GetAllValidators() const;
|
| 115 |
+
TuningStatus ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
|
| 116 |
+
void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
|
| 117 |
+
|
| 118 |
+
protected:
|
| 119 |
+
std::string GetPyTorchVersion() const;
|
| 120 |
+
TuningStatus ValidatePyTorchVersion(const std::string& value) const;
|
| 121 |
+
|
| 122 |
+
public:
|
| 123 |
+
static constexpr const std::array mandatory_keys{"PT_VERSION"};
|
| 124 |
+
|
| 125 |
+
private:
|
| 126 |
+
GetValidateFuncs validators_;
|
| 127 |
+
};
|
| 128 |
+
|
| 129 |
+
class TuningContext {
|
| 130 |
+
public:
|
| 131 |
+
TuningContext();
|
| 132 |
+
~TuningContext();
|
| 133 |
+
TuningContext(TuningContext &) = delete;
|
| 134 |
+
TuningContext(TuningContext &&) = delete;
|
| 135 |
+
TuningContext &operator=(TuningContext &) = delete;
|
| 136 |
+
TuningContext &operator=(TuningContext &&) = delete;
|
| 137 |
+
|
| 138 |
+
void EnableTunableOp();
|
| 139 |
+
void DisableTunableOp();
|
| 140 |
+
bool IsTunableOpEnabled() const;
|
| 141 |
+
|
| 142 |
+
void EnableTuning();
|
| 143 |
+
void DisableTuning();
|
| 144 |
+
bool IsTuningEnabled() const;
|
| 145 |
+
|
| 146 |
+
void SetMaxTuningDurationMs(int max_duration_ms);
|
| 147 |
+
int GetMaxTuningDurationMs() const;
|
| 148 |
+
|
| 149 |
+
void SetMaxTuningIterations(int max_iter);
|
| 150 |
+
int GetMaxTuningIterations() const;
|
| 151 |
+
|
| 152 |
+
void SetMaxWarmupDurationMs(int max_duration_ms);
|
| 153 |
+
int GetMaxWarmupDurationMs() const;
|
| 154 |
+
|
| 155 |
+
void SetMaxWarmupIterations(int max_iter);
|
| 156 |
+
int GetMaxWarmupIterations() const;
|
| 157 |
+
|
| 158 |
+
void EnableTunableOpAndTuning();
|
| 159 |
+
void DisableTunableOpAndTuning();
|
| 160 |
+
|
| 161 |
+
TuningResultsManager& GetTuningResultsManager();
|
| 162 |
+
|
| 163 |
+
TuningResultsValidator& GetTuningResultsValidator();
|
| 164 |
+
|
| 165 |
+
TuningResults GetTuningResults();
|
| 166 |
+
|
| 167 |
+
TuningStatus LoadTuningResults(const TuningResults& tr);
|
| 168 |
+
|
| 169 |
+
void SetFilename(const std::string& filename);
|
| 170 |
+
std::string GetFilename() const;
|
| 171 |
+
|
| 172 |
+
protected:
|
| 173 |
+
bool ReadFile(const std::string& filename);
|
| 174 |
+
bool WriteFile(const std::string& filename);
|
| 175 |
+
|
| 176 |
+
private:
|
| 177 |
+
bool enable_;
|
| 178 |
+
bool tuning_enable_;
|
| 179 |
+
bool manager_initialized_;
|
| 180 |
+
int max_tuning_duration_ms_;
|
| 181 |
+
int max_tuning_iterations_;
|
| 182 |
+
int max_warmup_duration_ms_;
|
| 183 |
+
int max_warmup_iterations_;
|
| 184 |
+
mutable TuningResultsManager manager_;
|
| 185 |
+
mutable c10::once_flag manager_init_once_;
|
| 186 |
+
TuningResultsValidator validator_;
|
| 187 |
+
std::string filename_;
|
| 188 |
+
size_t results_count_from_input_file_;
|
| 189 |
+
};
|
| 190 |
+
|
| 191 |
+
TuningContext* getTuningContext();
|
| 192 |
+
|
| 193 |
+
class ITimer {
|
| 194 |
+
public:
|
| 195 |
+
ITimer() = default;
|
| 196 |
+
virtual ~ITimer() = default;
|
| 197 |
+
|
| 198 |
+
virtual void Start() = 0;
|
| 199 |
+
virtual void End() = 0;
|
| 200 |
+
|
| 201 |
+
/// Computes the elapsed time in milliseconds between Start() and End()
|
| 202 |
+
virtual float Duration() = 0;
|
| 203 |
+
};
|
| 204 |
+
|
| 205 |
+
} // namespace at::cuda::tunable
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Original TunableOp is from onnxruntime.
|
| 2 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 4 |
+
// Copyright (c) Microsoft Corporation.
|
| 5 |
+
// Licensed under the MIT license.
|
| 6 |
+
//
|
| 7 |
+
// Adapting TunableOp into PyTorch
|
| 8 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 9 |
+
//
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <ATen/cuda/tunable/GemmCommon.h>
|
| 13 |
+
#ifdef USE_ROCM
|
| 14 |
+
#if ROCM_VERSION >= 50700
|
| 15 |
+
#include <ATen/cuda/tunable/GemmHipblaslt.h>
|
| 16 |
+
#endif
|
| 17 |
+
#include <ATen/cuda/tunable/GemmRocblas.h>
|
| 18 |
+
#endif
|
| 19 |
+
#include <ATen/cuda/tunable/StreamTimer.h>
|
| 20 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 21 |
+
#include <c10/cuda/CUDACachingAllocator.h>
|
| 22 |
+
#include <c10/util/StringUtil.h>
|
| 23 |
+
|
| 24 |
+
#ifdef USE_ROCM
|
| 25 |
+
#include <rocm-core/rocm_version.h>
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
#define STRINGIFY(s) #s
|
| 29 |
+
#define XSTRINGIFY(s) STRINGIFY(s)
|
| 30 |
+
|
| 31 |
+
namespace at::cuda::tunable {
|
| 32 |
+
|
| 33 |
+
template <typename T>
|
| 34 |
+
class DefaultGemmOp : public Callable<GemmParams<T>> {
|
| 35 |
+
public:
|
| 36 |
+
TuningStatus Call(const GemmParams<T>* params) override {
|
| 37 |
+
at::cuda::blas::gemm_internal<T>(
|
| 38 |
+
params->transa, params->transb,
|
| 39 |
+
params->m, params->n, params->k,
|
| 40 |
+
params->alpha,
|
| 41 |
+
params->a, params->lda,
|
| 42 |
+
params->b, params->ldb,
|
| 43 |
+
params->beta,
|
| 44 |
+
params->c, params->ldc);
|
| 45 |
+
return OK;
|
| 46 |
+
}
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
template <typename T>
|
| 50 |
+
class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
|
| 51 |
+
public:
|
| 52 |
+
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
| 53 |
+
at::cuda::blas::bgemm_internal<T>(
|
| 54 |
+
params->transa, params->transb,
|
| 55 |
+
params->m, params->n, params->k,
|
| 56 |
+
params->alpha,
|
| 57 |
+
params->a, params->lda, params->stride_a,
|
| 58 |
+
params->b, params->ldb, params->stride_b,
|
| 59 |
+
params->beta,
|
| 60 |
+
params->c, params->ldc, params->stride_c,
|
| 61 |
+
params->batch);
|
| 62 |
+
return OK;
|
| 63 |
+
}
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
template <typename T>
|
| 67 |
+
bool IsZero(T v) {
|
| 68 |
+
return v == 0.0f;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template <>
|
| 72 |
+
bool IsZero(BFloat16 v) {
|
| 73 |
+
return v.x == 0;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
template <>
|
| 77 |
+
bool IsZero(Half v) {
|
| 78 |
+
return float(v) == 0.0f;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <>
|
| 82 |
+
bool IsZero(c10::complex<double> v) {
|
| 83 |
+
return v == 0.0;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
template <>
|
| 87 |
+
bool IsZero(c10::complex<float> v) {
|
| 88 |
+
return v == 0.0f;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
template <typename T>
|
| 92 |
+
std::string TypeName(T v) {
|
| 93 |
+
return "unknown";
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
template <>
|
| 97 |
+
std::string TypeName(float v) {
|
| 98 |
+
return "float";
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
template <>
|
| 102 |
+
std::string TypeName(double v) {
|
| 103 |
+
return "double";
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
template <>
|
| 107 |
+
std::string TypeName(BFloat16 v) {
|
| 108 |
+
return "BFloat16";
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
template <>
|
| 112 |
+
std::string TypeName(Half v) {
|
| 113 |
+
return "Half";
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
template <>
|
| 117 |
+
std::string TypeName(c10::complex<double> v) {
|
| 118 |
+
return "c10::complex<double>";
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
template <>
|
| 122 |
+
std::string TypeName(c10::complex<float> v) {
|
| 123 |
+
return "c10::complex<float>";
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 128 |
+
class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
| 129 |
+
public:
|
| 130 |
+
GemmTunableOp() {
|
| 131 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
|
| 132 |
+
|
| 133 |
+
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
| 134 |
+
|
| 135 |
+
#ifdef USE_ROCM
|
| 136 |
+
for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
|
| 137 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
if (validators.find("ROCM_VERSION") == validators.end()) {
|
| 141 |
+
std::string rocm_version = ROCM_BUILD_INFO;
|
| 142 |
+
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
| 143 |
+
"ROCM_VERSION",
|
| 144 |
+
[rocm_version]() { return rocm_version; },
|
| 145 |
+
[rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
if (validators.find("GCN_ARCH_NAME") == validators.end()) {
|
| 149 |
+
std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName;
|
| 150 |
+
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
| 151 |
+
"GCN_ARCH_NAME",
|
| 152 |
+
[gcn_arch_name]() { return gcn_arch_name; },
|
| 153 |
+
[gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
if (validators.find("ROCBLAS_VERSION") == validators.end()) {
|
| 157 |
+
std::string rocblas_version = c10::str(
|
| 158 |
+
XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
|
| 159 |
+
XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
|
| 160 |
+
XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
|
| 161 |
+
XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
|
| 162 |
+
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
| 163 |
+
"ROCBLAS_VERSION",
|
| 164 |
+
[rocblas_version]() { return rocblas_version; },
|
| 165 |
+
[rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
|
| 166 |
+
}
|
| 167 |
+
#endif
|
| 168 |
+
|
| 169 |
+
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
| 170 |
+
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
| 171 |
+
if (env == nullptr || strcmp(env, "1") == 0) {
|
| 172 |
+
// disallow tuning of hipblaslt with c10::complex
|
| 173 |
+
if constexpr (
|
| 174 |
+
!std::is_same_v<T, c10::complex<float>> &&
|
| 175 |
+
!std::is_same_v<T, c10::complex<double>>) {
|
| 176 |
+
for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
| 177 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
|
| 182 |
+
std::string hipblaslt_version = c10::str(
|
| 183 |
+
XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".",
|
| 184 |
+
XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".",
|
| 185 |
+
XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-",
|
| 186 |
+
XSTRINGIFY(HIPBLASLT_VERSION_TWEAK));
|
| 187 |
+
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
| 188 |
+
"HIPBLASLT_VERSION",
|
| 189 |
+
[hipblaslt_version]() { return hipblaslt_version; },
|
| 190 |
+
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
#endif
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
std::string Signature() override {
|
| 197 |
+
return c10::str("GemmTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
| 198 |
+
}
|
| 199 |
+
};
|
| 200 |
+
|
| 201 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 202 |
+
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
|
| 203 |
+
public:
|
| 204 |
+
GemmStridedBatchedTunableOp() {
|
| 205 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
|
| 206 |
+
|
| 207 |
+
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
| 208 |
+
|
| 209 |
+
#ifdef USE_ROCM
|
| 210 |
+
for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
|
| 211 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
if (validators.find("ROCM_VERSION") == validators.end()) {
|
| 215 |
+
std::string rocm_version = ROCM_BUILD_INFO;
|
| 216 |
+
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
| 217 |
+
"ROCM_VERSION",
|
| 218 |
+
[rocm_version]() { return rocm_version; },
|
| 219 |
+
[rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
if (validators.find("GCN_ARCH_NAME") == validators.end()) {
|
| 223 |
+
std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName;
|
| 224 |
+
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
| 225 |
+
"GCN_ARCH_NAME",
|
| 226 |
+
[gcn_arch_name]() { return gcn_arch_name; },
|
| 227 |
+
[gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
if (validators.find("ROCBLAS_VERSION") == validators.end()) {
|
| 231 |
+
std::string rocblas_version = c10::str(
|
| 232 |
+
XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
|
| 233 |
+
XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
|
| 234 |
+
XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
|
| 235 |
+
XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
|
| 236 |
+
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
| 237 |
+
"ROCBLAS_VERSION",
|
| 238 |
+
[rocblas_version]() { return rocblas_version; },
|
| 239 |
+
[rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
|
| 240 |
+
}
|
| 241 |
+
#endif
|
| 242 |
+
|
| 243 |
+
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
| 244 |
+
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
| 245 |
+
if (env == nullptr || strcmp(env, "1") == 0) {
|
| 246 |
+
// disallow tuning of hipblaslt with c10::complex
|
| 247 |
+
if constexpr (
|
| 248 |
+
!std::is_same_v<T, c10::complex<float>> &&
|
| 249 |
+
!std::is_same_v<T, c10::complex<double>>) {
|
| 250 |
+
for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
|
| 251 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
|
| 256 |
+
std::string hipblaslt_version = c10::str(
|
| 257 |
+
XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".",
|
| 258 |
+
XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".",
|
| 259 |
+
XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-",
|
| 260 |
+
XSTRINGIFY(HIPBLASLT_VERSION_TWEAK));
|
| 261 |
+
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
| 262 |
+
"HIPBLASLT_VERSION",
|
| 263 |
+
[hipblaslt_version]() { return hipblaslt_version; },
|
| 264 |
+
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
#endif
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
std::string Signature() override {
|
| 271 |
+
return c10::str("GemmStridedBatchedTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
| 272 |
+
}
|
| 273 |
+
};
|
| 274 |
+
|
| 275 |
+
#undef XSTRINGIFY
|
| 276 |
+
#undef STRINGIFY
|
| 277 |
+
|
| 278 |
+
} // namespace at::cuda::tunable
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Original TunableOp is from onnxruntime.
|
| 2 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 4 |
+
// Copyright (c) Microsoft Corporation.
|
| 5 |
+
// Licensed under the MIT license.
|
| 6 |
+
//
|
| 7 |
+
// Adapting TunableOp into PyTorch
|
| 8 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 9 |
+
//
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <ATen/cuda/tunable/Tunable.h>
|
| 13 |
+
#include <c10/cuda/CUDACachingAllocator.h>
|
| 14 |
+
|
| 15 |
+
#ifndef _WIN32
|
| 16 |
+
#include <cxxabi.h>
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <string>
|
| 20 |
+
#include <type_traits>
|
| 21 |
+
#include <unordered_map>
|
| 22 |
+
#include <vector>
|
| 23 |
+
|
| 24 |
+
namespace at::cuda::tunable {
|
| 25 |
+
|
| 26 |
+
template <typename ParamsT>
|
| 27 |
+
class Callable {
|
| 28 |
+
public:
|
| 29 |
+
Callable() = default;
|
| 30 |
+
Callable(Callable&&) = default;
|
| 31 |
+
virtual ~Callable() = default;
|
| 32 |
+
virtual TuningStatus Call(const ParamsT*) {
|
| 33 |
+
return FAIL;
|
| 34 |
+
}
|
| 35 |
+
virtual TuningStatus IsSupported(const ParamsT* params) {
|
| 36 |
+
return Call(params);
|
| 37 |
+
}
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
template <typename ParamsT, typename TimerT>
|
| 41 |
+
class TunableOp {
|
| 42 |
+
public:
|
| 43 |
+
TunableOp() = default;
|
| 44 |
+
TunableOp(TunableOp&&) = default;
|
| 45 |
+
virtual ~TunableOp() = default;
|
| 46 |
+
|
| 47 |
+
TuningStatus operator()(const ParamsT* params) {
|
| 48 |
+
ResultEntry result = ResultEntry::Null();
|
| 49 |
+
TuningContext* ctx = getTuningContext();
|
| 50 |
+
if (ctx->IsTunableOpEnabled()) {
|
| 51 |
+
auto& mgr = ctx->GetTuningResultsManager();
|
| 52 |
+
auto op_sig = Signature();
|
| 53 |
+
auto params_sig = params->Signature();
|
| 54 |
+
result = mgr.Lookup(op_sig, params_sig);
|
| 55 |
+
// If there is not previous tuning result been found, we do the tuning iff tuning is enabled
|
| 56 |
+
if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) {
|
| 57 |
+
result = FindFastest(params);
|
| 58 |
+
mgr.Add(op_sig, params_sig, result);
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
else {
|
| 62 |
+
result = ResultEntry::Default();
|
| 63 |
+
}
|
| 64 |
+
if (result == ResultEntry::Null()) {
|
| 65 |
+
TUNABLE_LOG("no result, using default");
|
| 66 |
+
result = ResultEntry::Default();
|
| 67 |
+
}
|
| 68 |
+
auto iter = ops_.find(result);
|
| 69 |
+
TORCH_CHECK(iter != ops_.end());
|
| 70 |
+
return iter->second->Call(params);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
virtual std::string Signature() {
|
| 74 |
+
// According to C++17 standard https://wg21.link/n4659 section 15.7.4
|
| 75 |
+
// > if the operand of typeid refers to the
|
| 76 |
+
// > object under construction or destruction, typeid yields the std::type_info object representing the constructor
|
| 77 |
+
// > or destructor’s class.
|
| 78 |
+
// So delay the op signature generation.
|
| 79 |
+
c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
|
| 80 |
+
return signature_;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
protected:
|
| 84 |
+
void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) {
|
| 85 |
+
this->op_names_.emplace_back(name);
|
| 86 |
+
this->ops_.emplace(name, std::move(op));
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
private:
|
| 90 |
+
static void WarmUp(Callable<ParamsT> *op, ParamsT* param, size_t num_iter) {
|
| 91 |
+
for (size_t i = 0; i < num_iter; i++) {
|
| 92 |
+
TORCH_CHECK(op->Call(param) == OK);
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
static double Profile(Callable<ParamsT> *op, ParamsT* param, size_t num_iter) {
|
| 97 |
+
TimerT timer{};
|
| 98 |
+
timer.Start();
|
| 99 |
+
for (size_t i = 0; i < num_iter; i++) {
|
| 100 |
+
TORCH_CHECK(op->Call(param) == OK);
|
| 101 |
+
}
|
| 102 |
+
timer.End();
|
| 103 |
+
return timer.Duration() / num_iter;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
protected:
|
| 107 |
+
bool IsNumericsCheckEnabled() {
|
| 108 |
+
static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
| 109 |
+
if (env != nullptr && strcmp(env, "0") == 0) {
|
| 110 |
+
return false;
|
| 111 |
+
}
|
| 112 |
+
return true;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
virtual ResultEntry FindFastest(const ParamsT* params) {
|
| 116 |
+
TuningContext* ctx = getTuningContext();
|
| 117 |
+
auto op_sig = Signature();
|
| 118 |
+
auto params_sig = params->Signature();
|
| 119 |
+
TUNABLE_LOG("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
|
| 120 |
+
auto min_duration_ms = std::numeric_limits<double>::infinity();
|
| 121 |
+
std::string id_name = "Default";
|
| 122 |
+
|
| 123 |
+
// calcaulte a reference answer for numerical check
|
| 124 |
+
ParamsT* reference_params = params->DeepCopy();
|
| 125 |
+
TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
|
| 126 |
+
|
| 127 |
+
// need a copy of params to reuse
|
| 128 |
+
ParamsT* reusable_params = params->DeepCopy();
|
| 129 |
+
|
| 130 |
+
for (size_t i = 0; i < op_names_.size(); i++) {
|
| 131 |
+
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
| 132 |
+
auto status = candidate->Call(reusable_params);
|
| 133 |
+
if (status != OK) {
|
| 134 |
+
TUNABLE_LOG("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 135 |
+
continue;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
if (IsNumericsCheckEnabled()) {
|
| 139 |
+
ParamsT* numerical_params = params->DeepCopy();
|
| 140 |
+
WarmUp(candidate, numerical_params, 1);
|
| 141 |
+
status = reference_params->NumericalCheck(numerical_params);
|
| 142 |
+
numerical_params->Delete();
|
| 143 |
+
if (status != OK) {
|
| 144 |
+
TUNABLE_LOG("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 145 |
+
continue;
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// collect a small profile
|
| 150 |
+
constexpr const int approx_num_iter = 3;
|
| 151 |
+
auto approx_duration = Profile(candidate, reusable_params, approx_num_iter);
|
| 152 |
+
// bail if too slow
|
| 153 |
+
if (approx_duration > 2 * min_duration_ms) {
|
| 154 |
+
TUNABLE_LOG("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 155 |
+
continue;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// for warmup does user set max duration, max iters, or both?
|
| 159 |
+
double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
|
| 160 |
+
int max_warmup_iter = ctx->GetMaxWarmupIterations();
|
| 161 |
+
int warmup_iter = 1; // default
|
| 162 |
+
if (max_warmup_duration > 0) {
|
| 163 |
+
int duration_iters = max_warmup_duration / approx_duration;
|
| 164 |
+
if (max_warmup_iter > 0) {
|
| 165 |
+
warmup_iter = std::min(max_warmup_iter, duration_iters);
|
| 166 |
+
}
|
| 167 |
+
else {
|
| 168 |
+
warmup_iter = duration_iters;
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
else if (max_warmup_iter > 0) {
|
| 172 |
+
warmup_iter = max_warmup_iter;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// for tuning does user set max duration, max iters, or both?
|
| 176 |
+
double max_tuning_duration = ctx->GetMaxTuningDurationMs();
|
| 177 |
+
int max_tuning_iter = ctx->GetMaxTuningIterations();
|
| 178 |
+
int tuning_iter = 100; // default
|
| 179 |
+
if (max_tuning_duration > 0) {
|
| 180 |
+
int duration_iters = max_tuning_duration / approx_duration;
|
| 181 |
+
if (max_tuning_iter > 0) {
|
| 182 |
+
tuning_iter = std::min(max_tuning_iter, duration_iters);
|
| 183 |
+
}
|
| 184 |
+
else {
|
| 185 |
+
tuning_iter = duration_iters;
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
else if (max_tuning_iter > 0) {
|
| 189 |
+
tuning_iter = max_tuning_iter;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
// do the full warmup followed by tuning
|
| 193 |
+
double warmup_ms = warmup_iter * approx_duration;
|
| 194 |
+
double tuning_ms = tuning_iter * approx_duration;
|
| 195 |
+
TUNABLE_LOG("├──tuning using "
|
| 196 |
+
"warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
|
| 197 |
+
"and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
|
| 198 |
+
"instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
|
| 199 |
+
WarmUp(candidate, reusable_params, warmup_iter);
|
| 200 |
+
auto duration_ms = Profile(candidate, reusable_params, tuning_iter);
|
| 201 |
+
if (duration_ms < min_duration_ms) {
|
| 202 |
+
TUNABLE_LOG("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
|
| 203 |
+
min_duration_ms = duration_ms;
|
| 204 |
+
id_name = op_names_[i];
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
reusable_params->Delete();
|
| 209 |
+
reference_params->Delete();
|
| 210 |
+
|
| 211 |
+
TUNABLE_LOG("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
|
| 212 |
+
return ResultEntry(id_name, min_duration_ms);
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
private:
|
| 216 |
+
std::string CreateSignature() {
|
| 217 |
+
#ifndef _WIN32
|
| 218 |
+
const auto* name = typeid(*this).name();
|
| 219 |
+
char buf[256];
|
| 220 |
+
size_t buf_len = 256;
|
| 221 |
+
abi::__cxa_demangle(name, buf, &buf_len, nullptr);
|
| 222 |
+
buf[255] = '\0';
|
| 223 |
+
return buf;
|
| 224 |
+
#else
|
| 225 |
+
return typeid(*this).name();
|
| 226 |
+
#endif
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
mutable c10::once_flag signature_init_once_;
|
| 230 |
+
std::string signature_;
|
| 231 |
+
|
| 232 |
+
std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_;
|
| 233 |
+
std::vector<std::string> op_names_;
|
| 234 |
+
};
|
| 235 |
+
|
| 236 |
+
struct OpParams {
|
| 237 |
+
OpParams() {}
|
| 238 |
+
virtual ~OpParams() = default;
|
| 239 |
+
virtual std::string Signature() const = 0;
|
| 240 |
+
};
|
| 241 |
+
|
| 242 |
+
} // namespace at::cuda::tunable
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/TensorBase.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/util/TypeSafeSignMath.h>
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
struct TensorIterator;
|
| 11 |
+
struct TensorIteratorBase;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
|
| 16 |
+
inline void alpha_check(const ScalarType dtype, const Scalar& alpha) {
|
| 17 |
+
TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
|
| 18 |
+
"Boolean alpha only supported for Boolean results.");
|
| 19 |
+
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
|
| 20 |
+
|| alpha.isIntegral(true),
|
| 21 |
+
"For integral input tensors, argument alpha must not be a floating point number.");
|
| 22 |
+
TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(),
|
| 23 |
+
"For non-complex input tensors, argument alpha must not be a complex number.")
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
// Basic checking for all sub functions.
|
| 27 |
+
inline void sub_check(const TensorBase& self, const TensorBase& other) {
|
| 28 |
+
TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool,
|
| 29 |
+
"Subtraction, the `-` operator, with two bool tensors is not supported. "
|
| 30 |
+
"Use the `^` or `logical_xor()` operator instead.")
|
| 31 |
+
TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool,
|
| 32 |
+
"Subtraction, the `-` operator, with a bool tensor is not supported. "
|
| 33 |
+
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
inline void sub_check(const TensorBase& self, const Scalar& scalar) {
|
| 37 |
+
TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(),
|
| 38 |
+
"Subtraction, the `-` operator, with two bool tensors is not supported. "
|
| 39 |
+
"Use the `^` or `logical_xor()` operator instead.")
|
| 40 |
+
TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(),
|
| 41 |
+
"Subtraction, the `-` operator, with a bool tensor is not supported. "
|
| 42 |
+
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
|
| 46 |
+
using structured_binary_fn_double = void(*)(TensorIteratorBase&, double);
|
| 47 |
+
using structured_binary_fn = void(*)(TensorIteratorBase&);
|
| 48 |
+
|
| 49 |
+
using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
|
| 50 |
+
using binary_fn_double = void(*)(TensorIterator&, double);
|
| 51 |
+
using binary_fn = void(*)(TensorIterator&);
|
| 52 |
+
using binary_clamp_fn_alpha =
|
| 53 |
+
void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val);
|
| 54 |
+
|
| 55 |
+
// NB: codegenned
|
| 56 |
+
DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
|
| 57 |
+
|
| 58 |
+
DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
|
| 59 |
+
DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub);
|
| 60 |
+
DECLARE_DISPATCH(structured_binary_fn, mul_stub);
|
| 61 |
+
DECLARE_DISPATCH(structured_binary_fn, div_true_stub);
|
| 62 |
+
DECLARE_DISPATCH(structured_binary_fn, div_floor_stub);
|
| 63 |
+
DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub);
|
| 64 |
+
DECLARE_DISPATCH(structured_binary_fn, atan2_stub);
|
| 65 |
+
DECLARE_DISPATCH(structured_binary_fn, remainder_stub);
|
| 66 |
+
DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub);
|
| 67 |
+
DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub);
|
| 68 |
+
DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub);
|
| 69 |
+
DECLARE_DISPATCH(structured_binary_fn, lshift_stub);
|
| 70 |
+
DECLARE_DISPATCH(structured_binary_fn, rshift_stub);
|
| 71 |
+
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
|
| 72 |
+
DECLARE_DISPATCH(binary_fn, logical_and_stub);
|
| 73 |
+
DECLARE_DISPATCH(binary_fn, logical_or_stub);
|
| 74 |
+
DECLARE_DISPATCH(structured_binary_fn, lt_stub);
|
| 75 |
+
DECLARE_DISPATCH(structured_binary_fn, le_stub);
|
| 76 |
+
DECLARE_DISPATCH(structured_binary_fn, gt_stub);
|
| 77 |
+
DECLARE_DISPATCH(structured_binary_fn, ge_stub);
|
| 78 |
+
DECLARE_DISPATCH(structured_binary_fn, eq_stub);
|
| 79 |
+
DECLARE_DISPATCH(structured_binary_fn, ne_stub);
|
| 80 |
+
DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
|
| 81 |
+
DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
|
| 82 |
+
DECLARE_DISPATCH(structured_binary_fn, maximum_stub);
|
| 83 |
+
DECLARE_DISPATCH(structured_binary_fn, minimum_stub);
|
| 84 |
+
DECLARE_DISPATCH(structured_binary_fn, fmax_stub);
|
| 85 |
+
DECLARE_DISPATCH(structured_binary_fn, fmin_stub);
|
| 86 |
+
DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub);
|
| 87 |
+
DECLARE_DISPATCH(binary_fn_double, huber_stub);
|
| 88 |
+
DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub);
|
| 89 |
+
DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
|
| 90 |
+
DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub);
|
| 91 |
+
DECLARE_DISPATCH(structured_binary_fn, mse_stub);
|
| 92 |
+
DECLARE_DISPATCH(structured_binary_fn, fmod_stub);
|
| 93 |
+
DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub);
|
| 94 |
+
DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub);
|
| 95 |
+
DECLARE_DISPATCH(structured_binary_fn, gcd_stub);
|
| 96 |
+
DECLARE_DISPATCH(structured_binary_fn, lcm_stub);
|
| 97 |
+
DECLARE_DISPATCH(structured_binary_fn, hypot_stub);
|
| 98 |
+
DECLARE_DISPATCH(structured_binary_fn, igamma_stub);
|
| 99 |
+
DECLARE_DISPATCH(structured_binary_fn, igammac_stub);
|
| 100 |
+
DECLARE_DISPATCH(structured_binary_fn, nextafter_stub);
|
| 101 |
+
DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
|
| 102 |
+
DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
|
| 103 |
+
DECLARE_DISPATCH(structured_binary_fn, xlogy_stub);
|
| 104 |
+
DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
|
| 105 |
+
DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
|
| 106 |
+
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub);
|
| 107 |
+
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub);
|
| 108 |
+
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub);
|
| 109 |
+
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub);
|
| 110 |
+
DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub);
|
| 111 |
+
DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub);
|
| 112 |
+
DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub);
|
| 113 |
+
DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub);
|
| 114 |
+
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub);
|
| 115 |
+
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub);
|
| 116 |
+
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub);
|
| 117 |
+
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub);
|
| 118 |
+
|
| 119 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUFallback.h
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/ivalue.h>
|
| 4 |
+
#include <ATen/core/stack.h>
|
| 5 |
+
#include <ATen/core/boxing/KernelFunction.h>
|
| 6 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 7 |
+
#include <c10/util/Metaprogramming.h>
|
| 8 |
+
#include <torch/library.h>
|
| 9 |
+
|
| 10 |
+
namespace at::native {
|
| 11 |
+
|
| 12 |
+
// This function implements a boxed fallback to CPU.
|
| 13 |
+
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
|
| 14 |
+
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false);
|
| 15 |
+
|
| 16 |
+
// This is a helper function that backends can use to directly call their boxed CPU fallback
|
| 17 |
+
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
|
| 18 |
+
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 19 |
+
struct _call_fallback_fn final {};
|
| 20 |
+
|
| 21 |
+
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 22 |
+
struct _call_fallback_fn<fallback_fn, Op, symint, ReturnType(ParameterTypes...)> final {
|
| 23 |
+
static ReturnType call(typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
|
| 24 |
+
auto op = c10::Dispatcher::singleton()
|
| 25 |
+
// TODO: figure out how to make compiler happy without dynamic casts
|
| 26 |
+
.findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
|
| 27 |
+
//.findSchemaOrThrow("a", "b")
|
| 28 |
+
.typed<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>();
|
| 29 |
+
return c10::impl::BoxedKernelWrapper<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>::call(
|
| 30 |
+
c10::BoxedKernel::makeFromFunction<fallback_fn>(),
|
| 31 |
+
op,
|
| 32 |
+
c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
|
| 33 |
+
// TODO: get std::forward<> to work
|
| 34 |
+
args...
|
| 35 |
+
);
|
| 36 |
+
}
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
|
| 40 |
+
using call_fallback_fn_symint = _call_fallback_fn<fallback_fn, Op, true, typename Op::schema>;
|
| 41 |
+
|
| 42 |
+
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
|
| 43 |
+
using call_fallback_fn = _call_fallback_fn<fallback_fn, Op, false, typename Op::schema>;
|
| 44 |
+
|
| 45 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/macros/Export.h>
|
| 3 |
+
#include <limits>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class TensorBase;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
TORCH_API bool canUse32BitIndexMath(const at::TensorBase &t, int64_t max_elem=std::numeric_limits<int32_t>::max());
|
| 12 |
+
|
| 13 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <c10/util/irange.h>
|
| 5 |
+
|
| 6 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 7 |
+
#include <ATen/NativeFunctions.h>
|
| 8 |
+
#else
|
| 9 |
+
#include <ATen/ops/view_as_real_native.h>
|
| 10 |
+
#include <ATen/ops/view_as_complex_native.h>
|
| 11 |
+
|
| 12 |
+
#include <utility>
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
// WARNING: this header contains non-inline functions and should be only
|
| 16 |
+
// included from ONE cpp file
|
| 17 |
+
|
| 18 |
+
namespace at::native {
|
| 19 |
+
|
| 20 |
+
// View tensor with new dtype, storage offset, sizes and strides
|
| 21 |
+
inline Tensor view_tensor(
|
| 22 |
+
const Tensor &tensor, ScalarType dtype,
|
| 23 |
+
c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) {
|
| 24 |
+
Storage storage = tensor.storage();
|
| 25 |
+
auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
|
| 26 |
+
auto new_tensor = detail::make_tensor<TensorImpl>(
|
| 27 |
+
c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
|
| 28 |
+
auto * impl = new_tensor.unsafeGetTensorImpl();
|
| 29 |
+
impl->set_sizes_and_strides(sizes, strides, offset);
|
| 30 |
+
return new_tensor;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) {
|
| 34 |
+
SymDimVector res(oldstride.size() + 1);
|
| 35 |
+
for (const auto i : c10::irange(oldstride.size())) {
|
| 36 |
+
res[i] = oldstride[i] * 2;
|
| 37 |
+
}
|
| 38 |
+
res.back() = 1;
|
| 39 |
+
return res;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
inline Tensor _view_as_real_physical(const Tensor& self) {
|
| 43 |
+
TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
|
| 44 |
+
auto old_sizes = self.sym_sizes();
|
| 45 |
+
SymDimVector new_sizes(old_sizes.size() + 1);
|
| 46 |
+
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
|
| 47 |
+
// last dimension will always have two elements containing the real and imag vals
|
| 48 |
+
new_sizes.back() = 2;
|
| 49 |
+
auto new_strides = computeStrideForViewAsReal(self.sym_strides());
|
| 50 |
+
auto new_storage_offset = self.sym_storage_offset() * 2;
|
| 51 |
+
const auto float_type = c10::toRealValueType(self.scalar_type());
|
| 52 |
+
auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
|
| 53 |
+
return real_tensor;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// expects as input a complex tensor and returns back a tensor
|
| 57 |
+
// with corresponding real dtype containing the complex values
|
| 58 |
+
// in the last two dimensions
|
| 59 |
+
Tensor view_as_real(const Tensor& self) {
|
| 60 |
+
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
|
| 61 |
+
return _view_as_real_physical(self);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) {
|
| 65 |
+
const int64_t dim = oldstride.size();
|
| 66 |
+
TORCH_CHECK(oldstride[dim-1] == 1, "Tensor must have a last dimension with stride 1");
|
| 67 |
+
|
| 68 |
+
SymDimVector res(dim - 1);
|
| 69 |
+
for (const auto i : c10::irange(res.size())) {
|
| 70 |
+
TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension");
|
| 71 |
+
res[i] = oldstride[i] / 2;
|
| 72 |
+
}
|
| 73 |
+
return res;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// expects as input a float or double tensor with last dimension of size 2
|
| 77 |
+
// and returns back a tensor with corresponding complex dtype
|
| 78 |
+
Tensor view_as_complex(const Tensor& self) {
|
| 79 |
+
TORCH_CHECK(
|
| 80 |
+
self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf,
|
| 81 |
+
"view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
|
| 82 |
+
|
| 83 |
+
auto old_sizes = self.sym_sizes();
|
| 84 |
+
TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions");
|
| 85 |
+
TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2");
|
| 86 |
+
SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1);
|
| 87 |
+
|
| 88 |
+
const auto new_strides = computeStrideForViewAsComplex(self.sym_strides());
|
| 89 |
+
const auto complex_type = c10::toComplexType(self.scalar_type());
|
| 90 |
+
|
| 91 |
+
TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2");
|
| 92 |
+
const auto new_storage_offset = self.sym_storage_offset() / 2;
|
| 93 |
+
|
| 94 |
+
return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distributions.h
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/Math.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
#include <c10/util/MathConstants.h>
|
| 6 |
+
|
| 7 |
+
// ROCM hcc doesn't work well with using std:: in kernel functions
|
| 8 |
+
#if defined(__CUDA_ARCH__)
|
| 9 |
+
#include <c10/cuda/CUDAMathCompat.h>
|
| 10 |
+
#define compat_exp c10::cuda::compat::exp
|
| 11 |
+
#define compat_ceil c10::cuda::compat::ceil
|
| 12 |
+
#define compat_floor c10::cuda::compat::floor
|
| 13 |
+
#define compat_log c10::cuda::compat::log
|
| 14 |
+
#define compat_pow c10::cuda::compat::pow
|
| 15 |
+
#define compat_sqrt c10::cuda::compat::sqrt
|
| 16 |
+
#define compat_tan c10::cuda::compat::tan
|
| 17 |
+
#define compat_abs c10::cuda::compat::abs
|
| 18 |
+
#define compat_log1p c10::cuda::compat::log1p
|
| 19 |
+
#elif defined(__HIPCC__)
|
| 20 |
+
#include <c10/hip/HIPMathCompat.h>
|
| 21 |
+
#define compat_exp c10::hip::compat::exp
|
| 22 |
+
#define compat_ceil c10::hip::compat::ceil
|
| 23 |
+
#define compat_floor c10::hip::compat::floor
|
| 24 |
+
#define compat_log c10::hip::compat::log
|
| 25 |
+
#define compat_pow c10::hip::compat::pow
|
| 26 |
+
#define compat_sqrt c10::hip::compat::sqrt
|
| 27 |
+
#define compat_tan c10::hip::compat::tan
|
| 28 |
+
#define compat_abs c10::hip::compat::abs
|
| 29 |
+
#define compat_log1p c10::hip::compat::log1p
|
| 30 |
+
#else
|
| 31 |
+
#define compat_exp std::exp
|
| 32 |
+
#define compat_ceil std::ceil
|
| 33 |
+
#define compat_floor std::floor
|
| 34 |
+
#define compat_log std::log
|
| 35 |
+
#define compat_pow std::pow
|
| 36 |
+
#define compat_sqrt std::sqrt
|
| 37 |
+
#define compat_tan std::tan
|
| 38 |
+
#define compat_abs std::abs
|
| 39 |
+
#define compat_log1p std::log1p
|
| 40 |
+
#endif
|
| 41 |
+
|
| 42 |
+
namespace {
|
| 43 |
+
|
| 44 |
+
#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
|
| 45 |
+
// we cannot use std::isnan directly due to some incompatibility of
|
| 46 |
+
// gcc constexpr'ing and nvcc
|
| 47 |
+
using std::isnan;
|
| 48 |
+
#endif
|
| 49 |
+
|
| 50 |
+
// Here sampler_t should be function type scalar_t(void). For gpu
|
| 51 |
+
// "sampler" is a device function, but since ROCM doesn't have
|
| 52 |
+
// equivalent to nvstd::function, we use a template type parameter to
|
| 53 |
+
// capture it.
|
| 54 |
+
template<typename scalar_t, typename sampler_t>
|
| 55 |
+
struct BaseSampler {
|
| 56 |
+
sampler_t sampler;
|
| 57 |
+
C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {}
|
| 58 |
+
C10_DEVICE scalar_t sample() {
|
| 59 |
+
return sampler();
|
| 60 |
+
}
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
// The function `sample_gamma` is
|
| 64 |
+
// is adapted from Numpy's distributions.c implementation.
|
| 65 |
+
// It is MIT licensed, so here is the copyright:
|
| 66 |
+
|
| 67 |
+
/* Copyright 2005 Robert Kern (robert.kern@gmail.com)
|
| 68 |
+
*
|
| 69 |
+
* Permission is hereby granted, free of charge, to any person obtaining a
|
| 70 |
+
* copy of this software and associated documentation files (the
|
| 71 |
+
* "Software"), to deal in the Software without restriction, including
|
| 72 |
+
* without limitation the rights to use, copy, modify, merge, publish,
|
| 73 |
+
* distribute, sublicense, and/or sell copies of the Software, and to
|
| 74 |
+
* permit persons to whom the Software is furnished to do so, subject to
|
| 75 |
+
* the following conditions:
|
| 76 |
+
*
|
| 77 |
+
* The above copyright notice and this permission notice shall be included
|
| 78 |
+
* in all copies or substantial portions of the Software.
|
| 79 |
+
*
|
| 80 |
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
| 81 |
+
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 82 |
+
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
| 83 |
+
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
| 84 |
+
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
| 85 |
+
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
| 86 |
+
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 87 |
+
*/
|
| 88 |
+
|
| 89 |
+
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t, typename normal_sampler_t>
|
| 90 |
+
C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform, BaseSampler<accscalar_t, normal_sampler_t>& standard_normal) {
|
| 91 |
+
accscalar_t scale = 1.0f;
|
| 92 |
+
|
| 93 |
+
// Boost alpha for higher acceptance probability.
|
| 94 |
+
if (alpha < 1.0f) {
|
| 95 |
+
if (alpha == 0.f) return 0.f;
|
| 96 |
+
scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha);
|
| 97 |
+
alpha += 1.0f;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
|
| 101 |
+
// doi:10.1145/358407.358414
|
| 102 |
+
const accscalar_t d = alpha - 1.0f / 3.0f;
|
| 103 |
+
const accscalar_t c = 1.0f / compat_sqrt(9.0f * d);
|
| 104 |
+
for (;;) {
|
| 105 |
+
accscalar_t x, y;
|
| 106 |
+
do {
|
| 107 |
+
x = standard_normal.sample();
|
| 108 |
+
y = 1.0f + c * x;
|
| 109 |
+
} while (y <= 0);
|
| 110 |
+
const accscalar_t v = y * y * y;
|
| 111 |
+
const accscalar_t u = 1 - standard_uniform.sample();
|
| 112 |
+
const accscalar_t xx = x * x;
|
| 113 |
+
if (u < 1.0f - 0.0331f * xx * xx)
|
| 114 |
+
return static_cast<scalar_t>(scale * d * v);
|
| 115 |
+
if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v)))
|
| 116 |
+
return static_cast<scalar_t>(scale * d * v);
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
/* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
|
| 121 |
+
* from TensorFlow's random_binomial_op.cc implementation. That code is under
|
| 122 |
+
* copyright: 2019 The TensorFlow Authors.
|
| 123 |
+
*
|
| 124 |
+
* It was released under the Apache License, Version 2.0 (the "License"), available at:
|
| 125 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 126 |
+
*/
|
| 127 |
+
|
| 128 |
+
template<typename scalar_t>
|
| 129 |
+
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
|
| 130 |
+
const static scalar_t kTailValues[] = {
|
| 131 |
+
0.0810614667953272,
|
| 132 |
+
0.0413406959554092,
|
| 133 |
+
0.0276779256849983,
|
| 134 |
+
0.02079067210376509,
|
| 135 |
+
0.0166446911898211,
|
| 136 |
+
0.0138761288230707,
|
| 137 |
+
0.0118967099458917,
|
| 138 |
+
0.0104112652619720,
|
| 139 |
+
0.00925546218271273,
|
| 140 |
+
0.00833056343336287
|
| 141 |
+
};
|
| 142 |
+
if (k <= 9) {
|
| 143 |
+
return kTailValues[static_cast<size_t>(k)];
|
| 144 |
+
}
|
| 145 |
+
scalar_t kp1sq = (k + 1) * (k + 1);
|
| 146 |
+
return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
|
| 151 |
+
C10_DEVICE scalar_t binomial_inversion(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
|
| 152 |
+
accscalar_t U;
|
| 153 |
+
accscalar_t geom_sum = 0;
|
| 154 |
+
scalar_t num_geom = 0;
|
| 155 |
+
|
| 156 |
+
accscalar_t logprob = compat_log1p(-prob);
|
| 157 |
+
|
| 158 |
+
while (1) {
|
| 159 |
+
U = standard_uniform.sample();
|
| 160 |
+
accscalar_t geom = compat_ceil(compat_log(U) / logprob);
|
| 161 |
+
geom_sum += geom;
|
| 162 |
+
if (geom_sum > count) {
|
| 163 |
+
break;
|
| 164 |
+
}
|
| 165 |
+
num_geom = num_geom + 1;
|
| 166 |
+
}
|
| 167 |
+
return num_geom;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
|
| 171 |
+
C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
|
| 172 |
+
scalar_t k;
|
| 173 |
+
accscalar_t U, V, us;
|
| 174 |
+
|
| 175 |
+
// This is spq in the paper.
|
| 176 |
+
const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));
|
| 177 |
+
|
| 178 |
+
// Other coefficients for Transformed Rejection sampling.
|
| 179 |
+
const accscalar_t b = 1.15 + 2.53 * stddev;
|
| 180 |
+
const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
|
| 181 |
+
const accscalar_t c = count * prob + 0.5;
|
| 182 |
+
const accscalar_t v_r = 0.92 - 4.2 / b;
|
| 183 |
+
const accscalar_t r = prob / (1 - prob);
|
| 184 |
+
|
| 185 |
+
const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
|
| 186 |
+
const accscalar_t m = compat_floor((count + 1) * prob);
|
| 187 |
+
|
| 188 |
+
while (1) {
|
| 189 |
+
U = standard_uniform.sample() - 0.5;
|
| 190 |
+
V = standard_uniform.sample();
|
| 191 |
+
|
| 192 |
+
us = 0.5 - compat_abs(U);
|
| 193 |
+
k = static_cast<scalar_t>(compat_floor((2 * a / us + b) * U + c));
|
| 194 |
+
|
| 195 |
+
// Reject non-sensical answers.
|
| 196 |
+
if (k < 0 || k > count) {
|
| 197 |
+
continue;
|
| 198 |
+
}
|
| 199 |
+
// Region for which the box is tight, and we can return our calculated value.
|
| 200 |
+
// This should happen 0.86 * v_r times. In the limit as n * p is large,
|
| 201 |
+
// the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
|
| 202 |
+
if (us >= 0.07 && V <= v_r) {
|
| 203 |
+
return k;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
// This deviates from Hormann's BTRS algorithm, as there is a log missing.
|
| 207 |
+
// For all (u, v) pairs outside of the bounding box, this calculates the
|
| 208 |
+
// transformed-reject ratio.
|
| 209 |
+
V = compat_log(V * alpha / (a / (us * us) + b));
|
| 210 |
+
accscalar_t upperbound =
|
| 211 |
+
((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
|
| 212 |
+
(count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
|
| 213 |
+
(k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
|
| 214 |
+
stirling_approx_tail<accscalar_t>(m) + stirling_approx_tail<accscalar_t>(count - m) -
|
| 215 |
+
stirling_approx_tail<accscalar_t>(k) - stirling_approx_tail<accscalar_t>(count - k));
|
| 216 |
+
|
| 217 |
+
if (V <= upperbound) {
|
| 218 |
+
return k;
|
| 219 |
+
}
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
|
| 224 |
+
C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
|
| 225 |
+
if (count <= 0.0 || prob <= 0.0) {
|
| 226 |
+
return 0;
|
| 227 |
+
} else if (prob >= 1.0) {
|
| 228 |
+
return count;
|
| 229 |
+
} else if (prob <= 0.5) {
|
| 230 |
+
if (count * prob >= 10.0) {
|
| 231 |
+
// btrs
|
| 232 |
+
return btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
|
| 233 |
+
} else {
|
| 234 |
+
// binomial inversion
|
| 235 |
+
return binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
|
| 236 |
+
}
|
| 237 |
+
} else if (prob > 0.5) {
|
| 238 |
+
scalar_t qprob = 1.0 - prob;
|
| 239 |
+
if (count * qprob >= 10.0) {
|
| 240 |
+
// btrs
|
| 241 |
+
return count - btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
|
| 242 |
+
} else {
|
| 243 |
+
// count - binomial inversion
|
| 244 |
+
return count - binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
|
| 245 |
+
}
|
| 246 |
+
} else {
|
| 247 |
+
// prob is nan?
|
| 248 |
+
return static_cast<scalar_t>(NAN);
|
| 249 |
+
}
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
/*
|
| 253 |
+
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
|
| 254 |
+
* See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
|
| 255 |
+
*/
|
| 256 |
+
template<typename scalar_t, typename accscalar_t>
|
| 257 |
+
C10_DEVICE static inline scalar_t digamma_one(scalar_t x) {
|
| 258 |
+
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
|
| 259 |
+
if (x == 0) {
|
| 260 |
+
return INFINITY;
|
| 261 |
+
}
|
| 262 |
+
accscalar_t additional_summand = 0;
|
| 263 |
+
int x_is_integer = x == compat_floor(x);
|
| 264 |
+
if (x < 0) {
|
| 265 |
+
if (x_is_integer) {
|
| 266 |
+
return INFINITY;
|
| 267 |
+
}
|
| 268 |
+
// it is more standard to write this as recursion, but
|
| 269 |
+
// nvcc does not like that
|
| 270 |
+
additional_summand = -c10::pi<scalar_t> /
|
| 271 |
+
compat_tan(c10::pi<scalar_t> * x);
|
| 272 |
+
x = 1 - x;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// Push x to be >= 10
|
| 276 |
+
accscalar_t result = 0;
|
| 277 |
+
while (x < 10) {
|
| 278 |
+
result -= 1 / x;
|
| 279 |
+
x += 1;
|
| 280 |
+
}
|
| 281 |
+
if (x == 10) {
|
| 282 |
+
return result + PSI_10 + additional_summand;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
// Compute asymptotic digamma
|
| 286 |
+
static const accscalar_t A[] = {
|
| 287 |
+
8.33333333333333333333E-2,
|
| 288 |
+
-2.10927960927960927961E-2,
|
| 289 |
+
7.57575757575757575758E-3,
|
| 290 |
+
-4.16666666666666666667E-3,
|
| 291 |
+
3.96825396825396825397E-3,
|
| 292 |
+
-8.33333333333333333333E-3,
|
| 293 |
+
8.33333333333333333333E-2,
|
| 294 |
+
};
|
| 295 |
+
|
| 296 |
+
accscalar_t y = 0;
|
| 297 |
+
if (x < 1.0e17f) {
|
| 298 |
+
accscalar_t z = 1.0 / (x * x);
|
| 299 |
+
y = z * polevl<accscalar_t>(z, A, 6);
|
| 300 |
+
}
|
| 301 |
+
return static_cast<scalar_t>(
|
| 302 |
+
result + compat_log(x) - (0.5f / x) - y + additional_summand);
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
|
| 306 |
+
// for random number x drawn from a standard Gamma distribution Gamma(alpha).
|
| 307 |
+
template <typename scalar_t, typename accscalar_t>
|
| 308 |
+
C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
|
| 309 |
+
// Use a Taylor series expansion for small x.
|
| 310 |
+
accscalar_t x = static_cast<accscalar_t>(x_);
|
| 311 |
+
accscalar_t alpha = static_cast<accscalar_t>(alpha_);
|
| 312 |
+
if (x < 0.8f) {
|
| 313 |
+
accscalar_t numer = 1;
|
| 314 |
+
accscalar_t denom = alpha;
|
| 315 |
+
auto series1 = numer / denom;
|
| 316 |
+
auto series2 = numer / (denom * denom);
|
| 317 |
+
for (int i = 1; i <= 5; ++i) {
|
| 318 |
+
numer *= -x / static_cast<accscalar_t>(i);
|
| 319 |
+
denom += 1;
|
| 320 |
+
series1 += numer / denom;
|
| 321 |
+
series2 += numer / (denom * denom);
|
| 322 |
+
}
|
| 323 |
+
const auto pow_x_alpha = compat_pow(x, alpha);
|
| 324 |
+
const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x);
|
| 325 |
+
const auto gamma_cdf = pow_x_alpha * series1;
|
| 326 |
+
const auto gamma_cdf_alpha =
|
| 327 |
+
(compat_log(x) - digamma_one<accscalar_t, accscalar_t>(alpha)) *
|
| 328 |
+
gamma_cdf -
|
| 329 |
+
pow_x_alpha * series2;
|
| 330 |
+
const auto result = -gamma_cdf_alpha / gamma_pdf;
|
| 331 |
+
return isnan(result) ? static_cast<scalar_t>( 0.f ) : static_cast<scalar_t>(result);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
// Use a Rice saddle point expansion for large alpha.
|
| 335 |
+
if (alpha > 8.0f) {
|
| 336 |
+
if (0.9f * alpha <= x && x <= 1.1f * alpha) {
|
| 337 |
+
const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
|
| 338 |
+
const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
|
| 339 |
+
- 65 * x * x / alpha + alpha * (107 + 3600 * x);
|
| 340 |
+
const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
|
| 341 |
+
return static_cast<scalar_t>(numer_1 * numer_2 / denom);
|
| 342 |
+
}
|
| 343 |
+
const auto denom = compat_sqrt(8 * alpha);
|
| 344 |
+
const auto term2 = denom / (alpha - x);
|
| 345 |
+
const auto term3 = compat_pow(
|
| 346 |
+
x - alpha - alpha * compat_log(x / alpha),
|
| 347 |
+
static_cast<accscalar_t>(-1.5));
|
| 348 |
+
const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
|
| 349 |
+
const auto term1 = compat_log(x / alpha) * term23 -
|
| 350 |
+
compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
|
| 351 |
+
const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
|
| 352 |
+
const auto numer = x * term1;
|
| 353 |
+
return static_cast<scalar_t>(-stirling * numer / denom);
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
// Use a bivariate rational approximation to the reparameterized gradient.
|
| 357 |
+
const auto u = compat_log(x / alpha);
|
| 358 |
+
const auto v = compat_log(alpha);
|
| 359 |
+
static const accscalar_t coef_uv[3][8] = {
|
| 360 |
+
{0.16009398, -0.094634809, 0.025146376, -0.0030648343,
|
| 361 |
+
1, 0.32668115, 0.10406089, 0.0014179084},
|
| 362 |
+
{0.53487893, 0.1298071, 0.065735949, -0.0015649758,
|
| 363 |
+
0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
|
| 364 |
+
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
|
| 365 |
+
0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
|
| 366 |
+
};
|
| 367 |
+
accscalar_t coef_v[8];
|
| 368 |
+
for (int i = 0; i < 8; ++ i) {
|
| 369 |
+
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
|
| 370 |
+
}
|
| 371 |
+
const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
|
| 372 |
+
const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
|
| 373 |
+
return static_cast<scalar_t>(compat_exp(p / q));
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
|
| 377 |
+
// Assumes x is close to zero and uses a Taylor expansion.
|
| 378 |
+
template <typename scalar_t, typename accscalar_t>
|
| 379 |
+
C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
|
| 380 |
+
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha)
|
| 381 |
+
- digamma_one<scalar_t, accscalar_t>(alpha + beta) - compat_log(x);
|
| 382 |
+
scalar_t numer = 1;
|
| 383 |
+
scalar_t series = numer / alpha * (factor + 1 / alpha);
|
| 384 |
+
for (int i = 1; i <= 10; ++i) {
|
| 385 |
+
scalar_t casted_i = static_cast<scalar_t>(i);
|
| 386 |
+
numer *= (casted_i - beta) * x / casted_i;
|
| 387 |
+
const scalar_t denom = alpha + casted_i;
|
| 388 |
+
series += numer / denom * (factor + 1 / denom);
|
| 389 |
+
}
|
| 390 |
+
const scalar_t result = x * compat_pow(1 - x, -beta) * series;
|
| 391 |
+
return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
|
| 395 |
+
// Assumes x is close to zero and uses a Taylor expansion.
|
| 396 |
+
template <typename scalar_t, typename accscalar_t>
|
| 397 |
+
C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
|
| 398 |
+
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha + beta) - digamma_one<scalar_t, accscalar_t>(beta);
|
| 399 |
+
scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
|
| 400 |
+
for (int i = 1; i <= 8; ++i) {
|
| 401 |
+
scalar_t casted_i = static_cast<scalar_t>(i);
|
| 402 |
+
numer *= -x / casted_i;
|
| 403 |
+
dbetas = dbetas * (beta - casted_i) + betas;
|
| 404 |
+
betas = betas * (beta - casted_i);
|
| 405 |
+
series += numer / (alpha + casted_i) * (dbetas + factor * betas);
|
| 406 |
+
}
|
| 407 |
+
const scalar_t result = -compat_pow(1 - x, 1 - beta) * series;
|
| 408 |
+
return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
|
| 412 |
+
// Assumes alpha and beta are both large and uses a Rice saddle point expansion.
|
| 413 |
+
// To ensure numerical stability, this computation is performed at higher precision.
|
| 414 |
+
template<typename scalar_t, typename accscalar_t>
|
| 415 |
+
C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
|
| 416 |
+
const accscalar_t total = alpha + beta;
|
| 417 |
+
const accscalar_t mean = alpha / total;
|
| 418 |
+
const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
|
| 419 |
+
if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
|
| 420 |
+
// Avoid the singularity at x = mean.
|
| 421 |
+
const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
|
| 422 |
+
(43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
|
| 423 |
+
3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
|
| 424 |
+
(453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
|
| 425 |
+
8 * (1 - x) * (135 * beta - 11)))));
|
| 426 |
+
const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
|
| 427 |
+
const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
|
| 428 |
+
return prefactor_num / (1 - x) * poly / prefactor_den;
|
| 429 |
+
}
|
| 430 |
+
const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
|
| 431 |
+
const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
|
| 432 |
+
* (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
|
| 433 |
+
/ (1 + 1 / (12 * total) + 1 / (288 * total * total));
|
| 434 |
+
const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
|
| 435 |
+
const accscalar_t axbx = alpha * (x - 1) + beta * x;
|
| 436 |
+
const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast<accscalar_t>(1.5f)) * axbx * axbx;
|
| 437 |
+
const accscalar_t term1 = term1_num / term1_den;
|
| 438 |
+
const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
|
| 439 |
+
const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
|
| 440 |
+
const accscalar_t term3_den = beta * x + alpha * (x - 1);
|
| 441 |
+
const accscalar_t term3 = term3_num / term3_den;
|
| 442 |
+
const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
|
| 443 |
+
alpha * compat_log(alpha / (total * x));
|
| 444 |
+
const accscalar_t term4 = compat_pow(term4_base, static_cast<accscalar_t>(-1.5f));
|
| 445 |
+
const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
|
| 446 |
+
return static_cast<scalar_t>(stirling * prefactor * term1234);
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
// Computes a scaled reparameterized gradient
|
| 450 |
+
// -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x)
|
| 451 |
+
// for random number x drawn from a Beta distribution Beta(alpha,beta).
|
| 452 |
+
// This function inputs total=alpha+beta to make it easy to implement
|
| 453 |
+
// Dirichlet reparameterized gradients in terms of Betas.
|
| 454 |
+
template<typename scalar_t, typename accscalar_t>
|
| 455 |
+
C10_HOST_DEVICE static inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
|
| 456 |
+
accscalar_t x_ = static_cast<accscalar_t>(x);
|
| 457 |
+
accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
|
| 458 |
+
accscalar_t total_ = static_cast<accscalar_t>(total);
|
| 459 |
+
|
| 460 |
+
const scalar_t beta = total - alpha;
|
| 461 |
+
const accscalar_t beta_ = total_ - alpha_;
|
| 462 |
+
const scalar_t boundary = total * x * (1 - x);
|
| 463 |
+
|
| 464 |
+
// Use an asymptotic approximation for x close to 0.
|
| 465 |
+
if (x <= 0.5f && boundary < 2.5f) {
|
| 466 |
+
return _beta_grad_alpha_small<scalar_t, accscalar_t>(x, alpha, beta);
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
// Use an asymptotic approximation for x close to 1.
|
| 470 |
+
if (x >= 0.5f && boundary < 0.75f) {
|
| 471 |
+
return -_beta_grad_beta_small<scalar_t, accscalar_t>(1 - x, beta, alpha);
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
// Use an asymptotic approximation when alpha and (total - alpha) are both large.
|
| 475 |
+
if (alpha > 6 && beta > 6) {
|
| 476 |
+
return _beta_grad_alpha_mid<scalar_t, accscalar_t>(x_, alpha_, beta_);
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
// Use a rational correction to an analytic approximation.
|
| 480 |
+
static const accscalar_t c[2][3][3][4] = {
|
| 481 |
+
{{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
|
| 482 |
+
{0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
|
| 483 |
+
{-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
|
| 484 |
+
{{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
|
| 485 |
+
{-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
|
| 486 |
+
{0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
|
| 487 |
+
{{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
|
| 488 |
+
{0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
|
| 489 |
+
{0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
|
| 490 |
+
{{{1, -0.02924021934, -0.04438342661, 0.007285809825},
|
| 491 |
+
{0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
|
| 492 |
+
{-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
|
| 493 |
+
{{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
|
| 494 |
+
{0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
|
| 495 |
+
{-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
|
| 496 |
+
{{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
|
| 497 |
+
{0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
|
| 498 |
+
{-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
|
| 499 |
+
};
|
| 500 |
+
const accscalar_t u = compat_log(x_);
|
| 501 |
+
const accscalar_t a = compat_log(alpha_) - u;
|
| 502 |
+
const accscalar_t b = compat_log(total_) - a;
|
| 503 |
+
const accscalar_t pow_u[3] = {1, u, u * u};
|
| 504 |
+
const accscalar_t pow_a[3] = {1, a, a * a};
|
| 505 |
+
accscalar_t p = 0.0;
|
| 506 |
+
accscalar_t q = 0.0;
|
| 507 |
+
for (int i = 0; i < 3; ++i) {
|
| 508 |
+
for (int j = 0; j < 3; ++j) {
|
| 509 |
+
const accscalar_t ua = pow_u[i] * pow_a[j];
|
| 510 |
+
p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
|
| 511 |
+
q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
|
| 512 |
+
}
|
| 513 |
+
}
|
| 514 |
+
const accscalar_t approx = x_ * (digamma_one<scalar_t, accscalar_t>(total_) - digamma_one<scalar_t, accscalar_t>(alpha_)) / beta_;
|
| 515 |
+
return static_cast<scalar_t>(p / q * approx);
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
} // namespace
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonSymbolicBC.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
#include <ATen/core/IListRef.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
// This file contains non-symbolic signatures for ops that we have sym-intified the signature of.
|
| 8 |
+
// However, in certain cases (such as static runtime), we call the native versions of the ops directly.
|
| 9 |
+
// In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation.
|
| 10 |
+
TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape);
|
| 11 |
+
TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
|
| 12 |
+
TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, c10::optional<at::ScalarType> dtype=c10::nullopt, c10::optional<at::Layout> layout=c10::nullopt, c10::optional<at::Device> device=c10::nullopt, c10::optional<bool> pin_memory=c10::nullopt, c10::optional<bool> is_coalesced=c10::nullopt);
|
| 13 |
+
TORCH_API at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
|
| 14 |
+
TORCH_API at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
|
| 15 |
+
// The below ops don't get a duplicated C++ implementation.
|
| 16 |
+
// They are backward ops, which make them very unlikely to be called directly
|
| 17 |
+
// by external code (at::native::trace_backward).
|
| 18 |
+
// They get their own declaration for BC purposes however.
|
| 19 |
+
TORCH_API at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
|
| 20 |
+
TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
|
| 21 |
+
TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim);
|
| 22 |
+
TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes);
|
| 23 |
+
TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index);
|
| 24 |
+
TORCH_API at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index);
|
| 25 |
+
TORCH_API std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim);
|
| 26 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PixelShuffle.h
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <c10/util/Exception.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
namespace native {
|
| 6 |
+
|
| 7 |
+
inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
|
| 8 |
+
TORCH_CHECK(self.dim() >= 3,
|
| 9 |
+
"pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
|
| 10 |
+
self.dim(), " dimension(s)");
|
| 11 |
+
TORCH_CHECK(upscale_factor > 0,
|
| 12 |
+
"pixel_shuffle expects a positive upscale_factor, but got ",
|
| 13 |
+
upscale_factor);
|
| 14 |
+
int64_t c = self.size(-3);
|
| 15 |
+
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
|
| 16 |
+
TORCH_CHECK(c % upscale_factor_squared == 0,
|
| 17 |
+
"pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
|
| 18 |
+
"upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
|
| 22 |
+
TORCH_CHECK(
|
| 23 |
+
self.dim() >= 3,
|
| 24 |
+
"pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
|
| 25 |
+
self.dim(),
|
| 26 |
+
" dimension(s)");
|
| 27 |
+
TORCH_CHECK(
|
| 28 |
+
downscale_factor > 0,
|
| 29 |
+
"pixel_unshuffle expects a positive downscale_factor, but got ",
|
| 30 |
+
downscale_factor);
|
| 31 |
+
int64_t h = self.size(-2);
|
| 32 |
+
int64_t w = self.size(-1);
|
| 33 |
+
TORCH_CHECK(
|
| 34 |
+
h % downscale_factor == 0,
|
| 35 |
+
"pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
|
| 36 |
+
h,
|
| 37 |
+
" is not divisible by ",
|
| 38 |
+
downscale_factor);
|
| 39 |
+
TORCH_CHECK(
|
| 40 |
+
w % downscale_factor == 0,
|
| 41 |
+
"pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
|
| 42 |
+
w,
|
| 43 |
+
" is not divisible by ",
|
| 44 |
+
downscale_factor);
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RangeFactories.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/native/DispatchStub.h>
|
| 2 |
+
#include <c10/core/Scalar.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
struct TensorIterator;
|
| 6 |
+
|
| 7 |
+
namespace native {
|
| 8 |
+
|
| 9 |
+
DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&), arange_stub);
|
| 10 |
+
DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, int64_t), linspace_stub);
|
| 11 |
+
|
| 12 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceAllOps.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class Tensor;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
using reduce_all_fn = void (*)(Tensor & result, const Tensor & self);
|
| 12 |
+
using reduce_min_max_fn = void (*)(Tensor & max_result, Tensor & min_result, const Tensor & self);
|
| 13 |
+
DECLARE_DISPATCH(reduce_all_fn, min_all_stub);
|
| 14 |
+
DECLARE_DISPATCH(reduce_all_fn, max_all_stub);
|
| 15 |
+
|
| 16 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOpsUtils.h
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <limits>
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <ATen/native/Resize.h>
|
| 6 |
+
#include <ATen/native/TensorIterator.h>
|
| 7 |
+
#include <ATen/native/NonEmptyUtils.h>
|
| 8 |
+
#include <ATen/WrapDimUtilsMulti.h>
|
| 9 |
+
#include <c10/core/ScalarType.h>
|
| 10 |
+
#include <c10/util/irange.h>
|
| 11 |
+
|
| 12 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 13 |
+
#include <ATen/Functions.h>
|
| 14 |
+
#else
|
| 15 |
+
#include <ATen/ops/empty.h>
|
| 16 |
+
#include <ATen/ops/scalar_tensor.h>
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
namespace at::native {
|
| 20 |
+
|
| 21 |
+
// Maximum and minimum possible scalar values, including infinities
|
| 22 |
+
template <typename scalar_t>
|
| 23 |
+
constexpr scalar_t upper_bound() {
|
| 24 |
+
using lim = std::numeric_limits<scalar_t>;
|
| 25 |
+
return lim::has_infinity ? lim::infinity() : lim::max();
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
template <typename scalar_t>
|
| 29 |
+
constexpr scalar_t lower_bound() {
|
| 30 |
+
using lim = std::numeric_limits<scalar_t>;
|
| 31 |
+
return lim::has_infinity ? -lim::infinity() : lim::lowest();
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
static inline Tensor restride_dim(
|
| 35 |
+
const Tensor& src, int64_t dim,
|
| 36 |
+
IntArrayRef replacement_shape
|
| 37 |
+
) {
|
| 38 |
+
auto strides = ensure_nonempty_vec(src.strides().vec());
|
| 39 |
+
strides[dim] = 0;
|
| 40 |
+
return src.as_strided(replacement_shape, strides);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
|
| 44 |
+
int64_t dim) {
|
| 45 |
+
IntArrayRef self_sizes = self.sizes();
|
| 46 |
+
std::vector<int64_t> result_sizes;
|
| 47 |
+
result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
|
| 48 |
+
result_sizes[dim] = 1;
|
| 49 |
+
result.resize_(result_sizes);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
|
| 53 |
+
const Scalar& ident, int64_t dim, bool keepdim) {
|
| 54 |
+
if (self.numel() == 1 && self.ndimension() == 0) {
|
| 55 |
+
result.resize_({});
|
| 56 |
+
result.fill_(self);
|
| 57 |
+
return true;
|
| 58 |
+
}
|
| 59 |
+
// Return identity
|
| 60 |
+
if (self.numel() == 0) {
|
| 61 |
+
_dimreduce_setup(result, self, dim);
|
| 62 |
+
result.fill_(ident);
|
| 63 |
+
if (!keepdim) result.squeeze_(dim);
|
| 64 |
+
return true;
|
| 65 |
+
}
|
| 66 |
+
return false;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
inline bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &self,
|
| 70 |
+
int64_t /*dim*/, bool /*keepdim*/, const char* /*fn_name*/) {
|
| 71 |
+
if (self.numel() == 1 && self.ndimension() == 0) {
|
| 72 |
+
result.resize_({});
|
| 73 |
+
result.fill_(self);
|
| 74 |
+
return true;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return false;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
inline c10::optional<Tensor> _allreduce_return_trivial(
|
| 81 |
+
const Tensor& self,
|
| 82 |
+
const Scalar& ident) {
|
| 83 |
+
// Return identity
|
| 84 |
+
if (self.numel() == 0) {
|
| 85 |
+
return at::scalar_tensor(ident, self.options());
|
| 86 |
+
}
|
| 87 |
+
return c10::nullopt;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
#define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \
|
| 91 |
+
{ \
|
| 92 |
+
TORCH_CHECK(\
|
| 93 |
+
out.option() == self.option(),\
|
| 94 |
+
"expected ", #option, " ",\
|
| 95 |
+
self.option(),\
|
| 96 |
+
" but found ", out.option())\
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
static inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
|
| 100 |
+
OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self);
|
| 101 |
+
OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options());
|
| 102 |
+
OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options());
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
static inline Tensor integer_upcast(const Tensor& self, c10::optional<ScalarType> dtype) {
|
| 106 |
+
ScalarType scalarType = self.scalar_type();
|
| 107 |
+
TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
|
| 108 |
+
ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
|
| 109 |
+
return self.toType(upcast_scalarType);
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
using DimMask = TensorIterator::DimMask;
|
| 113 |
+
|
| 114 |
+
static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
|
| 115 |
+
if (opt_dims.has_value()) {
|
| 116 |
+
return DimVector(opt_dims.value());
|
| 117 |
+
} else {
|
| 118 |
+
std::vector<int64_t> all_dims(ndim);
|
| 119 |
+
std::iota(all_dims.begin(), all_dims.end(), 0);
|
| 120 |
+
return DimVector(all_dims);
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
|
| 125 |
+
DimMask mask;
|
| 126 |
+
if (opt_dims.has_value()) {
|
| 127 |
+
auto dims = opt_dims.value();
|
| 128 |
+
if (dims.empty() && !allow_empty_dims) {
|
| 129 |
+
mask = DimMask().flip();
|
| 130 |
+
} else {
|
| 131 |
+
mask = at::dim_list_to_bitset(dims, ndim);
|
| 132 |
+
}
|
| 133 |
+
} else {
|
| 134 |
+
mask = DimMask().flip();
|
| 135 |
+
}
|
| 136 |
+
return mask;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keepdim) {
|
| 140 |
+
auto shape = DimVector(self.sizes());
|
| 141 |
+
for (int dim = shape.size() - 1; dim >= 0; dim--) {
|
| 142 |
+
if (mask[dim]) {
|
| 143 |
+
if (keepdim) {
|
| 144 |
+
shape[dim] = 1;
|
| 145 |
+
} else {
|
| 146 |
+
shape.erase(shape.begin() + dim);
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
return shape;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
static void resize_reduction_result(
|
| 154 |
+
Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
|
| 155 |
+
ScalarType /*dtype*/)
|
| 156 |
+
{
|
| 157 |
+
auto shape = shape_from_dim_mask(self, mask, keepdim);
|
| 158 |
+
TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
|
| 159 |
+
at::native::resize_output(result, shape);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
inline Tensor create_reduction_result(
|
| 163 |
+
const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype
|
| 164 |
+
) {
|
| 165 |
+
DimMask mask = make_dim_mask(dim, self.dim());
|
| 166 |
+
auto shape = shape_from_dim_mask(self, mask, keepdim);
|
| 167 |
+
return at::empty(shape, self.options().dtype(dtype));
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
static Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
|
| 171 |
+
if (keepdim) {
|
| 172 |
+
return result;
|
| 173 |
+
}
|
| 174 |
+
auto shape = DimVector(result.sizes());
|
| 175 |
+
auto stride = DimVector(result.strides());
|
| 176 |
+
for (const auto dim : c10::irange(ndim)) {
|
| 177 |
+
if (mask[dim]) {
|
| 178 |
+
shape.insert(shape.begin() + dim, 1);
|
| 179 |
+
stride.insert(stride.begin() + dim, 0);
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
return result.as_strided(shape, stride);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
static TensorIterator make_reduction(
|
| 186 |
+
const char* name, Tensor& result, const Tensor& self,
|
| 187 |
+
at::OptionalIntArrayRef dim_opt,
|
| 188 |
+
bool keepdim, ScalarType in_dtype, ScalarType out_dtype) {
|
| 189 |
+
// check that result type and dtype match if provided
|
| 190 |
+
TORCH_CHECK(
|
| 191 |
+
!result.defined() || result.scalar_type() == out_dtype,
|
| 192 |
+
name, ": provided dtype must match dtype of result. Got ",
|
| 193 |
+
toString(result.scalar_type()),
|
| 194 |
+
" and ",
|
| 195 |
+
toString(out_dtype),
|
| 196 |
+
".");
|
| 197 |
+
// dim={} performs an all-reduce, same as dim=None
|
| 198 |
+
IntArrayRef dim = dim_opt.value_or(IntArrayRef{});
|
| 199 |
+
int64_t ndim = self.dim();
|
| 200 |
+
auto mask = make_dim_mask(dim, ndim);
|
| 201 |
+
resize_reduction_result(result, self, mask, keepdim, out_dtype);
|
| 202 |
+
auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
|
| 203 |
+
namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
|
| 204 |
+
if (self.scalar_type() == in_dtype) {
|
| 205 |
+
return TensorIterator::reduce_op(viewed_result, self);
|
| 206 |
+
}
|
| 207 |
+
return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
static C10_UNUSED TensorIterator make_reduction(
|
| 211 |
+
const char* name, Tensor& result, const Tensor& self,
|
| 212 |
+
at::OptionalIntArrayRef dim, bool keepdim, ScalarType out_dtype) {
|
| 213 |
+
// special case for type promotion in mixed precision, improves computational
|
| 214 |
+
// efficiency.
|
| 215 |
+
// not generalize this to common mismatched input/output types to avoid cross
|
| 216 |
+
// product of templated kernel launches.
|
| 217 |
+
const bool gpu_lowp_to_f32 = (
|
| 218 |
+
self.is_cuda() && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat);
|
| 219 |
+
auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type()
|
| 220 |
+
: self.is_complex() ? c10::toComplexType(out_dtype)
|
| 221 |
+
: out_dtype;
|
| 222 |
+
return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
static TensorIterator make_reduction(
|
| 226 |
+
const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
|
| 227 |
+
at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1,
|
| 228 |
+
ScalarType dtype2) {
|
| 229 |
+
// check that result type and dtype match if provided
|
| 230 |
+
TORCH_CHECK(
|
| 231 |
+
(!result1.defined() || result1.scalar_type() == dtype1) && (!result2.defined() || result2.scalar_type() == dtype2),
|
| 232 |
+
name, ": provided dtype must match dtype of result. Got ",
|
| 233 |
+
toString(result1.scalar_type()), toString(result2.scalar_type()),
|
| 234 |
+
" and ",
|
| 235 |
+
toString(dtype1), toString(dtype2),
|
| 236 |
+
".");
|
| 237 |
+
|
| 238 |
+
// dim={} performs an all-reduce, same as dim=None
|
| 239 |
+
auto dim = dim_opt.value_or(IntArrayRef{});
|
| 240 |
+
int64_t ndim = self.dim();
|
| 241 |
+
DimMask mask = make_dim_mask(dim, ndim);
|
| 242 |
+
resize_reduction_result(result1, self, mask, keepdim, dtype1);
|
| 243 |
+
auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim);
|
| 244 |
+
|
| 245 |
+
resize_reduction_result(result2, self, mask, keepdim, dtype2);
|
| 246 |
+
auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim);
|
| 247 |
+
|
| 248 |
+
namedinference::propagate_names_for_reduction(result1, self, dim, keepdim);
|
| 249 |
+
namedinference::propagate_names_for_reduction(result2, self, dim, keepdim);
|
| 250 |
+
|
| 251 |
+
// special case for type promotion in mixed precision, improves computational
|
| 252 |
+
// efficiency.
|
| 253 |
+
// We don't generalize this to common mismatched input/output types to avoid cross
|
| 254 |
+
// product of templated kernel launches.
|
| 255 |
+
if (self.scalar_type() == dtype1 ||
|
| 256 |
+
(self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
|
| 257 |
+
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
|
| 258 |
+
}
|
| 259 |
+
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
static C10_UNUSED TensorIterator make_reduction(
|
| 263 |
+
const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
|
| 264 |
+
at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) {
|
| 265 |
+
return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
static void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
|
| 269 |
+
if (self.ndimension() == 0) {
|
| 270 |
+
TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name,
|
| 271 |
+
": Expected reduction dim -1 or 0 for scalar but got ", dim);
|
| 272 |
+
}
|
| 273 |
+
else {
|
| 274 |
+
TORCH_CHECK_INDEX(self.size(dim) != 0, fn_name,
|
| 275 |
+
": Expected reduction dim ", dim, " to have non-zero size.");
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
|
| 280 |
+
TORCH_CHECK(
|
| 281 |
+
!dim.empty(),
|
| 282 |
+
fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
|
| 283 |
+
"Specify the reduction dim with the 'dim' argument.");
|
| 284 |
+
for (const int64_t d : dim) {
|
| 285 |
+
zero_numel_check_dims(self, d, fn_name);
|
| 286 |
+
}
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
static std::vector<int64_t> get_zero_numel_tensor_size(
|
| 290 |
+
const Tensor& self,
|
| 291 |
+
const int64_t dim,
|
| 292 |
+
const bool keepdim,
|
| 293 |
+
const char* fn_name) {
|
| 294 |
+
TORCH_INTERNAL_ASSERT(self.numel() == 0, fn_name, ": Expected self.numel() == 0.");
|
| 295 |
+
zero_numel_check_dims(self, dim, fn_name);
|
| 296 |
+
std::vector<int64_t> sizes;
|
| 297 |
+
if (keepdim) {
|
| 298 |
+
sizes = self.sizes().vec();
|
| 299 |
+
sizes[dim] = 1;
|
| 300 |
+
}
|
| 301 |
+
else {
|
| 302 |
+
for (const auto d : c10::irange(self.dim())) {
|
| 303 |
+
if (d != dim) {
|
| 304 |
+
sizes.push_back(self.sizes()[d]);
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
}
|
| 308 |
+
return sizes;
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
// Resize the result tensor and indices when result.numel() == 0 depending on values of
|
| 312 |
+
// dim and keepdim for returning tensors containing reduction results.
|
| 313 |
+
// This function should be called when you are reducing a zero-numel tensor and want to
|
| 314 |
+
// resize the output and return it. This function exists for resizing zero-numel
|
| 315 |
+
// tensors when the size of the reduction dimension is non-zero.
|
| 316 |
+
static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices,
|
| 317 |
+
const Tensor& self, const int64_t dim,
|
| 318 |
+
const bool keepdim, const char *fn_name) {
|
| 319 |
+
auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name);
|
| 320 |
+
at::native::resize_output(result, sizes);
|
| 321 |
+
at::native::resize_output(result_indices, sizes);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
inline ScalarType get_dtype_from_self(
|
| 325 |
+
const Tensor& self,
|
| 326 |
+
const c10::optional<ScalarType>& dtype,
|
| 327 |
+
bool promote_integers) {
|
| 328 |
+
if (dtype.has_value()) {
|
| 329 |
+
return dtype.value();
|
| 330 |
+
}
|
| 331 |
+
ScalarType src_type = self.scalar_type();
|
| 332 |
+
if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
|
| 333 |
+
return kLong;
|
| 334 |
+
}
|
| 335 |
+
return src_type;
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
inline ScalarType get_dtype_from_result(Tensor& result, c10::optional<ScalarType> dtype) {
|
| 339 |
+
TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
|
| 340 |
+
if (dtype.has_value()) {
|
| 341 |
+
return dtype.value();
|
| 342 |
+
} else {
|
| 343 |
+
return result.scalar_type();
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
} // namespace at::native
|
| 349 |
+
|
| 350 |
+
namespace at::meta {
|
| 351 |
+
|
| 352 |
+
static C10_UNUSED DimVector get_reduction_shape(
|
| 353 |
+
const Tensor& self,
|
| 354 |
+
IntArrayRef dims,
|
| 355 |
+
bool keepdim,
|
| 356 |
+
bool allow_empty_dims=false) {
|
| 357 |
+
auto mask = native::make_dim_mask(dims, self.dim(), allow_empty_dims);
|
| 358 |
+
return native::shape_from_dim_mask(self, mask, keepdim);
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
static void resize_reduction(
|
| 362 |
+
impl::MetaBase& meta,
|
| 363 |
+
const Tensor& self,
|
| 364 |
+
OptionalIntArrayRef opt_dims,
|
| 365 |
+
bool keepdim,
|
| 366 |
+
ScalarType out_dtype,
|
| 367 |
+
bool allow_empty_dims=false) {
|
| 368 |
+
DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
|
| 369 |
+
maybe_wrap_dims(dims_, self.dim());
|
| 370 |
+
auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims);
|
| 371 |
+
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
|
| 372 |
+
namedinference::propagate_names_for_reduction(
|
| 373 |
+
meta.maybe_get_output(), self, dims_, keepdim);
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
static void resize_reduction_with_indices(
|
| 377 |
+
impl::MetaBase& meta,
|
| 378 |
+
const Tensor& self,
|
| 379 |
+
IntArrayRef dims,
|
| 380 |
+
bool keepdim,
|
| 381 |
+
ScalarType out_dtype) {
|
| 382 |
+
DimVector dims_(dims);
|
| 383 |
+
maybe_wrap_dims(dims_, self.dim());
|
| 384 |
+
auto shape = get_reduction_shape(self, dims_, keepdim);
|
| 385 |
+
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
|
| 386 |
+
meta.set_output_raw_strided(1, shape, {}, self.options().dtype(kLong));
|
| 387 |
+
namedinference::propagate_names_for_reduction(
|
| 388 |
+
meta.maybe_get_output(0), self, dims_, keepdim);
|
| 389 |
+
namedinference::propagate_names_for_reduction(
|
| 390 |
+
meta.maybe_get_output(1), self, dims_, keepdim);
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
static TensorIterator make_reduction(
|
| 394 |
+
const Tensor& self,
|
| 395 |
+
const Tensor& result,
|
| 396 |
+
OptionalIntArrayRef opt_dims,
|
| 397 |
+
bool keepdim,
|
| 398 |
+
ScalarType in_dtype) {
|
| 399 |
+
int64_t ndim = self.dim();
|
| 400 |
+
auto mask = at::native::make_dim_mask(opt_dims, ndim);
|
| 401 |
+
auto viewed_result =
|
| 402 |
+
at::native::review_reduce_result(result, ndim, mask, keepdim);
|
| 403 |
+
if (self.scalar_type() == in_dtype) {
|
| 404 |
+
return TensorIterator::reduce_op(viewed_result, self);
|
| 405 |
+
}
|
| 406 |
+
return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
static TensorIterator make_reduction(
|
| 410 |
+
const Tensor& self,
|
| 411 |
+
const Tensor& result1,
|
| 412 |
+
const Tensor& result2,
|
| 413 |
+
IntArrayRef dims,
|
| 414 |
+
bool keepdim,
|
| 415 |
+
ScalarType dtype1,
|
| 416 |
+
ScalarType /*dtype2*/) {
|
| 417 |
+
int64_t ndim = self.dim();
|
| 418 |
+
auto mask = at::native::make_dim_mask(dims, ndim);
|
| 419 |
+
auto viewed_result1 = at::native::review_reduce_result(result1, ndim, mask, keepdim);
|
| 420 |
+
auto viewed_result2 = at::native::review_reduce_result(result2, ndim, mask, keepdim);
|
| 421 |
+
// special case for type promotion in mixed precision, improves computational efficiency.
|
| 422 |
+
// We don't generalize this to common mismatched input/output types to avoid cross product
|
| 423 |
+
// of templated kernel launches.
|
| 424 |
+
if (self.scalar_type() == dtype1 ||
|
| 425 |
+
(self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
|
| 426 |
+
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
|
| 427 |
+
}
|
| 428 |
+
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
static C10_UNUSED TensorIterator make_reduction_from_out_ty(
|
| 432 |
+
const Tensor& self,
|
| 433 |
+
const Tensor& result,
|
| 434 |
+
OptionalIntArrayRef opt_dims,
|
| 435 |
+
bool keepdim,
|
| 436 |
+
ScalarType out_dtype) {
|
| 437 |
+
// special case for type promotion in mixed precision, improves computational
|
| 438 |
+
// efficiency.
|
| 439 |
+
// not generalize this to common mismatched input/output types to avoid cross
|
| 440 |
+
// product of templated kernel launches.
|
| 441 |
+
const bool gpu_lowp_to_f32 =
|
| 442 |
+
(self.is_cuda() &&
|
| 443 |
+
(self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
|
| 444 |
+
out_dtype == kFloat);
|
| 445 |
+
auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
|
| 446 |
+
return make_reduction(self, result, opt_dims, keepdim, in_dtype);
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
} // namespace at::meta
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReductionType.h
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Scalar.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
|
| 8 |
+
|
| 9 |
+
static inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
|
| 10 |
+
if (reduce == "max" || reduce == "amax") {
|
| 11 |
+
return ReductionType::MAX;
|
| 12 |
+
} else if (reduce == "mean") {
|
| 13 |
+
return ReductionType::MEAN;
|
| 14 |
+
} else if (reduce == "min" || reduce == "amin") {
|
| 15 |
+
return ReductionType::MIN;
|
| 16 |
+
} else if (reduce == "sum") {
|
| 17 |
+
return ReductionType::SUM;
|
| 18 |
+
} else if (reduce == "prod") {
|
| 19 |
+
return ReductionType::PROD;
|
| 20 |
+
} else {
|
| 21 |
+
TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// used for `scatter_reduce`, old options for BC.
|
| 26 |
+
static inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
|
| 27 |
+
if (use_new_options) {
|
| 28 |
+
return get_reduction_enum(reduce);
|
| 29 |
+
} else {
|
| 30 |
+
if (reduce == "add") {
|
| 31 |
+
return ReductionType::SUM;
|
| 32 |
+
} else if (reduce == "multiply") {
|
| 33 |
+
return ReductionType::PROD;
|
| 34 |
+
} else {
|
| 35 |
+
TORCH_CHECK(false, "reduce argument must be either add or multiply.")
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
} // at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ScatterGatherChecks.h
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <vector>
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <ATen/native/ReduceOpsUtils.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
namespace {
|
| 11 |
+
|
| 12 |
+
// checks whether index.dtype == int64
|
| 13 |
+
// and self.dtype == src.dtype if src is a Tensor
|
| 14 |
+
static void scatter_gather_dtype_check(
|
| 15 |
+
const std::string& method_name,
|
| 16 |
+
const Tensor& self,
|
| 17 |
+
const Tensor& index,
|
| 18 |
+
const c10::optional<Tensor>& src_opt = c10::nullopt
|
| 19 |
+
) {
|
| 20 |
+
if (index.numel() != 0) {
|
| 21 |
+
TORCH_CHECK(
|
| 22 |
+
index.scalar_type() == at::ScalarType::Long,
|
| 23 |
+
method_name, "(): Expected dtype int64 for index"
|
| 24 |
+
);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
if (src_opt.has_value()) {
|
| 28 |
+
const auto& src = src_opt.value();
|
| 29 |
+
TORCH_CHECK(
|
| 30 |
+
self.scalar_type() == src.scalar_type(),
|
| 31 |
+
method_name, "(): Expected self.dtype to be equal to src.dtype"
|
| 32 |
+
);
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// Used for `gather`-like methods
|
| 37 |
+
// Note: self means the input tensor here
|
| 38 |
+
// Test:
|
| 39 |
+
// 1. index.size(d) <= self.size(d) for all d != dim
|
| 40 |
+
// 2. index.dim() == self.dim()
|
| 41 |
+
static C10_UNUSED void gather_shape_check(const Tensor& self, int64_t dim,
|
| 42 |
+
const Tensor& index
|
| 43 |
+
) {
|
| 44 |
+
auto self_dims = ensure_nonempty_dim(self.dim());
|
| 45 |
+
TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
|
| 46 |
+
"Index tensor must have the same number of dimensions as input tensor"
|
| 47 |
+
);
|
| 48 |
+
|
| 49 |
+
for (const auto i : c10::irange(self_dims)) {
|
| 50 |
+
if (i != dim) {
|
| 51 |
+
TORCH_CHECK(
|
| 52 |
+
ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
|
| 53 |
+
"Size does not match at dimension ", i,
|
| 54 |
+
" expected index ", index.sizes(),
|
| 55 |
+
" to be smaller than self ", self.sizes(),
|
| 56 |
+
" apart from dimension ", dim
|
| 57 |
+
);
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// Used for `scatter` and `scatter_add`
|
| 63 |
+
// Tests:
|
| 64 |
+
// 1. index.size(d) <= self.size(d) for all d != dim
|
| 65 |
+
// 2. index.size(d) <= src.size(d) for all d if src is a Tensor
|
| 66 |
+
// 3. index.dim() == self.dim() == src.dim()
|
| 67 |
+
static C10_UNUSED void scatter_shape_check(
|
| 68 |
+
const Tensor& self, int64_t dim, const Tensor& index,
|
| 69 |
+
const c10::optional<Tensor>& src_opt = c10::nullopt
|
| 70 |
+
) {
|
| 71 |
+
if (index.numel() == 0) return;
|
| 72 |
+
TORCH_CHECK(
|
| 73 |
+
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
|
| 74 |
+
"Index tensor must have the same number of dimensions as self tensor"
|
| 75 |
+
);
|
| 76 |
+
|
| 77 |
+
bool is_wrong_shape = false;
|
| 78 |
+
int64_t self_dims = ensure_nonempty_dim(self.dim());
|
| 79 |
+
|
| 80 |
+
// Check: index.size(d) <= self.size(d) for all d != dim
|
| 81 |
+
for (const auto d : c10::irange(self_dims)) {
|
| 82 |
+
int64_t index_d_size = ensure_nonempty_size(index, d);
|
| 83 |
+
if (d == dim) continue;
|
| 84 |
+
if (index_d_size > ensure_nonempty_size(self, d)) {
|
| 85 |
+
is_wrong_shape = true;
|
| 86 |
+
break;
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// Check: index.size(d) <= src.size(d) for all d if src is Tensor
|
| 91 |
+
if (!is_wrong_shape && src_opt.has_value()) {
|
| 92 |
+
const auto& src = src_opt.value();
|
| 93 |
+
for (const auto d : c10::irange(self_dims)) {
|
| 94 |
+
int64_t index_d_size = ensure_nonempty_size(index, d);
|
| 95 |
+
if (index_d_size > ensure_nonempty_size(src, d)) {
|
| 96 |
+
is_wrong_shape = true;
|
| 97 |
+
break;
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
if (src_opt.has_value()) {
|
| 103 |
+
const auto& src = src_opt.value();
|
| 104 |
+
|
| 105 |
+
TORCH_CHECK(
|
| 106 |
+
ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
|
| 107 |
+
"Index tensor must have the same number of dimensions as src tensor"
|
| 108 |
+
);
|
| 109 |
+
|
| 110 |
+
TORCH_CHECK(!is_wrong_shape,
|
| 111 |
+
"Expected index ", index.sizes(),
|
| 112 |
+
" to be smaller than self ", self.sizes(),
|
| 113 |
+
" apart from dimension ", dim,
|
| 114 |
+
" and to be smaller size than src ", src.sizes()
|
| 115 |
+
);
|
| 116 |
+
}
|
| 117 |
+
else {
|
| 118 |
+
TORCH_CHECK(!is_wrong_shape,
|
| 119 |
+
"Expected index ", index.sizes(),
|
| 120 |
+
" to be smaller than self ", self.sizes(),
|
| 121 |
+
" apart from dimension ", dim
|
| 122 |
+
);
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
} // anonymous namespace
|
| 127 |
+
|
| 128 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SegmentReduce.h
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <ATen/native/ReductionType.h>
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/util/Optional.h>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
class Tensor;
|
| 10 |
+
|
| 11 |
+
namespace native {
|
| 12 |
+
|
| 13 |
+
using segment_reduce_lengths_fn = Tensor (*)(
|
| 14 |
+
ReductionType,
|
| 15 |
+
const Tensor&,
|
| 16 |
+
const Tensor&,
|
| 17 |
+
int64_t,
|
| 18 |
+
const c10::optional<Scalar>&);
|
| 19 |
+
DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
|
| 20 |
+
|
| 21 |
+
using segment_reduce_offsets_fn = Tensor (*)(
|
| 22 |
+
ReductionType,
|
| 23 |
+
const Tensor&,
|
| 24 |
+
const Tensor&,
|
| 25 |
+
int64_t,
|
| 26 |
+
const c10::optional<Scalar>&);
|
| 27 |
+
DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub);
|
| 28 |
+
|
| 29 |
+
using segment_reduce_lengths_backward_fn = Tensor (*)(
|
| 30 |
+
const Tensor&,
|
| 31 |
+
const Tensor&,
|
| 32 |
+
const Tensor&,
|
| 33 |
+
ReductionType,
|
| 34 |
+
const Tensor&,
|
| 35 |
+
int64_t,
|
| 36 |
+
const c10::optional<Scalar>&);
|
| 37 |
+
DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub);
|
| 38 |
+
|
| 39 |
+
using segment_reduce_offsets_backward_fn = Tensor (*)(
|
| 40 |
+
const Tensor&,
|
| 41 |
+
const Tensor&,
|
| 42 |
+
const Tensor&,
|
| 43 |
+
ReductionType,
|
| 44 |
+
const Tensor&,
|
| 45 |
+
int64_t,
|
| 46 |
+
const c10::optional<Scalar>&);
|
| 47 |
+
DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub);
|
| 48 |
+
|
| 49 |
+
} // namespace native
|
| 50 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// Indexing tensors by tensors
|
| 4 |
+
|
| 5 |
+
#include <ATen/core/List.h>
|
| 6 |
+
#include <ATen/core/Tensor.h>
|
| 7 |
+
#include <ATen/native/DispatchStub.h>
|
| 8 |
+
#include <ATen/native/ReductionType.h>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
struct TensorIterator;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
|
| 16 |
+
using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
|
| 17 |
+
using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<c10::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
|
| 18 |
+
using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
|
| 19 |
+
using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
|
| 20 |
+
using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
|
| 21 |
+
using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
|
| 22 |
+
using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
|
| 23 |
+
const Tensor& src, const ReductionType& reduce);
|
| 24 |
+
using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
|
| 25 |
+
const Scalar& value, const ReductionType& reduce);
|
| 26 |
+
using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
|
| 27 |
+
const Tensor& src, const ReductionType& reduce);
|
| 28 |
+
|
| 29 |
+
DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
|
| 30 |
+
DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
|
| 31 |
+
DECLARE_DISPATCH(gather_fn, gather_stub);
|
| 32 |
+
DECLARE_DISPATCH(scatter_fn, scatter_stub);
|
| 33 |
+
DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
|
| 34 |
+
DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
|
| 35 |
+
DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
|
| 36 |
+
DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
|
| 37 |
+
DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub);
|
| 38 |
+
|
| 39 |
+
TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<c10::optional<at::Tensor>>& indices);
|
| 40 |
+
|
| 41 |
+
using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
|
| 42 |
+
using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
|
| 43 |
+
using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
|
| 44 |
+
|
| 45 |
+
DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
|
| 46 |
+
DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
|
| 47 |
+
DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
|
| 48 |
+
|
| 49 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TypeProperties.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/core/IListRef.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
struct ResultTypeState {
|
| 9 |
+
c10::ScalarType dimResult = ScalarType::Undefined;
|
| 10 |
+
c10::ScalarType wrappedResult = ScalarType::Undefined;
|
| 11 |
+
c10::ScalarType zeroResult = ScalarType::Undefined;
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state);
|
| 15 |
+
TORCH_API ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state);
|
| 16 |
+
TORCH_API ScalarType result_type(const ResultTypeState& state);
|
| 17 |
+
|
| 18 |
+
TORCH_API ScalarType result_type(ITensorListRef tensors);
|
| 19 |
+
|
| 20 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold2d.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <c10/core/ScalarType.h>
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
using unfold2d_fn = void (*)(
|
| 10 |
+
ScalarType dtype,
|
| 11 |
+
void *finput,
|
| 12 |
+
void *input,
|
| 13 |
+
int64_t kH,
|
| 14 |
+
int64_t kW,
|
| 15 |
+
int64_t dH,
|
| 16 |
+
int64_t dW,
|
| 17 |
+
int64_t padH,
|
| 18 |
+
int64_t padW,
|
| 19 |
+
int64_t n_input_plane,
|
| 20 |
+
int64_t input_height,
|
| 21 |
+
int64_t input_width,
|
| 22 |
+
int64_t output_height,
|
| 23 |
+
int64_t output_width,
|
| 24 |
+
bool is_channels_last
|
| 25 |
+
);
|
| 26 |
+
|
| 27 |
+
DECLARE_DISPATCH(unfold2d_fn, unfolded2d_copy_stub);
|
| 28 |
+
DECLARE_DISPATCH(unfold2d_fn, unfolded2d_acc_stub);
|
| 29 |
+
|
| 30 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/batch_norm.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
|
| 9 |
+
const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
|
| 10 |
+
using batch_norm_collect_stats_fn = void (*)(Tensor&, Tensor&, const Tensor&);
|
| 11 |
+
using batch_norm_backward_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&,
|
| 12 |
+
const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
|
| 13 |
+
|
| 14 |
+
DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_stub);
|
| 15 |
+
DECLARE_DISPATCH(batch_norm_collect_stats_fn, batch_norm_cpu_collect_stats_stub);
|
| 16 |
+
DECLARE_DISPATCH(batch_norm_backward_fn, batch_norm_cpu_backward_stub);
|
| 17 |
+
|
| 18 |
+
// TensorAccessor when it is defined to work around undefined...
|
| 19 |
+
template <typename scalar_t>
|
| 20 |
+
static TensorAccessor<scalar_t, 1> conditional_accessor_1d(const Tensor& t) {
|
| 21 |
+
if (! t.defined()) {
|
| 22 |
+
return TensorAccessor<scalar_t, 1>(nullptr, nullptr, nullptr);
|
| 23 |
+
}
|
| 24 |
+
return t.accessor<scalar_t, 1>();
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
template <typename scalar_t>
|
| 28 |
+
static scalar_t* conditional_data_ptr(const Tensor& t) {
|
| 29 |
+
return t.defined() ? t.contiguous().data_ptr<scalar_t>()
|
| 30 |
+
: nullptr;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/im2col_shape_check.h
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/TensorUtils.h>
|
| 4 |
+
#include <ATen/div_rtn.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
static inline void col2im_shape_check(
|
| 9 |
+
const Tensor& input,
|
| 10 |
+
const Tensor& grad_output,
|
| 11 |
+
int64_t output_height,
|
| 12 |
+
int64_t output_width,
|
| 13 |
+
int64_t kernel_height,
|
| 14 |
+
int64_t kernel_width,
|
| 15 |
+
int64_t dilation_height,
|
| 16 |
+
int64_t dilation_width,
|
| 17 |
+
int64_t pad_height,
|
| 18 |
+
int64_t pad_width,
|
| 19 |
+
int64_t stride_height,
|
| 20 |
+
int64_t stride_width) {
|
| 21 |
+
TORCH_CHECK(
|
| 22 |
+
kernel_width > 0 && kernel_height > 0,
|
| 23 |
+
"kernel size should be greater than zero, but got kernel_height: ",
|
| 24 |
+
kernel_height,
|
| 25 |
+
" kernel_width: ",
|
| 26 |
+
kernel_width);
|
| 27 |
+
TORCH_CHECK(
|
| 28 |
+
stride_width > 0 && stride_height > 0,
|
| 29 |
+
"stride should be greater than zero, but got stride_height: ",
|
| 30 |
+
stride_height,
|
| 31 |
+
" stride_width: ",
|
| 32 |
+
stride_width);
|
| 33 |
+
TORCH_CHECK(
|
| 34 |
+
dilation_width > 0 && dilation_height > 0,
|
| 35 |
+
"dilation should be greater than zero, but got dilation_height: ",
|
| 36 |
+
dilation_height,
|
| 37 |
+
" dilation_width: ",
|
| 38 |
+
dilation_width);
|
| 39 |
+
TORCH_CHECK(
|
| 40 |
+
pad_width >= 0 && pad_height >= 0,
|
| 41 |
+
"padding should be non-negative, but got pad_height: ",
|
| 42 |
+
pad_height,
|
| 43 |
+
" pad_width: ",
|
| 44 |
+
pad_width);
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
int64_t ndim = input.ndimension();
|
| 48 |
+
// allow dim=0 only the batch dimension.
|
| 49 |
+
TORCH_CHECK(
|
| 50 |
+
(ndim == 2 && input.size(0) != 0 && input.size(1) != 0) ||
|
| 51 |
+
(ndim == 3 && input.size(1) != 0 && input.size(2) != 0),
|
| 52 |
+
"Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ",
|
| 53 |
+
input.sizes());
|
| 54 |
+
|
| 55 |
+
int64_t batch_dim = (ndim == 3) ? 0 : -1;
|
| 56 |
+
int64_t n_input_plane = input.size(batch_dim + 1);
|
| 57 |
+
|
| 58 |
+
if (n_input_plane % (kernel_width * kernel_height) != 0) {
|
| 59 |
+
AT_ERROR(
|
| 60 |
+
"Expected size of input's dimension 1 to be divisible by the "
|
| 61 |
+
"product of kernel_size, but got input.size(1)=",
|
| 62 |
+
n_input_plane,
|
| 63 |
+
" and kernel_size=(",
|
| 64 |
+
kernel_height,
|
| 65 |
+
", ",
|
| 66 |
+
kernel_width,
|
| 67 |
+
").");
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
int64_t input_length = input.size(batch_dim + 2);
|
| 71 |
+
int64_t n_blocks_height =
|
| 72 |
+
div_rtn<int64_t>(
|
| 73 |
+
output_height + 2 * pad_height -
|
| 74 |
+
dilation_height * (kernel_height - 1) - 1,
|
| 75 |
+
stride_height) +
|
| 76 |
+
1;
|
| 77 |
+
int64_t n_blocks_width = div_rtn<int64_t>(
|
| 78 |
+
output_width + 2 * pad_width -
|
| 79 |
+
dilation_width * (kernel_width - 1) - 1,
|
| 80 |
+
stride_width) +
|
| 81 |
+
1;
|
| 82 |
+
|
| 83 |
+
if (input_length != (n_blocks_height * n_blocks_width)) {
|
| 84 |
+
AT_ERROR(
|
| 85 |
+
"Given output_size=(",
|
| 86 |
+
output_height,
|
| 87 |
+
", ",
|
| 88 |
+
output_width,
|
| 89 |
+
"), kernel_size=(",
|
| 90 |
+
kernel_height,
|
| 91 |
+
", ",
|
| 92 |
+
kernel_width,
|
| 93 |
+
"), dilation=(",
|
| 94 |
+
dilation_height,
|
| 95 |
+
", ",
|
| 96 |
+
dilation_width,
|
| 97 |
+
"), padding=(",
|
| 98 |
+
pad_height,
|
| 99 |
+
", ",
|
| 100 |
+
pad_width,
|
| 101 |
+
"), stride=(",
|
| 102 |
+
stride_height,
|
| 103 |
+
", ",
|
| 104 |
+
stride_width,
|
| 105 |
+
"), expected size of input's dimension 2 to match the calculated number of ",
|
| 106 |
+
"sliding blocks ",
|
| 107 |
+
n_blocks_height,
|
| 108 |
+
" * ",
|
| 109 |
+
n_blocks_width,
|
| 110 |
+
" = ",
|
| 111 |
+
(n_blocks_height * n_blocks_width),
|
| 112 |
+
", but got input.size(2)=",
|
| 113 |
+
input_length,
|
| 114 |
+
".");
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
TORCH_CHECK(
|
| 118 |
+
n_blocks_height >= 1 && n_blocks_width >= 1,
|
| 119 |
+
"Given output_size=(", output_height, ", ", output_width, "), ",
|
| 120 |
+
"kernel_size=(", kernel_height, ", ", kernel_width, "), ",
|
| 121 |
+
"dilation=(", dilation_height, ", ", dilation_width, "), ",
|
| 122 |
+
"padding=(", pad_height, ", ", pad_width, "), ",
|
| 123 |
+
"stride=(", stride_height, ", ", stride_width, "), ",
|
| 124 |
+
"calculated shape of the array of sliding blocks as ",
|
| 125 |
+
"(", n_blocks_height, ", ", n_blocks_width, "), ",
|
| 126 |
+
"which is too small (non-positive)");
|
| 127 |
+
|
| 128 |
+
if (output_width < 1 || output_height < 1) {
|
| 129 |
+
AT_ERROR(
|
| 130 |
+
"Expected output spatial size to be positive, but got: output_size=(",
|
| 131 |
+
output_height,
|
| 132 |
+
", ",
|
| 133 |
+
output_width,
|
| 134 |
+
").");
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
static inline void im2col_shape_check(
|
| 139 |
+
const Tensor& input,
|
| 140 |
+
const Tensor& grad_output,
|
| 141 |
+
int64_t kernel_height,
|
| 142 |
+
int64_t kernel_width,
|
| 143 |
+
int64_t dilation_height,
|
| 144 |
+
int64_t dilation_width,
|
| 145 |
+
int64_t pad_height,
|
| 146 |
+
int64_t pad_width,
|
| 147 |
+
int64_t stride_height,
|
| 148 |
+
int64_t stride_width) {
|
| 149 |
+
TORCH_CHECK(
|
| 150 |
+
kernel_width > 0 && kernel_height > 0,
|
| 151 |
+
"kernel size should be greater than zero, but got kernel_height: ",
|
| 152 |
+
kernel_height,
|
| 153 |
+
" kernel_width: ",
|
| 154 |
+
kernel_width);
|
| 155 |
+
|
| 156 |
+
TORCH_CHECK(
|
| 157 |
+
dilation_width > 0 && dilation_height > 0,
|
| 158 |
+
"dilation should be greater than zero, but got dilation_height: ",
|
| 159 |
+
dilation_height,
|
| 160 |
+
" dilation_width: ",
|
| 161 |
+
dilation_width);
|
| 162 |
+
|
| 163 |
+
TORCH_CHECK(
|
| 164 |
+
pad_width >= 0 && pad_height >= 0,
|
| 165 |
+
"padding should be non-negative, but got pad_height: ",
|
| 166 |
+
pad_height,
|
| 167 |
+
" pad_width: ",
|
| 168 |
+
pad_width);
|
| 169 |
+
|
| 170 |
+
TORCH_CHECK(
|
| 171 |
+
stride_width > 0 && stride_height > 0,
|
| 172 |
+
"stride should be greater than zero, but got stride_height: ",
|
| 173 |
+
stride_height,
|
| 174 |
+
" stride_width: ",
|
| 175 |
+
stride_width);
|
| 176 |
+
|
| 177 |
+
int64_t ndim = input.ndimension();
|
| 178 |
+
|
| 179 |
+
// allow dim=0 only the batch dimension.
|
| 180 |
+
bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
|
| 181 |
+
TORCH_CHECK(
|
| 182 |
+
(ndim == 3 && input.size(0) && valid_dims) ||
|
| 183 |
+
(ndim == 4 && valid_dims && input.size(3) != 0),
|
| 184 |
+
"Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
|
| 185 |
+
input.sizes());
|
| 186 |
+
|
| 187 |
+
int64_t dim_batch = 0;
|
| 188 |
+
|
| 189 |
+
if (ndim == 3) {
|
| 190 |
+
dim_batch = -1;
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
int64_t input_height = input.size(dim_batch + 2);
|
| 194 |
+
int64_t input_width = input.size(dim_batch + 3);
|
| 195 |
+
int64_t output_height = div_rtn<int64_t>(
|
| 196 |
+
input_height + 2 * pad_height -
|
| 197 |
+
(dilation_height * (kernel_height - 1) + 1),
|
| 198 |
+
stride_height) +
|
| 199 |
+
1;
|
| 200 |
+
int64_t output_width = div_rtn<int64_t>(
|
| 201 |
+
input_width + 2 * pad_width -
|
| 202 |
+
(dilation_width * (kernel_width - 1) + 1),
|
| 203 |
+
stride_width) +
|
| 204 |
+
1;
|
| 205 |
+
|
| 206 |
+
if (output_height < 1 || output_width < 1) {
|
| 207 |
+
AT_ERROR(
|
| 208 |
+
"Given input with spatial size (",
|
| 209 |
+
input_height,
|
| 210 |
+
", ",
|
| 211 |
+
input_height,
|
| 212 |
+
"), kernel_size=(",
|
| 213 |
+
kernel_height,
|
| 214 |
+
", ",
|
| 215 |
+
kernel_width,
|
| 216 |
+
"), dilation=(",
|
| 217 |
+
dilation_height,
|
| 218 |
+
", ",
|
| 219 |
+
dilation_width,
|
| 220 |
+
"), padding=(",
|
| 221 |
+
pad_height,
|
| 222 |
+
", ",
|
| 223 |
+
pad_width,
|
| 224 |
+
"), calculated shape of the array of sliding blocks as (",
|
| 225 |
+
output_height,
|
| 226 |
+
", ",
|
| 227 |
+
output_width,
|
| 228 |
+
"), but its components must be at least one.");
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cholesky_solve_helper_cpu_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cpu {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor _cholesky_solve_helper(const at::Tensor & self, const at::Tensor & A, bool upper);
|
| 21 |
+
|
| 22 |
+
} // namespace cpu
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_lgamma_ops.h
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _foreach_lgamma {
|
| 18 |
+
using schema = ::std::vector<at::Tensor> (at::TensorList);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_lgamma")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_lgamma(Tensor[] self) -> Tensor[]")
|
| 24 |
+
static ::std::vector<at::Tensor> call(at::TensorList self);
|
| 25 |
+
static ::std::vector<at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _foreach_lgamma_ {
|
| 29 |
+
using schema = void (at::TensorList);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_lgamma_")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_lgamma_(Tensor(a!)[] self) -> ()")
|
| 35 |
+
static void call(at::TensorList self);
|
| 36 |
+
static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
struct TORCH_API _foreach_lgamma_out {
|
| 40 |
+
using schema = void (at::TensorList, at::TensorList);
|
| 41 |
+
using ptr_schema = schema*;
|
| 42 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 43 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_lgamma")
|
| 44 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 45 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> ()")
|
| 46 |
+
static void call(at::TensorList self, at::TensorList out);
|
| 47 |
+
static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_cuda_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cuda {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> _scaled_dot_product_cudnn_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, c10::optional<double> scale=c10::nullopt);
|
| 21 |
+
|
| 22 |
+
} // namespace cuda
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_triton_scaled_dot_attention_compositeexplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor & _triton_scaled_dot_attention_out(at::Tensor & out, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p=0.0);
|
| 21 |
+
TORCH_API at::Tensor & _triton_scaled_dot_attention_outf(const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p, at::Tensor & out);
|
| 22 |
+
|
| 23 |
+
} // namespace compositeexplicitautograd
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_values_native.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor _values_sparse(const at::Tensor & self);
|
| 20 |
+
} // namespace native
|
| 21 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_meta.h>
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
struct TORCH_API structured_adaptive_max_pool3d_backward_out_cpu : public at::meta::structured_adaptive_max_pool3d_backward {
|
| 20 |
+
void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input);
|
| 21 |
+
};
|
| 22 |
+
struct TORCH_API structured_adaptive_max_pool3d_backward_out_cuda : public at::meta::structured_adaptive_max_pool3d_backward {
|
| 23 |
+
void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input);
|
| 24 |
+
};
|
| 25 |
+
} // namespace native
|
| 26 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautogradnonfunctional {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor amin(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false);
|
| 21 |
+
|
| 22 |
+
} // namespace compositeexplicitautogradnonfunctional
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeimplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor atleast_2d(const at::Tensor & self);
|
| 21 |
+
TORCH_API ::std::vector<at::Tensor> atleast_2d(at::TensorList tensors);
|
| 22 |
+
|
| 23 |
+
} // namespace compositeimplicitautograd
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_2d_native.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor atleast_2d(const at::Tensor & self);
|
| 20 |
+
TORCH_API ::std::vector<at::Tensor> atleast_2d(at::TensorList tensors);
|
| 21 |
+
} // namespace native
|
| 22 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautogradnonfunctional {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional<int64_t> divisor_override);
|
| 21 |
+
|
| 22 |
+
} // namespace compositeexplicitautogradnonfunctional
|
| 23 |
+
} // namespace at
|