koichi12 commited on
Commit
e52b3b0
·
verified ·
1 Parent(s): b755ec5

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py +233 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/conv.py +495 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h +37 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh +28 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h +275 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h +205 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h +278 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h +242 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h +119 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUFallback.h +45 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h +13 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h +97 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distributions.h +518 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonSymbolicBC.h +26 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PixelShuffle.h +47 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RangeFactories.h +12 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceAllOps.h +16 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOpsUtils.h +449 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReductionType.h +40 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ScatterGatherChecks.h +128 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SegmentReduce.h +50 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h +49 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TypeProperties.h +20 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold2d.h +30 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/batch_norm.h +33 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/im2col_shape_check.h +232 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cholesky_solve_helper_cpu_dispatch.h +23 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_lgamma_ops.h +50 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_cuda_dispatch.h +23 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_triton_scaled_dot_attention_compositeexplicitautograd_dispatch.h +24 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_values_native.h +21 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h +26 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h +24 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_2d_native.h +22 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-311.pyc ADDED
Binary file (5.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-311.pyc ADDED
Binary file (6.53 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-311.pyc ADDED
Binary file (17.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-311.pyc ADDED
Binary file (6.81 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-311.pyc ADDED
Binary file (26.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-311.pyc ADDED
Binary file (50.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-311.pyc ADDED
Binary file (26.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-311.pyc ADDED
Binary file (17.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
35
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
36
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
37
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
38
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
39
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
40
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
41
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
42
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
43
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
44
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
45
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
46
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
47
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
48
+ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
49
+ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
50
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
51
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
52
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
53
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
54
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
55
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
56
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
57
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
58
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
59
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
60
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
61
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
62
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
63
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
64
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
65
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
66
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
67
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
68
+ view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
69
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
70
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
71
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
72
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
73
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
74
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
75
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
76
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
77
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
78
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
79
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
80
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
81
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
82
+ view_default_8 = CallFunction(aten.view.default, sub_Tensor_1, Ignored(), _users=2)
83
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
84
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
85
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
86
+ div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
87
+ permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
88
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
89
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
90
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
91
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
92
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
93
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
94
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
95
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
96
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
97
+ _sfdp_pattern_9_training = MultiOutputPattern([view_default_5,
98
+ permute_default_6,
99
+ permute_default_9,
100
+ permute_default_11,
101
+ None
102
+ ])
103
+
104
+
105
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
106
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
107
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
108
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
109
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
110
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
111
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
112
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
113
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
114
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
115
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
116
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
117
+ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
118
+ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
119
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
120
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
121
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
122
+ clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
123
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
124
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
125
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
126
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
127
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
128
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
129
+ view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
130
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
131
+ _sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
132
+
133
+
134
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
135
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
136
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
137
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
138
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
139
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
140
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
141
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
142
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
143
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
144
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
145
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
146
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
147
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
148
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
149
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
150
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
151
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
152
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
153
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
154
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
155
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
156
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
157
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
158
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
159
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
160
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
161
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
162
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
163
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
164
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
165
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
166
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
167
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
168
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
169
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
170
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
171
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
172
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
173
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
174
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
175
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
176
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
177
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
178
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
179
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
180
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
181
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
182
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
183
+ view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2)
184
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
185
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
186
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
187
+ div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
188
+ permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
189
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
190
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
191
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
192
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
193
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
194
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
195
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
196
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
197
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
198
+ _sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5,
199
+ permute_default_6,
200
+ permute_default_9,
201
+ permute_default_11,
202
+ None
203
+ ])
204
+
205
+
206
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
207
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
208
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
209
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
210
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
211
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
212
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
213
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
214
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
215
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
216
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
217
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
218
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
219
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
220
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
221
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
222
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
223
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
224
+ clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
225
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
226
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
227
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
228
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
229
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
230
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
231
+ view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
232
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
233
+ _sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-311.pyc ADDED
Binary file (3.55 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/conv.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import logging
5
+ from typing import cast, List, Optional, Sequence, Tuple, TypedDict
6
+
7
+ import torch
8
+ from .. import config, ir
9
+ from ..ir import TensorBox
10
+
11
+ from ..lowering import (
12
+ add_layout_constraint,
13
+ constrain_to_fx_strides,
14
+ lowerings as L,
15
+ register_lowering,
16
+ )
17
+ from ..select_algorithm import (
18
+ autotune_select_algorithm,
19
+ ExternKernelChoice,
20
+ TritonTemplate,
21
+ )
22
+ from ..utils import (
23
+ ceildiv,
24
+ is_ones,
25
+ is_zeros,
26
+ pad_listlike,
27
+ sympy_product,
28
+ use_triton_template,
29
+ )
30
+ from ..virtualized import V
31
+ from .mm_common import filtered_configs
32
+
33
+ log = logging.getLogger(__name__)
34
+
35
+
36
+ aten = torch.ops.aten
37
+
38
+
39
+ def conv_grid(n, c, h, w, meta):
40
+ return (
41
+ ceildiv(n * h * w, meta["BLOCK_M"]),
42
+ ceildiv(c, meta["BLOCK_N"]),
43
+ meta["GROUPS"],
44
+ )
45
+
46
+
47
+ # List of dictionaries to store the kernel configs. Configs that evaluate to true
48
+ # will be utilised on the target platform
49
+ kernel_configs = [
50
+ # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
51
+ {"config": (64, 256, 16, 2, 4), "cond": True},
52
+ {"config": (256, 64, 16, 2, 4), "cond": True},
53
+ {"config": (1024, 16, 16, 1, 8), "cond": True},
54
+ {"config": (128, 128, 32, 2, 8), "cond": True},
55
+ {"config": (64, 64, 32, 2, 4), "cond": True},
56
+ {"config": (64, 256, 32, 2, 8), "cond": True},
57
+ {"config": (256, 64, 32, 2, 8), "cond": True},
58
+ ]
59
+
60
+ # Create filtered list of configs based on conv
61
+ platform_configs = tuple(
62
+ cast(Tuple[int, int, int, int, int], config["config"])
63
+ for config in kernel_configs
64
+ if config["cond"]
65
+ )
66
+
67
+ # On ROCm convert num_stages to 1 as pipelining provides no benefit
68
+ if torch.version.hip:
69
+ platform_configs = tuple(
70
+ (config[0], config[1], config[2], 1, config[4]) for config in platform_configs
71
+ )
72
+
73
+ conv_configs = functools.partial(
74
+ filtered_configs,
75
+ configs=platform_configs,
76
+ )
77
+
78
+ LOOP_BODY = """
79
+ idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
80
+ idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
81
+ idx_x_c = tl.arange(0, BLOCK_K) + k
82
+
83
+ x_ptrs = x_base + (
84
+ (idx_x_h * stride_xh)[:, None]
85
+ + (idx_x_w * stride_xw)[:, None]
86
+ + (idx_x_c * stride_xc)[None, :]
87
+ )
88
+ mask_x = (
89
+ (idx_n < BATCH)[:, None]
90
+ & (idx_x_h >= 0)[:, None]
91
+ & (idx_x_h < IN_H)[:, None]
92
+ & (idx_x_w >= 0)[:, None]
93
+ & (idx_x_w < IN_W)[:, None]
94
+ & (idx_x_c < GROUP_IN_C)[None, :]
95
+ )
96
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
97
+
98
+ w_ptrs = w_base + (
99
+ (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
100
+ )
101
+ mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
102
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
103
+ acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
104
+ """
105
+
106
+ """
107
+ This is a relatively simple conv implementation that can likely be
108
+ improved. Many alternate conv versions can be found here:
109
+ https://github.com/pytorch/torchdynamo/pull/971
110
+ """
111
+ conv2d_template = TritonTemplate(
112
+ name="convolution",
113
+ grid=conv_grid,
114
+ source=r"""
115
+ {{def_kernel("X", "W")}}
116
+ # Tensor dimensions
117
+ BATCH = {{size("X", 0)}}
118
+ IN_C = {{size("X", 1)}}
119
+ IN_H = {{size("X", 2)}}
120
+ IN_W = {{size("X", 3)}}
121
+ OUT_C = {{size(None, 1)}}
122
+ OUT_H = {{size(None, 2)}}
123
+ OUT_W = {{size(None, 3)}}
124
+
125
+ # Strides:
126
+ stride_xn = {{stride("X", 0)}}
127
+ stride_xc = {{stride("X", 1)}}
128
+ stride_xh = {{stride("X", 2)}}
129
+ stride_xw = {{stride("X", 3)}}
130
+ stride_wc_out = {{stride("W", 0)}}
131
+ stride_wc_in = {{stride("W", 1)}}
132
+ stride_wh = {{stride("W", 2)}}
133
+ stride_ww = {{stride("W", 3)}}
134
+
135
+ nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
136
+ idx_y_w = nhw % OUT_W
137
+ nh = nhw // OUT_W
138
+ idx_y_h = nh % OUT_H
139
+ idx_n = nh // OUT_H
140
+ idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
141
+
142
+ {% if GROUPS == 1 %}
143
+ group = 0
144
+ GROUP_IN_C = IN_C
145
+ GROUP_OUT_C = OUT_C
146
+ {% else %}
147
+ group = tl.program_id(2)
148
+ GROUP_IN_C = IN_C // GROUPS
149
+ GROUP_OUT_C = OUT_C // GROUPS
150
+ {% endif %}
151
+
152
+ x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
153
+ w_base = (
154
+ W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
155
+ )
156
+
157
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
158
+
159
+ {% if UNROLL %}
160
+ {% for i in range(KERNEL_H) %}
161
+ {% for j in range(KERNEL_W) %}
162
+ i = {{i}}
163
+ j = {{j}}
164
+ for k in range(0, GROUP_IN_C, BLOCK_K):
165
+ """
166
+ + LOOP_BODY
167
+ + """
168
+ {% endfor %}
169
+ {% endfor %}
170
+ {% else %}
171
+ # Could be simplified, but slightly slower:
172
+ # for i in range(KERNEL_H):
173
+ # for j in range(KERNEL_W):
174
+ # for k in range(0, GROUP_IN_C, BLOCK_K):
175
+ BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
176
+ for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
177
+ k = (ijk % BLOCK_K_COUNT) * BLOCK_K
178
+ ij = ijk // BLOCK_K_COUNT
179
+ i = ij // KERNEL_W
180
+ j = ij % KERNEL_W
181
+ """
182
+ + LOOP_BODY
183
+ + """
184
+ {% endif %}
185
+
186
+ mask = (
187
+ (idx_n < BATCH)[:, None]
188
+ & (idx_y_h < OUT_H)[:, None]
189
+ & (idx_y_w < OUT_W)[:, None]
190
+ & (idx_y_c < GROUP_OUT_C)[None, :]
191
+ )
192
+ idx_n = idx_n[:, None]
193
+ idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
194
+ idx_h = idx_y_h[:, None]
195
+ idx_w = idx_y_w[:, None]
196
+
197
+ # inductor generates a suffix
198
+ {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}}
199
+ """,
200
+ )
201
+
202
+ aten_convolution = ExternKernelChoice(
203
+ torch.convolution,
204
+ "at::convolution",
205
+ has_out_variant=False,
206
+ op_overload=aten.convolution.default,
207
+ )
208
+
209
+
210
+ def conv1x1_via_mm(x, w, *, out):
211
+ w = torch.squeeze(torch.squeeze(w, -1), -1)
212
+ return torch.matmul(
213
+ x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
214
+ )
215
+
216
+
217
+ aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
218
+
219
+
220
+ class ConvLayoutParams(TypedDict):
221
+ stride: tuple[int, ...]
222
+ padding: tuple[int, ...]
223
+ dilation: tuple[int, ...]
224
+ transposed: bool
225
+ output_padding: tuple[int, ...]
226
+ groups: int
227
+
228
+
229
+ def conv_layout(
230
+ x: TensorBox,
231
+ weight: TensorBox,
232
+ bias: Optional[TensorBox],
233
+ stride: Sequence[int],
234
+ padding: tuple[int, ...],
235
+ dilation: tuple[int, ...],
236
+ transposed: bool,
237
+ output_padding: tuple[int, ...],
238
+ groups: int,
239
+ ) -> ir.Layout:
240
+ """Determine output layout for a convolution"""
241
+ with V.graph.fake_mode:
242
+ output = torch.ops.aten.convolution(
243
+ ir.ir_node_to_tensor(x, guard_shape=True),
244
+ ir.ir_node_to_tensor(weight, guard_shape=True),
245
+ ir.ir_node_to_tensor(bias, guard_shape=True),
246
+ stride,
247
+ tuple(V.graph.sizevars.size_hint(p) for p in padding), # type: ignore[arg-type]
248
+ dilation,
249
+ transposed,
250
+ tuple(V.graph.sizevars.size_hint(p) for p in output_padding), # type: ignore[arg-type]
251
+ groups,
252
+ )
253
+ sizes = ir.convert_shape_to_inductor(output.size())
254
+ stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
255
+
256
+ return ir.FixedLayout(
257
+ x.get_device(),
258
+ x.get_dtype(),
259
+ sizes,
260
+ stride,
261
+ )
262
+
263
+
264
+ def channels_last_order(rank):
265
+ order = list(reversed(range(rank)))
266
+ order.insert(1, order.pop(-1))
267
+ return order
268
+
269
+
270
+ def convert_1x1_conv_to_mm(x, weight, bias):
271
+ # special case for 1x1 convolution, which is actually just a matmul
272
+ rank = len(weight.get_size())
273
+ for _ in range(rank - 2):
274
+ weight = L[aten.squeeze](weight, dim=-1)
275
+ weight = L[aten.permute](weight, [1, 0])
276
+
277
+ if x.get_size()[0] != 1:
278
+ x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
279
+ else:
280
+ x.realize()
281
+ x.freeze_layout()
282
+
283
+ x_permute = list(range(rank))
284
+ x_permute.append(x_permute.pop(1))
285
+ x = L[aten.permute](x, x_permute)
286
+ *sizes, in_chan = x.get_size()
287
+ x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
288
+ if bias is None:
289
+ result = L[aten.mm](x, weight)
290
+ else:
291
+ result = L[aten.addmm](bias, x, weight)
292
+ result = L[aten.reshape](result, [*sizes, -1])
293
+ result_permute = list(range(rank))
294
+ result_permute.insert(1, result_permute.pop(-1))
295
+ return L[aten.permute](result, result_permute)
296
+
297
+
298
+ @register_lowering(aten.convolution)
299
+ def convolution(
300
+ x: TensorBox,
301
+ weight: TensorBox,
302
+ bias: TensorBox,
303
+ stride: List[int],
304
+ padding: List[int],
305
+ dilation: List[int],
306
+ transposed: bool,
307
+ output_padding: List[int],
308
+ groups: int,
309
+ ):
310
+ stride = tuple(stride)
311
+ padding = tuple(padding)
312
+ dilation = tuple(dilation)
313
+ output_padding = tuple(output_padding)
314
+ if not isinstance(groups, int):
315
+ groups = V.graph.sizevars.evaluate_static_shape(groups)
316
+ assert isinstance(groups, int)
317
+ kwargs: ConvLayoutParams = {
318
+ "stride": stride,
319
+ "padding": padding,
320
+ "dilation": dilation,
321
+ "transposed": transposed,
322
+ "output_padding": output_padding,
323
+ "groups": groups,
324
+ }
325
+
326
+ if len(x.get_size()) == len(weight.get_size()) - 1:
327
+ # add batch dimension to simplify rest of function
328
+ return L[aten.squeeze](
329
+ convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
330
+ dim=0,
331
+ )
332
+
333
+ out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
334
+ weight.get_size()
335
+ )
336
+ ndim = len(kernel_shape)
337
+ stride = pad_listlike(stride, ndim)
338
+ padding = pad_listlike(padding, ndim)
339
+ dilation = pad_listlike(dilation, ndim)
340
+ output_padding = pad_listlike(output_padding, ndim)
341
+
342
+ def channels_last_conv():
343
+ if V.graph.layout_opt and ndim == 2:
344
+ return True
345
+
346
+ layout = conv_layout(x, weight, None, **kwargs)
347
+ req_stride_order = ir.get_stride_order(
348
+ V.graph.sizevars.size_hints(layout.stride)
349
+ )
350
+ return req_stride_order == ir.NHWC_STRIDE_ORDER
351
+
352
+ autotuning_gemm = config.max_autotune or config.max_autotune_gemm
353
+
354
+ if (
355
+ (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
356
+ and is_ones(kernel_shape)
357
+ and is_ones(stride)
358
+ and is_zeros(padding)
359
+ and is_ones(dilation)
360
+ and not transposed
361
+ and is_zeros(output_padding)
362
+ and groups == 1
363
+ ):
364
+ return convert_1x1_conv_to_mm(x, weight, bias)
365
+
366
+ if bias is not None and ir.get_device_type(x) != "cpu":
367
+ # peel off the bias, cudnn is slower with it
368
+ result = convolution(x, weight, None, **kwargs)
369
+ return L[aten.add](
370
+ result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
371
+ )
372
+
373
+ x.realize()
374
+ weight.realize()
375
+
376
+ # ndim can be 1 for convolution in models such as demucs
377
+ # TODO: check if it's beneficial to convert Conv1d to Conv2d and then
378
+ # apply channels last.
379
+ if V.graph.layout_opt and ndim == 2:
380
+ V.graph.num_channels_last_conv += 1
381
+ x = ir.ExternKernel.require_channels_last(x)
382
+ # TODO maybe we can convert weights to channels last just once before
383
+ # running the model.
384
+ weight = ir.ExternKernel.require_channels_last(weight)
385
+ layout = conv_layout(x, weight, None, **kwargs)
386
+ else:
387
+ layout = conv_layout(x, weight, None, **kwargs)
388
+ req_stride_order = ir.get_stride_order(
389
+ V.graph.sizevars.size_hints(layout.stride)
390
+ )
391
+ x = ir.ExternKernel.require_stride_order(x, req_stride_order)
392
+ weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)
393
+
394
+ ordered_kwargs_for_cpp_kernel = [
395
+ "stride",
396
+ "padding",
397
+ "dilation",
398
+ "transposed",
399
+ "output_padding",
400
+ "groups",
401
+ ]
402
+ if bias is None:
403
+ args = [x, weight]
404
+ kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
405
+ ordered_kwargs_for_cpp_kernel.insert(0, "bias")
406
+ else:
407
+ args = [x, weight, bias]
408
+ bias.realize()
409
+ bias.freeze_layout()
410
+ V.graph.sizevars.evaluate_static_shapes(bias.get_size())
411
+ choices = [
412
+ aten_convolution.bind(
413
+ args,
414
+ layout,
415
+ ordered_kwargs_for_cpp_kernel,
416
+ **kwargs,
417
+ )
418
+ ]
419
+
420
+ if (
421
+ use_triton_template(layout)
422
+ # templates only support these:
423
+ and ndim == 2
424
+ and is_ones(dilation)
425
+ and not transposed
426
+ and is_zeros(output_padding)
427
+ # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
428
+ and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
429
+ ):
430
+ if (
431
+ is_ones(kernel_shape)
432
+ and is_ones(stride)
433
+ and is_zeros(padding)
434
+ and groups == 1
435
+ ):
436
+ choices.append(aten_conv1x1_via_mm.bind(args, layout))
437
+
438
+ for cfg in conv_configs(
439
+ sympy_product([x.get_size()[0], *x.get_size()[2:]]),
440
+ out_chan,
441
+ in_chan,
442
+ ):
443
+ conv2d_template.maybe_append_choice(
444
+ choices,
445
+ input_nodes=(x, weight),
446
+ layout=layout,
447
+ KERNEL_H=kernel_shape[0],
448
+ KERNEL_W=kernel_shape[1],
449
+ STRIDE_H=stride[0],
450
+ STRIDE_W=stride[1],
451
+ PADDING_H=padding[0],
452
+ PADDING_W=padding[1],
453
+ GROUPS=groups,
454
+ # TODO(jansel): try unroll for bigger kernels once fixed:
455
+ # https://github.com/openai/triton/issues/1254
456
+ UNROLL=is_ones(kernel_shape),
457
+ ALLOW_TF32=torch.backends.cudnn.allow_tf32,
458
+ num_stages=cfg.num_stages,
459
+ num_warps=cfg.num_warps,
460
+ **cfg.kwargs,
461
+ )
462
+
463
+ return autotune_select_algorithm("convolution", choices, args, layout)
464
+
465
+
466
+ @register_lowering(aten._convolution)
467
+ def _convolution(
468
+ x,
469
+ weight,
470
+ bias,
471
+ stride,
472
+ padding,
473
+ dilation,
474
+ transposed,
475
+ output_padding,
476
+ groups,
477
+ benchmark,
478
+ deterministic,
479
+ cudnn_enabled,
480
+ allow_tf32,
481
+ ):
482
+ return convolution(
483
+ x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
484
+ )
485
+
486
+
487
+ def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
488
+ assert fx_node.target == torch.ops.aten.convolution.default
489
+ if V.graph.layout_opt:
490
+ return args, kwargs
491
+ else:
492
+ return constrain_to_fx_strides(fx_node, *args, **kwargs)
493
+
494
+
495
+ add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <limits>
4
+ #include <c10/util/Exception.h>
5
+
6
+ namespace at::cuda::detail {
7
+
8
+ // CUDA: grid stride looping
9
+ //
10
+ // int64_t _i_n_d_e_x specifically prevents overflow in the loop increment.
11
+ // If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final
12
+ // iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be
13
+ // greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no
14
+ // further iterations and the overflowed value in i=_i_n_d_e_x is not used.
15
+ #define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
16
+ int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \
17
+ for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
18
+
19
+ #define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
20
+
21
+
22
+ // Use 1024 threads per block, which requires cuda sm_2x or above
23
+ constexpr int CUDA_NUM_THREADS = 1024;
24
+
25
+ // CUDA: number of blocks for threads.
26
+ inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) {
27
+ TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
28
+ constexpr int64_t max_int = std::numeric_limits<int>::max();
29
+
30
+ // Round up division for positive number that cannot cause integer overflow
31
+ auto block_num = (N - 1) / max_threads_per_block + 1;
32
+ TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
33
+
34
+ return static_cast<int>(block_num);
35
+ }
36
+
37
+ } // namespace at::cuda::detail
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // No "#pragma once" because this is a raw definition that can be copied by jit codegen.
2
+ // Eager mode clients should not include this file directly, instead,
3
+ // they should #include <ATen/cuda/PhiloxUtils.cuh>, which has a #pragma once.
4
+
5
+ namespace at::cuda::philox {
6
+
7
+ // In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
8
+ // that instance was created with graph capture underway or not.
9
+ // See Note [CUDA Graph-safe RNG states].
10
+ //
11
+ // We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen.
12
+ // Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable.
13
+ // Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
14
+ //
15
+ // The raw definition lives in its own file so jit codegen can easily copy it.
16
+ __host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
17
+ unpack(at::PhiloxCudaState arg) {
18
+ if (arg.captured_) {
19
+ // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
20
+ // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
21
+ // For most threads' reads it will hit in cache, so it shouldn't hurt performance.
22
+ return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
23
+ } else {
24
+ return std::make_tuple(arg.seed_.val, arg.offset_.val);
25
+ }
26
+ }
27
+
28
+ } // namespace at::cuda::philox
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ #include <ATen/cuda/CUDAContext.h>
7
+ #include <ATen/cuda/tunable/TunableOp.h>
8
+ #include <ATen/cuda/tunable/GemmCommon.h>
9
+ #include <c10/util/StringUtil.h>
10
+
11
+ #define ROCBLAS_BETA_FEATURES_API
12
+ #include <rocblas/rocblas.h>
13
+
14
+ #define TORCH_ROCBLAS_CHECK(EXPR) \
15
+ do { \
16
+ rocblas_status __err = EXPR; \
17
+ TORCH_CHECK(__err == rocblas_status_success, \
18
+ "rocblas error: ", \
19
+ rocblas_status_to_string(__err), \
20
+ " when calling `" #EXPR "`"); \
21
+ } while (0)
22
+
23
+ namespace at::cuda::tunable {
24
+
25
+ template <typename T>
26
+ constexpr rocblas_datatype RocBlasDataTypeFor();
27
+
28
+ template <>
29
+ constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
30
+ return rocblas_datatype_f32_r;
31
+ }
32
+
33
+ template <>
34
+ constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
35
+ return rocblas_datatype_f64_r;
36
+ }
37
+
38
+ template <>
39
+ constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
40
+ return rocblas_datatype_f16_r;
41
+ }
42
+
43
+ template <>
44
+ constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
45
+ return rocblas_datatype_bf16_r;
46
+ }
47
+
48
+ template <>
49
+ constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
50
+ return rocblas_datatype_f32_c;
51
+ }
52
+
53
+ template <>
54
+ constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
55
+ return rocblas_datatype_f64_c;
56
+ }
57
+
58
+ template <typename T>
59
+ constexpr rocblas_datatype RocBlasComputeTypeFor();
60
+
61
+ template <>
62
+ constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
63
+ return rocblas_datatype_f32_r;
64
+ }
65
+
66
+ template <>
67
+ constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
68
+ return rocblas_datatype_f64_r;
69
+ }
70
+
71
+ template <>
72
+ constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
73
+ // Note that we're returning the _compute_ type for a given datatype.
74
+ // As of 12/2022, using compute type FP16 for 16-bit floats was much
75
+ // slower than using compute type FP32. So we use FP32 compute even for
76
+ // FP16 datatypes. This is how GEMM is implemented even in the function
77
+ // rocblasGemmHelper (see fpgeneric.h)
78
+ return rocblas_datatype_f32_r;
79
+ }
80
+
81
+ template <>
82
+ constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
83
+ // Note that we're returning the _compute_ type for a given datatype.
84
+ // As of 12/2022, using compute type FP16 for 16-bit floats was much
85
+ // slower than using compute type FP32. So we use FP32 compute even for
86
+ // BF16 datatypes. This is how GEMM is implemented even in the function
87
+ // rocblasGemmHelper (see fpgeneric.h)
88
+ return rocblas_datatype_f32_r;
89
+ }
90
+
91
+ template <>
92
+ constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
93
+ return rocblas_datatype_f32_c;
94
+ }
95
+
96
+ template <>
97
+ constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
98
+ return rocblas_datatype_f64_c;
99
+ }
100
+
101
+ template <typename T>
102
+ auto DoCastForHalfOrBfloat16(const T fp) {
103
+ return fp;
104
+ }
105
+
106
+ template <>
107
+ inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
108
+ // alpha and beta should be the same as compute_type, in Half case it is float.
109
+ float h = fp;
110
+ return h;
111
+ }
112
+
113
+ template <>
114
+ inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
115
+ // alpha and beta should be the same as compute_type, in bfloat16 case it is float.
116
+ float h = fp;
117
+ return h;
118
+ }
119
+
120
+ static rocblas_operation _rocblasOpFromChar(char op) {
121
+ switch (op) {
122
+ case 'n':
123
+ case 'N':
124
+ return rocblas_operation_none;
125
+ case 't':
126
+ case 'T':
127
+ return rocblas_operation_transpose;
128
+ case 'c':
129
+ case 'C':
130
+ return rocblas_operation_conjugate_transpose;
131
+ }
132
+ AT_ERROR(
133
+ "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
134
+ }
135
+
136
+ template <typename T>
137
+ class RocblasGemmOp : public Callable<GemmParams<T>> {
138
+ public:
139
+ RocblasGemmOp(int solution) : solution_{solution} {}
140
+
141
+ TuningStatus Call(const GemmParams<T>* params) override {
142
+ auto input_output_type = RocBlasDataTypeFor<T>();
143
+ auto compute_type = RocBlasComputeTypeFor<T>();
144
+ auto h_a = DoCastForHalfOrBfloat16(params->alpha);
145
+ auto h_b = DoCastForHalfOrBfloat16(params->beta);
146
+ auto status = rocblas_gemm_ex(
147
+ (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
148
+ _rocblasOpFromChar(params->transa),
149
+ _rocblasOpFromChar(params->transb),
150
+ params->m, params->n, params->k,
151
+ &h_a,
152
+ params->a, input_output_type, params->lda,
153
+ params->b, input_output_type, params->ldb,
154
+ &h_b,
155
+ params->c, input_output_type, params->ldc,
156
+ params->c, input_output_type, params->ldc,
157
+ compute_type,
158
+ rocblas_gemm_algo_solution_index,
159
+ solution_,
160
+ rocblas_gemm_flags_none);
161
+ if (status != rocblas_status_success) {
162
+ return FAIL;
163
+ }
164
+ return OK;
165
+ }
166
+
167
+ private:
168
+ int solution_;
169
+ };
170
+
171
+ template <typename T>
172
+ auto GetRocBlasGemmTypeStringAndOps() {
173
+ rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
174
+ int solution_size;
175
+ auto input_output_type = RocBlasDataTypeFor<T>();
176
+ auto compute_type = RocBlasComputeTypeFor<T>();
177
+ // Get the number of available solutions
178
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
179
+ input_output_type,
180
+ input_output_type,
181
+ compute_type,
182
+ rocblas_gemm_flags_none,
183
+ nullptr,
184
+ &solution_size));
185
+ std::vector<int> solutions(solution_size);
186
+ // Get the list of available solutions
187
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
188
+ input_output_type,
189
+ input_output_type,
190
+ compute_type,
191
+ rocblas_gemm_flags_none,
192
+ solutions.data(),
193
+ &solution_size));
194
+ // Sort the solutions in ascending order to make the solution vector deterministic across runs
195
+ std::sort(solutions.begin(), solutions.end());
196
+
197
+ std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
198
+ for (size_t i = 0; i < solutions.size(); ++i) {
199
+ auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
200
+ ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
201
+ }
202
+ return ret;
203
+ }
204
+
205
+ template <typename T>
206
+ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
207
+ public:
208
+ RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
209
+
210
+ TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
211
+ auto input_output_type = RocBlasDataTypeFor<T>();
212
+ auto compute_type = RocBlasComputeTypeFor<T>();
213
+ auto h_a = DoCastForHalfOrBfloat16(params->alpha);
214
+ auto h_b = DoCastForHalfOrBfloat16(params->beta);
215
+ auto status = rocblas_gemm_strided_batched_ex(
216
+ (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
217
+ _rocblasOpFromChar(params->transa),
218
+ _rocblasOpFromChar(params->transb),
219
+ params->m, params->n, params->k,
220
+ &h_a,
221
+ params->a, input_output_type, params->lda, params->stride_a,
222
+ params->b, input_output_type, params->ldb, params->stride_b,
223
+ &h_b,
224
+ params->c, input_output_type, params->ldc, params->stride_c,
225
+ params->c, input_output_type, params->ldc, params->stride_c,
226
+ params->batch,
227
+ compute_type,
228
+ rocblas_gemm_algo_solution_index,
229
+ solution_,
230
+ rocblas_gemm_flags_none);
231
+ if (status != rocblas_status_success) {
232
+ return FAIL;
233
+ }
234
+ return OK;
235
+ }
236
+
237
+ private:
238
+ int solution_;
239
+ };
240
+
241
+ template <typename T>
242
+ auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
243
+ rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
244
+ int solution_size;
245
+ auto input_output_type = RocBlasDataTypeFor<T>();
246
+ auto compute_type = RocBlasComputeTypeFor<T>();
247
+ // Get the number of available solutions
248
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
249
+ input_output_type,
250
+ input_output_type,
251
+ compute_type,
252
+ rocblas_gemm_flags_none,
253
+ nullptr,
254
+ &solution_size));
255
+ std::vector<int> solutions(solution_size);
256
+ // Get the list of available solutions
257
+ TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
258
+ input_output_type,
259
+ input_output_type,
260
+ compute_type,
261
+ rocblas_gemm_flags_none,
262
+ solutions.data(),
263
+ &solution_size));
264
+ // Sort the solutions in ascending order to make the solution vector deterministic across runs
265
+ std::sort(solutions.begin(), solutions.end());
266
+
267
+ std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
268
+ for (size_t i = 0; i < solutions.size(); ++i) {
269
+ auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
270
+ ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
271
+ }
272
+ return ret;
273
+ }
274
+
275
+ } // namespace at::cuda::tunable
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Original TunableOp is from onnxruntime.
2
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4
+ // Copyright (c) Microsoft Corporation.
5
+ // Licensed under the MIT license.
6
+ //
7
+ // Adapting TunableOp into PyTorch
8
+ // Copyright (c) Advanced Micro Devices, Inc.
9
+ //
10
+ #pragma once
11
+
12
+ #include <c10/util/CallOnce.h>
13
+
14
+ #include <functional>
15
+ #include <iostream>
16
+ #include <memory>
17
+ #include <mutex>
18
+ #include <string>
19
+ #include <type_traits>
20
+ #include <unordered_map>
21
+ #include <utility>
22
+ #include <vector>
23
+
24
+ namespace at::cuda::tunable {
25
+
26
+ static void TunableLog(const std::string& msg) {
27
+ static const char *env = getenv("PYTORCH_TUNABLEOP_VERBOSE");
28
+ if (env != nullptr && strcmp(env, "1") == 0) {
29
+ std::cerr << msg << std::endl;
30
+ }
31
+ }
32
+ #define TUNABLE_LOG(...) TunableLog(c10::str(__VA_ARGS__))
33
+
34
+ enum TuningStatus {
35
+ OK = 0,
36
+ FAIL = 1,
37
+ UNSUPPORTED = 2,
38
+ };
39
+
40
+ // Mapping from params signature to kernel id
41
+ class ResultEntry {
42
+ public:
43
+ explicit ResultEntry(const std::string& key, double time) : key_(key), time_(time) {}
44
+ bool operator==(const ResultEntry& other) { return key_ == other.key_; }
45
+ bool operator!=(const ResultEntry& other) { return key_ != other.key_; }
46
+ operator std::string () { return key_; }
47
+ friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
48
+ static ResultEntry Null() { return ResultEntry("Null", 0.0); }
49
+ static ResultEntry Default() { return ResultEntry("Default", 0.0); }
50
+
51
+ private:
52
+ std::string key_;
53
+ double time_;
54
+ };
55
+
56
+ typedef std::unordered_map<std::string, ResultEntry> KernelMap;
57
+ typedef std::unordered_map<std::string, KernelMap> ResultsMap;
58
+
59
+ struct TuningResults {
60
+ // Validates if these results are compatible with the libraries
61
+ std::unordered_map<std::string, std::string> validators;
62
+
63
+ // Mapping from Callable signature to Callable's tuning result
64
+ ResultsMap results;
65
+ };
66
+
67
+ class TuningResultsManager {
68
+ public:
69
+ TuningResultsManager() = default;
70
+ ~TuningResultsManager() = default;
71
+
72
+ KernelMap Lookup(const std::string& op_signature);
73
+
74
+ ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
75
+
76
+ inline void AddImpl(const std::string& op_signature,
77
+ const std::string& params_signature,
78
+ ResultEntry best,
79
+ KernelMap& kernel_map);
80
+
81
+ void Add(const std::string& op_signature,
82
+ const std::string& params_signature,
83
+ ResultEntry best);
84
+
85
+ void Delete(const std::string& op_signature, const std::string& params_signature);
86
+
87
+ inline void DisjointMergeImpl(
88
+ const std::string& op_signature,
89
+ const KernelMap& kernel_map,
90
+ /*out*/ ResultsMap& results);
91
+
92
+ void Load(const ResultsMap& results_to_load);
93
+
94
+ ResultsMap Dump();
95
+
96
+ void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
97
+
98
+ size_t GetSize();
99
+
100
+ private:
101
+ std::mutex lock_;
102
+ ResultsMap results_;
103
+ };
104
+
105
+ class TuningResultsValidator {
106
+ public:
107
+ using GetFunc = std::function<std::string()>;
108
+ using ValidateFunc = std::function<TuningStatus(const std::string&)>;
109
+ using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
110
+
111
+ TuningResultsValidator();
112
+ ~TuningResultsValidator() = default;
113
+
114
+ std::unordered_map<std::string, std::string> GetAllValidators() const;
115
+ TuningStatus ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
116
+ void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
117
+
118
+ protected:
119
+ std::string GetPyTorchVersion() const;
120
+ TuningStatus ValidatePyTorchVersion(const std::string& value) const;
121
+
122
+ public:
123
+ static constexpr const std::array mandatory_keys{"PT_VERSION"};
124
+
125
+ private:
126
+ GetValidateFuncs validators_;
127
+ };
128
+
129
+ class TuningContext {
130
+ public:
131
+ TuningContext();
132
+ ~TuningContext();
133
+ TuningContext(TuningContext &) = delete;
134
+ TuningContext(TuningContext &&) = delete;
135
+ TuningContext &operator=(TuningContext &) = delete;
136
+ TuningContext &operator=(TuningContext &&) = delete;
137
+
138
+ void EnableTunableOp();
139
+ void DisableTunableOp();
140
+ bool IsTunableOpEnabled() const;
141
+
142
+ void EnableTuning();
143
+ void DisableTuning();
144
+ bool IsTuningEnabled() const;
145
+
146
+ void SetMaxTuningDurationMs(int max_duration_ms);
147
+ int GetMaxTuningDurationMs() const;
148
+
149
+ void SetMaxTuningIterations(int max_iter);
150
+ int GetMaxTuningIterations() const;
151
+
152
+ void SetMaxWarmupDurationMs(int max_duration_ms);
153
+ int GetMaxWarmupDurationMs() const;
154
+
155
+ void SetMaxWarmupIterations(int max_iter);
156
+ int GetMaxWarmupIterations() const;
157
+
158
+ void EnableTunableOpAndTuning();
159
+ void DisableTunableOpAndTuning();
160
+
161
+ TuningResultsManager& GetTuningResultsManager();
162
+
163
+ TuningResultsValidator& GetTuningResultsValidator();
164
+
165
+ TuningResults GetTuningResults();
166
+
167
+ TuningStatus LoadTuningResults(const TuningResults& tr);
168
+
169
+ void SetFilename(const std::string& filename);
170
+ std::string GetFilename() const;
171
+
172
+ protected:
173
+ bool ReadFile(const std::string& filename);
174
+ bool WriteFile(const std::string& filename);
175
+
176
+ private:
177
+ bool enable_;
178
+ bool tuning_enable_;
179
+ bool manager_initialized_;
180
+ int max_tuning_duration_ms_;
181
+ int max_tuning_iterations_;
182
+ int max_warmup_duration_ms_;
183
+ int max_warmup_iterations_;
184
+ mutable TuningResultsManager manager_;
185
+ mutable c10::once_flag manager_init_once_;
186
+ TuningResultsValidator validator_;
187
+ std::string filename_;
188
+ size_t results_count_from_input_file_;
189
+ };
190
+
191
+ TuningContext* getTuningContext();
192
+
193
+ class ITimer {
194
+ public:
195
+ ITimer() = default;
196
+ virtual ~ITimer() = default;
197
+
198
+ virtual void Start() = 0;
199
+ virtual void End() = 0;
200
+
201
+ /// Computes the elapsed time in milliseconds between Start() and End()
202
+ virtual float Duration() = 0;
203
+ };
204
+
205
+ } // namespace at::cuda::tunable
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Original TunableOp is from onnxruntime.
2
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4
+ // Copyright (c) Microsoft Corporation.
5
+ // Licensed under the MIT license.
6
+ //
7
+ // Adapting TunableOp into PyTorch
8
+ // Copyright (c) Advanced Micro Devices, Inc.
9
+ //
10
+ #pragma once
11
+
12
+ #include <ATen/cuda/tunable/GemmCommon.h>
13
+ #ifdef USE_ROCM
14
+ #if ROCM_VERSION >= 50700
15
+ #include <ATen/cuda/tunable/GemmHipblaslt.h>
16
+ #endif
17
+ #include <ATen/cuda/tunable/GemmRocblas.h>
18
+ #endif
19
+ #include <ATen/cuda/tunable/StreamTimer.h>
20
+ #include <ATen/cuda/tunable/TunableOp.h>
21
+ #include <c10/cuda/CUDACachingAllocator.h>
22
+ #include <c10/util/StringUtil.h>
23
+
24
+ #ifdef USE_ROCM
25
+ #include <rocm-core/rocm_version.h>
26
+ #endif
27
+
28
+ #define STRINGIFY(s) #s
29
+ #define XSTRINGIFY(s) STRINGIFY(s)
30
+
31
+ namespace at::cuda::tunable {
32
+
33
+ template <typename T>
34
+ class DefaultGemmOp : public Callable<GemmParams<T>> {
35
+ public:
36
+ TuningStatus Call(const GemmParams<T>* params) override {
37
+ at::cuda::blas::gemm_internal<T>(
38
+ params->transa, params->transb,
39
+ params->m, params->n, params->k,
40
+ params->alpha,
41
+ params->a, params->lda,
42
+ params->b, params->ldb,
43
+ params->beta,
44
+ params->c, params->ldc);
45
+ return OK;
46
+ }
47
+ };
48
+
49
+ template <typename T>
50
+ class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
51
+ public:
52
+ TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
53
+ at::cuda::blas::bgemm_internal<T>(
54
+ params->transa, params->transb,
55
+ params->m, params->n, params->k,
56
+ params->alpha,
57
+ params->a, params->lda, params->stride_a,
58
+ params->b, params->ldb, params->stride_b,
59
+ params->beta,
60
+ params->c, params->ldc, params->stride_c,
61
+ params->batch);
62
+ return OK;
63
+ }
64
+ };
65
+
66
+ template <typename T>
67
+ bool IsZero(T v) {
68
+ return v == 0.0f;
69
+ }
70
+
71
+ template <>
72
+ bool IsZero(BFloat16 v) {
73
+ return v.x == 0;
74
+ }
75
+
76
+ template <>
77
+ bool IsZero(Half v) {
78
+ return float(v) == 0.0f;
79
+ }
80
+
81
+ template <>
82
+ bool IsZero(c10::complex<double> v) {
83
+ return v == 0.0;
84
+ }
85
+
86
+ template <>
87
+ bool IsZero(c10::complex<float> v) {
88
+ return v == 0.0f;
89
+ }
90
+
91
+ template <typename T>
92
+ std::string TypeName(T v) {
93
+ return "unknown";
94
+ }
95
+
96
+ template <>
97
+ std::string TypeName(float v) {
98
+ return "float";
99
+ }
100
+
101
+ template <>
102
+ std::string TypeName(double v) {
103
+ return "double";
104
+ }
105
+
106
+ template <>
107
+ std::string TypeName(BFloat16 v) {
108
+ return "BFloat16";
109
+ }
110
+
111
+ template <>
112
+ std::string TypeName(Half v) {
113
+ return "Half";
114
+ }
115
+
116
+ template <>
117
+ std::string TypeName(c10::complex<double> v) {
118
+ return "c10::complex<double>";
119
+ }
120
+
121
+ template <>
122
+ std::string TypeName(c10::complex<float> v) {
123
+ return "c10::complex<float>";
124
+ }
125
+
126
+
127
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
128
+ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
129
+ public:
130
+ GemmTunableOp() {
131
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
132
+
133
+ auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
134
+
135
+ #ifdef USE_ROCM
136
+ for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
137
+ this->RegisterOp(std::move(name), std::move(op));
138
+ }
139
+
140
+ if (validators.find("ROCM_VERSION") == validators.end()) {
141
+ std::string rocm_version = ROCM_BUILD_INFO;
142
+ getTuningContext()->GetTuningResultsValidator().RegisterValidator(
143
+ "ROCM_VERSION",
144
+ [rocm_version]() { return rocm_version; },
145
+ [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
146
+ }
147
+
148
+ if (validators.find("GCN_ARCH_NAME") == validators.end()) {
149
+ std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName;
150
+ getTuningContext()->GetTuningResultsValidator().RegisterValidator(
151
+ "GCN_ARCH_NAME",
152
+ [gcn_arch_name]() { return gcn_arch_name; },
153
+ [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
154
+ }
155
+
156
+ if (validators.find("ROCBLAS_VERSION") == validators.end()) {
157
+ std::string rocblas_version = c10::str(
158
+ XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
159
+ XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
160
+ XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
161
+ XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
162
+ getTuningContext()->GetTuningResultsValidator().RegisterValidator(
163
+ "ROCBLAS_VERSION",
164
+ [rocblas_version]() { return rocblas_version; },
165
+ [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
166
+ }
167
+ #endif
168
+
169
+ #if defined(USE_ROCM) && ROCM_VERSION >= 50700
170
+ static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
171
+ if (env == nullptr || strcmp(env, "1") == 0) {
172
+ // disallow tuning of hipblaslt with c10::complex
173
+ if constexpr (
174
+ !std::is_same_v<T, c10::complex<float>> &&
175
+ !std::is_same_v<T, c10::complex<double>>) {
176
+ for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
177
+ this->RegisterOp(std::move(name), std::move(op));
178
+ }
179
+ }
180
+
181
+ if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
182
+ std::string hipblaslt_version = c10::str(
183
+ XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".",
184
+ XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".",
185
+ XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-",
186
+ XSTRINGIFY(HIPBLASLT_VERSION_TWEAK));
187
+ getTuningContext()->GetTuningResultsValidator().RegisterValidator(
188
+ "HIPBLASLT_VERSION",
189
+ [hipblaslt_version]() { return hipblaslt_version; },
190
+ [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
191
+ }
192
+ }
193
+ #endif
194
+ }
195
+
196
+ std::string Signature() override {
197
+ return c10::str("GemmTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
198
+ }
199
+ };
200
+
201
+ template <typename T, BlasOp ALayout, BlasOp BLayout>
202
+ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
203
+ public:
204
+ GemmStridedBatchedTunableOp() {
205
+ this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
206
+
207
+ auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
208
+
209
+ #ifdef USE_ROCM
210
+ for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
211
+ this->RegisterOp(std::move(name), std::move(op));
212
+ }
213
+
214
+ if (validators.find("ROCM_VERSION") == validators.end()) {
215
+ std::string rocm_version = ROCM_BUILD_INFO;
216
+ getTuningContext()->GetTuningResultsValidator().RegisterValidator(
217
+ "ROCM_VERSION",
218
+ [rocm_version]() { return rocm_version; },
219
+ [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
220
+ }
221
+
222
+ if (validators.find("GCN_ARCH_NAME") == validators.end()) {
223
+ std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName;
224
+ getTuningContext()->GetTuningResultsValidator().RegisterValidator(
225
+ "GCN_ARCH_NAME",
226
+ [gcn_arch_name]() { return gcn_arch_name; },
227
+ [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
228
+ }
229
+
230
+ if (validators.find("ROCBLAS_VERSION") == validators.end()) {
231
+ std::string rocblas_version = c10::str(
232
+ XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
233
+ XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
234
+ XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
235
+ XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
236
+ getTuningContext()->GetTuningResultsValidator().RegisterValidator(
237
+ "ROCBLAS_VERSION",
238
+ [rocblas_version]() { return rocblas_version; },
239
+ [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
240
+ }
241
+ #endif
242
+
243
+ #if defined(USE_ROCM) && ROCM_VERSION >= 50700
244
+ static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
245
+ if (env == nullptr || strcmp(env, "1") == 0) {
246
+ // disallow tuning of hipblaslt with c10::complex
247
+ if constexpr (
248
+ !std::is_same_v<T, c10::complex<float>> &&
249
+ !std::is_same_v<T, c10::complex<double>>) {
250
+ for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
251
+ this->RegisterOp(std::move(name), std::move(op));
252
+ }
253
+ }
254
+
255
+ if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
256
+ std::string hipblaslt_version = c10::str(
257
+ XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".",
258
+ XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".",
259
+ XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-",
260
+ XSTRINGIFY(HIPBLASLT_VERSION_TWEAK));
261
+ getTuningContext()->GetTuningResultsValidator().RegisterValidator(
262
+ "HIPBLASLT_VERSION",
263
+ [hipblaslt_version]() { return hipblaslt_version; },
264
+ [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
265
+ }
266
+ }
267
+ #endif
268
+ }
269
+
270
+ std::string Signature() override {
271
+ return c10::str("GemmStridedBatchedTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
272
+ }
273
+ };
274
+
275
+ #undef XSTRINGIFY
276
+ #undef STRINGIFY
277
+
278
+ } // namespace at::cuda::tunable
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Original TunableOp is from onnxruntime.
2
+ // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3
+ // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4
+ // Copyright (c) Microsoft Corporation.
5
+ // Licensed under the MIT license.
6
+ //
7
+ // Adapting TunableOp into PyTorch
8
+ // Copyright (c) Advanced Micro Devices, Inc.
9
+ //
10
+ #pragma once
11
+
12
+ #include <ATen/cuda/tunable/Tunable.h>
13
+ #include <c10/cuda/CUDACachingAllocator.h>
14
+
15
+ #ifndef _WIN32
16
+ #include <cxxabi.h>
17
+ #endif
18
+
19
+ #include <string>
20
+ #include <type_traits>
21
+ #include <unordered_map>
22
+ #include <vector>
23
+
24
+ namespace at::cuda::tunable {
25
+
26
+ template <typename ParamsT>
27
+ class Callable {
28
+ public:
29
+ Callable() = default;
30
+ Callable(Callable&&) = default;
31
+ virtual ~Callable() = default;
32
+ virtual TuningStatus Call(const ParamsT*) {
33
+ return FAIL;
34
+ }
35
+ virtual TuningStatus IsSupported(const ParamsT* params) {
36
+ return Call(params);
37
+ }
38
+ };
39
+
40
+ template <typename ParamsT, typename TimerT>
41
+ class TunableOp {
42
+ public:
43
+ TunableOp() = default;
44
+ TunableOp(TunableOp&&) = default;
45
+ virtual ~TunableOp() = default;
46
+
47
+ TuningStatus operator()(const ParamsT* params) {
48
+ ResultEntry result = ResultEntry::Null();
49
+ TuningContext* ctx = getTuningContext();
50
+ if (ctx->IsTunableOpEnabled()) {
51
+ auto& mgr = ctx->GetTuningResultsManager();
52
+ auto op_sig = Signature();
53
+ auto params_sig = params->Signature();
54
+ result = mgr.Lookup(op_sig, params_sig);
55
+ // If there is not previous tuning result been found, we do the tuning iff tuning is enabled
56
+ if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) {
57
+ result = FindFastest(params);
58
+ mgr.Add(op_sig, params_sig, result);
59
+ }
60
+ }
61
+ else {
62
+ result = ResultEntry::Default();
63
+ }
64
+ if (result == ResultEntry::Null()) {
65
+ TUNABLE_LOG("no result, using default");
66
+ result = ResultEntry::Default();
67
+ }
68
+ auto iter = ops_.find(result);
69
+ TORCH_CHECK(iter != ops_.end());
70
+ return iter->second->Call(params);
71
+ }
72
+
73
+ virtual std::string Signature() {
74
+ // According to C++17 standard https://wg21.link/n4659 section 15.7.4
75
+ // > if the operand of typeid refers to the
76
+ // > object under construction or destruction, typeid yields the std::type_info object representing the constructor
77
+ // > or destructor’s class.
78
+ // So delay the op signature generation.
79
+ c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
80
+ return signature_;
81
+ }
82
+
83
+ protected:
84
+ void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) {
85
+ this->op_names_.emplace_back(name);
86
+ this->ops_.emplace(name, std::move(op));
87
+ }
88
+
89
+ private:
90
+ static void WarmUp(Callable<ParamsT> *op, ParamsT* param, size_t num_iter) {
91
+ for (size_t i = 0; i < num_iter; i++) {
92
+ TORCH_CHECK(op->Call(param) == OK);
93
+ }
94
+ }
95
+
96
+ static double Profile(Callable<ParamsT> *op, ParamsT* param, size_t num_iter) {
97
+ TimerT timer{};
98
+ timer.Start();
99
+ for (size_t i = 0; i < num_iter; i++) {
100
+ TORCH_CHECK(op->Call(param) == OK);
101
+ }
102
+ timer.End();
103
+ return timer.Duration() / num_iter;
104
+ }
105
+
106
+ protected:
107
+ bool IsNumericsCheckEnabled() {
108
+ static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
109
+ if (env != nullptr && strcmp(env, "0") == 0) {
110
+ return false;
111
+ }
112
+ return true;
113
+ }
114
+
115
+ virtual ResultEntry FindFastest(const ParamsT* params) {
116
+ TuningContext* ctx = getTuningContext();
117
+ auto op_sig = Signature();
118
+ auto params_sig = params->Signature();
119
+ TUNABLE_LOG("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
120
+ auto min_duration_ms = std::numeric_limits<double>::infinity();
121
+ std::string id_name = "Default";
122
+
123
+ // calcaulte a reference answer for numerical check
124
+ ParamsT* reference_params = params->DeepCopy();
125
+ TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
126
+
127
+ // need a copy of params to reuse
128
+ ParamsT* reusable_params = params->DeepCopy();
129
+
130
+ for (size_t i = 0; i < op_names_.size(); i++) {
131
+ auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
132
+ auto status = candidate->Call(reusable_params);
133
+ if (status != OK) {
134
+ TUNABLE_LOG("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
135
+ continue;
136
+ }
137
+
138
+ if (IsNumericsCheckEnabled()) {
139
+ ParamsT* numerical_params = params->DeepCopy();
140
+ WarmUp(candidate, numerical_params, 1);
141
+ status = reference_params->NumericalCheck(numerical_params);
142
+ numerical_params->Delete();
143
+ if (status != OK) {
144
+ TUNABLE_LOG("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
145
+ continue;
146
+ }
147
+ }
148
+
149
+ // collect a small profile
150
+ constexpr const int approx_num_iter = 3;
151
+ auto approx_duration = Profile(candidate, reusable_params, approx_num_iter);
152
+ // bail if too slow
153
+ if (approx_duration > 2 * min_duration_ms) {
154
+ TUNABLE_LOG("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
155
+ continue;
156
+ }
157
+
158
+ // for warmup does user set max duration, max iters, or both?
159
+ double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
160
+ int max_warmup_iter = ctx->GetMaxWarmupIterations();
161
+ int warmup_iter = 1; // default
162
+ if (max_warmup_duration > 0) {
163
+ int duration_iters = max_warmup_duration / approx_duration;
164
+ if (max_warmup_iter > 0) {
165
+ warmup_iter = std::min(max_warmup_iter, duration_iters);
166
+ }
167
+ else {
168
+ warmup_iter = duration_iters;
169
+ }
170
+ }
171
+ else if (max_warmup_iter > 0) {
172
+ warmup_iter = max_warmup_iter;
173
+ }
174
+
175
+ // for tuning does user set max duration, max iters, or both?
176
+ double max_tuning_duration = ctx->GetMaxTuningDurationMs();
177
+ int max_tuning_iter = ctx->GetMaxTuningIterations();
178
+ int tuning_iter = 100; // default
179
+ if (max_tuning_duration > 0) {
180
+ int duration_iters = max_tuning_duration / approx_duration;
181
+ if (max_tuning_iter > 0) {
182
+ tuning_iter = std::min(max_tuning_iter, duration_iters);
183
+ }
184
+ else {
185
+ tuning_iter = duration_iters;
186
+ }
187
+ }
188
+ else if (max_tuning_iter > 0) {
189
+ tuning_iter = max_tuning_iter;
190
+ }
191
+
192
+ // do the full warmup followed by tuning
193
+ double warmup_ms = warmup_iter * approx_duration;
194
+ double tuning_ms = tuning_iter * approx_duration;
195
+ TUNABLE_LOG("├──tuning using "
196
+ "warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
197
+ "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
198
+ "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
199
+ WarmUp(candidate, reusable_params, warmup_iter);
200
+ auto duration_ms = Profile(candidate, reusable_params, tuning_iter);
201
+ if (duration_ms < min_duration_ms) {
202
+ TUNABLE_LOG("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
203
+ min_duration_ms = duration_ms;
204
+ id_name = op_names_[i];
205
+ }
206
+ }
207
+
208
+ reusable_params->Delete();
209
+ reference_params->Delete();
210
+
211
+ TUNABLE_LOG("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
212
+ return ResultEntry(id_name, min_duration_ms);
213
+ }
214
+
215
+ private:
216
+ std::string CreateSignature() {
217
+ #ifndef _WIN32
218
+ const auto* name = typeid(*this).name();
219
+ char buf[256];
220
+ size_t buf_len = 256;
221
+ abi::__cxa_demangle(name, buf, &buf_len, nullptr);
222
+ buf[255] = '\0';
223
+ return buf;
224
+ #else
225
+ return typeid(*this).name();
226
+ #endif
227
+ }
228
+
229
+ mutable c10::once_flag signature_init_once_;
230
+ std::string signature_;
231
+
232
+ std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_;
233
+ std::vector<std::string> op_names_;
234
+ };
235
+
236
+ struct OpParams {
237
+ OpParams() {}
238
+ virtual ~OpParams() = default;
239
+ virtual std::string Signature() const = 0;
240
+ };
241
+
242
+ } // namespace at::cuda::tunable
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/TensorBase.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/util/TypeSafeSignMath.h>
7
+
8
+
9
+ namespace at {
10
+ struct TensorIterator;
11
+ struct TensorIteratorBase;
12
+ }
13
+
14
+ namespace at::native {
15
+
16
+ inline void alpha_check(const ScalarType dtype, const Scalar& alpha) {
17
+ TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
18
+ "Boolean alpha only supported for Boolean results.");
19
+ TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
20
+ || alpha.isIntegral(true),
21
+ "For integral input tensors, argument alpha must not be a floating point number.");
22
+ TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(),
23
+ "For non-complex input tensors, argument alpha must not be a complex number.")
24
+ }
25
+
26
+ // Basic checking for all sub functions.
27
+ inline void sub_check(const TensorBase& self, const TensorBase& other) {
28
+ TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool,
29
+ "Subtraction, the `-` operator, with two bool tensors is not supported. "
30
+ "Use the `^` or `logical_xor()` operator instead.")
31
+ TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool,
32
+ "Subtraction, the `-` operator, with a bool tensor is not supported. "
33
+ "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
34
+ }
35
+
36
+ inline void sub_check(const TensorBase& self, const Scalar& scalar) {
37
+ TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(),
38
+ "Subtraction, the `-` operator, with two bool tensors is not supported. "
39
+ "Use the `^` or `logical_xor()` operator instead.")
40
+ TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(),
41
+ "Subtraction, the `-` operator, with a bool tensor is not supported. "
42
+ "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
43
+ }
44
+
45
+ using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
46
+ using structured_binary_fn_double = void(*)(TensorIteratorBase&, double);
47
+ using structured_binary_fn = void(*)(TensorIteratorBase&);
48
+
49
+ using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
50
+ using binary_fn_double = void(*)(TensorIterator&, double);
51
+ using binary_fn = void(*)(TensorIterator&);
52
+ using binary_clamp_fn_alpha =
53
+ void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val);
54
+
55
+ // NB: codegenned
56
+ DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
57
+
58
+ DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
59
+ DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub);
60
+ DECLARE_DISPATCH(structured_binary_fn, mul_stub);
61
+ DECLARE_DISPATCH(structured_binary_fn, div_true_stub);
62
+ DECLARE_DISPATCH(structured_binary_fn, div_floor_stub);
63
+ DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub);
64
+ DECLARE_DISPATCH(structured_binary_fn, atan2_stub);
65
+ DECLARE_DISPATCH(structured_binary_fn, remainder_stub);
66
+ DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub);
67
+ DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub);
68
+ DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub);
69
+ DECLARE_DISPATCH(structured_binary_fn, lshift_stub);
70
+ DECLARE_DISPATCH(structured_binary_fn, rshift_stub);
71
+ DECLARE_DISPATCH(binary_fn, logical_xor_stub);
72
+ DECLARE_DISPATCH(binary_fn, logical_and_stub);
73
+ DECLARE_DISPATCH(binary_fn, logical_or_stub);
74
+ DECLARE_DISPATCH(structured_binary_fn, lt_stub);
75
+ DECLARE_DISPATCH(structured_binary_fn, le_stub);
76
+ DECLARE_DISPATCH(structured_binary_fn, gt_stub);
77
+ DECLARE_DISPATCH(structured_binary_fn, ge_stub);
78
+ DECLARE_DISPATCH(structured_binary_fn, eq_stub);
79
+ DECLARE_DISPATCH(structured_binary_fn, ne_stub);
80
+ DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
81
+ DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
82
+ DECLARE_DISPATCH(structured_binary_fn, maximum_stub);
83
+ DECLARE_DISPATCH(structured_binary_fn, minimum_stub);
84
+ DECLARE_DISPATCH(structured_binary_fn, fmax_stub);
85
+ DECLARE_DISPATCH(structured_binary_fn, fmin_stub);
86
+ DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub);
87
+ DECLARE_DISPATCH(binary_fn_double, huber_stub);
88
+ DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub);
89
+ DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
90
+ DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub);
91
+ DECLARE_DISPATCH(structured_binary_fn, mse_stub);
92
+ DECLARE_DISPATCH(structured_binary_fn, fmod_stub);
93
+ DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub);
94
+ DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub);
95
+ DECLARE_DISPATCH(structured_binary_fn, gcd_stub);
96
+ DECLARE_DISPATCH(structured_binary_fn, lcm_stub);
97
+ DECLARE_DISPATCH(structured_binary_fn, hypot_stub);
98
+ DECLARE_DISPATCH(structured_binary_fn, igamma_stub);
99
+ DECLARE_DISPATCH(structured_binary_fn, igammac_stub);
100
+ DECLARE_DISPATCH(structured_binary_fn, nextafter_stub);
101
+ DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
102
+ DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
103
+ DECLARE_DISPATCH(structured_binary_fn, xlogy_stub);
104
+ DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
105
+ DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
106
+ DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub);
107
+ DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub);
108
+ DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub);
109
+ DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub);
110
+ DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub);
111
+ DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub);
112
+ DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub);
113
+ DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub);
114
+ DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub);
115
+ DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub);
116
+ DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub);
117
+ DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub);
118
+
119
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUFallback.h ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/ivalue.h>
4
+ #include <ATen/core/stack.h>
5
+ #include <ATen/core/boxing/KernelFunction.h>
6
+ #include <ATen/core/dispatch/Dispatcher.h>
7
+ #include <c10/util/Metaprogramming.h>
8
+ #include <torch/library.h>
9
+
10
+ namespace at::native {
11
+
12
+ // This function implements a boxed fallback to CPU.
13
+ // External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
14
+ TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false);
15
+
16
+ // This is a helper function that backends can use to directly call their boxed CPU fallback
17
+ // TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
18
+ template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
19
+ struct _call_fallback_fn final {};
20
+
21
+ template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
22
+ struct _call_fallback_fn<fallback_fn, Op, symint, ReturnType(ParameterTypes...)> final {
23
+ static ReturnType call(typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
24
+ auto op = c10::Dispatcher::singleton()
25
+ // TODO: figure out how to make compiler happy without dynamic casts
26
+ .findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
27
+ //.findSchemaOrThrow("a", "b")
28
+ .typed<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>();
29
+ return c10::impl::BoxedKernelWrapper<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>::call(
30
+ c10::BoxedKernel::makeFromFunction<fallback_fn>(),
31
+ op,
32
+ c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
33
+ // TODO: get std::forward<> to work
34
+ args...
35
+ );
36
+ }
37
+ };
38
+
39
+ template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
40
+ using call_fallback_fn_symint = _call_fallback_fn<fallback_fn, Op, true, typename Op::schema>;
41
+
42
+ template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
43
+ using call_fallback_fn = _call_fallback_fn<fallback_fn, Op, false, typename Op::schema>;
44
+
45
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/macros/Export.h>
3
+ #include <limits>
4
+
5
+ namespace at {
6
+ class TensorBase;
7
+ }
8
+
9
+ namespace at::native {
10
+
11
+ TORCH_API bool canUse32BitIndexMath(const at::TensorBase &t, int64_t max_elem=std::numeric_limits<int32_t>::max());
12
+
13
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <c10/util/irange.h>
5
+
6
+ #ifndef AT_PER_OPERATOR_HEADERS
7
+ #include <ATen/NativeFunctions.h>
8
+ #else
9
+ #include <ATen/ops/view_as_real_native.h>
10
+ #include <ATen/ops/view_as_complex_native.h>
11
+
12
+ #include <utility>
13
+ #endif
14
+
15
+ // WARNING: this header contains non-inline functions and should be only
16
+ // included from ONE cpp file
17
+
18
+ namespace at::native {
19
+
20
+ // View tensor with new dtype, storage offset, sizes and strides
21
+ inline Tensor view_tensor(
22
+ const Tensor &tensor, ScalarType dtype,
23
+ c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) {
24
+ Storage storage = tensor.storage();
25
+ auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
26
+ auto new_tensor = detail::make_tensor<TensorImpl>(
27
+ c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
28
+ auto * impl = new_tensor.unsafeGetTensorImpl();
29
+ impl->set_sizes_and_strides(sizes, strides, offset);
30
+ return new_tensor;
31
+ }
32
+
33
+ inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) {
34
+ SymDimVector res(oldstride.size() + 1);
35
+ for (const auto i : c10::irange(oldstride.size())) {
36
+ res[i] = oldstride[i] * 2;
37
+ }
38
+ res.back() = 1;
39
+ return res;
40
+ }
41
+
42
+ inline Tensor _view_as_real_physical(const Tensor& self) {
43
+ TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
44
+ auto old_sizes = self.sym_sizes();
45
+ SymDimVector new_sizes(old_sizes.size() + 1);
46
+ std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
47
+ // last dimension will always have two elements containing the real and imag vals
48
+ new_sizes.back() = 2;
49
+ auto new_strides = computeStrideForViewAsReal(self.sym_strides());
50
+ auto new_storage_offset = self.sym_storage_offset() * 2;
51
+ const auto float_type = c10::toRealValueType(self.scalar_type());
52
+ auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
53
+ return real_tensor;
54
+ }
55
+
56
+ // expects as input a complex tensor and returns back a tensor
57
+ // with corresponding real dtype containing the complex values
58
+ // in the last two dimensions
59
+ Tensor view_as_real(const Tensor& self) {
60
+ TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
61
+ return _view_as_real_physical(self);
62
+ }
63
+
64
+ inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) {
65
+ const int64_t dim = oldstride.size();
66
+ TORCH_CHECK(oldstride[dim-1] == 1, "Tensor must have a last dimension with stride 1");
67
+
68
+ SymDimVector res(dim - 1);
69
+ for (const auto i : c10::irange(res.size())) {
70
+ TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension");
71
+ res[i] = oldstride[i] / 2;
72
+ }
73
+ return res;
74
+ }
75
+
76
+ // expects as input a float or double tensor with last dimension of size 2
77
+ // and returns back a tensor with corresponding complex dtype
78
+ Tensor view_as_complex(const Tensor& self) {
79
+ TORCH_CHECK(
80
+ self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf,
81
+ "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
82
+
83
+ auto old_sizes = self.sym_sizes();
84
+ TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions");
85
+ TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2");
86
+ SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1);
87
+
88
+ const auto new_strides = computeStrideForViewAsComplex(self.sym_strides());
89
+ const auto complex_type = c10::toComplexType(self.scalar_type());
90
+
91
+ TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2");
92
+ const auto new_storage_offset = self.sym_storage_offset() / 2;
93
+
94
+ return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);
95
+ }
96
+
97
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distributions.h ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/Math.h>
4
+ #include <c10/macros/Macros.h>
5
+ #include <c10/util/MathConstants.h>
6
+
7
+ // ROCM hcc doesn't work well with using std:: in kernel functions
8
+ #if defined(__CUDA_ARCH__)
9
+ #include <c10/cuda/CUDAMathCompat.h>
10
+ #define compat_exp c10::cuda::compat::exp
11
+ #define compat_ceil c10::cuda::compat::ceil
12
+ #define compat_floor c10::cuda::compat::floor
13
+ #define compat_log c10::cuda::compat::log
14
+ #define compat_pow c10::cuda::compat::pow
15
+ #define compat_sqrt c10::cuda::compat::sqrt
16
+ #define compat_tan c10::cuda::compat::tan
17
+ #define compat_abs c10::cuda::compat::abs
18
+ #define compat_log1p c10::cuda::compat::log1p
19
+ #elif defined(__HIPCC__)
20
+ #include <c10/hip/HIPMathCompat.h>
21
+ #define compat_exp c10::hip::compat::exp
22
+ #define compat_ceil c10::hip::compat::ceil
23
+ #define compat_floor c10::hip::compat::floor
24
+ #define compat_log c10::hip::compat::log
25
+ #define compat_pow c10::hip::compat::pow
26
+ #define compat_sqrt c10::hip::compat::sqrt
27
+ #define compat_tan c10::hip::compat::tan
28
+ #define compat_abs c10::hip::compat::abs
29
+ #define compat_log1p c10::hip::compat::log1p
30
+ #else
31
+ #define compat_exp std::exp
32
+ #define compat_ceil std::ceil
33
+ #define compat_floor std::floor
34
+ #define compat_log std::log
35
+ #define compat_pow std::pow
36
+ #define compat_sqrt std::sqrt
37
+ #define compat_tan std::tan
38
+ #define compat_abs std::abs
39
+ #define compat_log1p std::log1p
40
+ #endif
41
+
42
+ namespace {
43
+
44
+ #if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
45
+ // we cannot use std::isnan directly due to some incompatibility of
46
+ // gcc constexpr'ing and nvcc
47
+ using std::isnan;
48
+ #endif
49
+
50
+ // Here sampler_t should be function type scalar_t(void). For gpu
51
+ // "sampler" is a device function, but since ROCM doesn't have
52
+ // equivalent to nvstd::function, we use a template type parameter to
53
+ // capture it.
54
+ template<typename scalar_t, typename sampler_t>
55
+ struct BaseSampler {
56
+ sampler_t sampler;
57
+ C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {}
58
+ C10_DEVICE scalar_t sample() {
59
+ return sampler();
60
+ }
61
+ };
62
+
63
+ // The function `sample_gamma` is
64
+ // is adapted from Numpy's distributions.c implementation.
65
+ // It is MIT licensed, so here is the copyright:
66
+
67
+ /* Copyright 2005 Robert Kern (robert.kern@gmail.com)
68
+ *
69
+ * Permission is hereby granted, free of charge, to any person obtaining a
70
+ * copy of this software and associated documentation files (the
71
+ * "Software"), to deal in the Software without restriction, including
72
+ * without limitation the rights to use, copy, modify, merge, publish,
73
+ * distribute, sublicense, and/or sell copies of the Software, and to
74
+ * permit persons to whom the Software is furnished to do so, subject to
75
+ * the following conditions:
76
+ *
77
+ * The above copyright notice and this permission notice shall be included
78
+ * in all copies or substantial portions of the Software.
79
+ *
80
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
81
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
82
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
83
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
84
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
85
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
86
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
87
+ */
88
+
89
+ template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t, typename normal_sampler_t>
90
+ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform, BaseSampler<accscalar_t, normal_sampler_t>& standard_normal) {
91
+ accscalar_t scale = 1.0f;
92
+
93
+ // Boost alpha for higher acceptance probability.
94
+ if (alpha < 1.0f) {
95
+ if (alpha == 0.f) return 0.f;
96
+ scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha);
97
+ alpha += 1.0f;
98
+ }
99
+
100
+ // This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
101
+ // doi:10.1145/358407.358414
102
+ const accscalar_t d = alpha - 1.0f / 3.0f;
103
+ const accscalar_t c = 1.0f / compat_sqrt(9.0f * d);
104
+ for (;;) {
105
+ accscalar_t x, y;
106
+ do {
107
+ x = standard_normal.sample();
108
+ y = 1.0f + c * x;
109
+ } while (y <= 0);
110
+ const accscalar_t v = y * y * y;
111
+ const accscalar_t u = 1 - standard_uniform.sample();
112
+ const accscalar_t xx = x * x;
113
+ if (u < 1.0f - 0.0331f * xx * xx)
114
+ return static_cast<scalar_t>(scale * d * v);
115
+ if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v)))
116
+ return static_cast<scalar_t>(scale * d * v);
117
+ }
118
+ }
119
+
120
+ /* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
121
+ * from TensorFlow's random_binomial_op.cc implementation. That code is under
122
+ * copyright: 2019 The TensorFlow Authors.
123
+ *
124
+ * It was released under the Apache License, Version 2.0 (the "License"), available at:
125
+ * http://www.apache.org/licenses/LICENSE-2.0
126
+ */
127
+
128
+ template<typename scalar_t>
129
+ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
130
+ const static scalar_t kTailValues[] = {
131
+ 0.0810614667953272,
132
+ 0.0413406959554092,
133
+ 0.0276779256849983,
134
+ 0.02079067210376509,
135
+ 0.0166446911898211,
136
+ 0.0138761288230707,
137
+ 0.0118967099458917,
138
+ 0.0104112652619720,
139
+ 0.00925546218271273,
140
+ 0.00833056343336287
141
+ };
142
+ if (k <= 9) {
143
+ return kTailValues[static_cast<size_t>(k)];
144
+ }
145
+ scalar_t kp1sq = (k + 1) * (k + 1);
146
+ return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
147
+ }
148
+
149
+
150
+ template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
151
+ C10_DEVICE scalar_t binomial_inversion(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
152
+ accscalar_t U;
153
+ accscalar_t geom_sum = 0;
154
+ scalar_t num_geom = 0;
155
+
156
+ accscalar_t logprob = compat_log1p(-prob);
157
+
158
+ while (1) {
159
+ U = standard_uniform.sample();
160
+ accscalar_t geom = compat_ceil(compat_log(U) / logprob);
161
+ geom_sum += geom;
162
+ if (geom_sum > count) {
163
+ break;
164
+ }
165
+ num_geom = num_geom + 1;
166
+ }
167
+ return num_geom;
168
+ }
169
+
170
+ template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
171
+ C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
172
+ scalar_t k;
173
+ accscalar_t U, V, us;
174
+
175
+ // This is spq in the paper.
176
+ const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));
177
+
178
+ // Other coefficients for Transformed Rejection sampling.
179
+ const accscalar_t b = 1.15 + 2.53 * stddev;
180
+ const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
181
+ const accscalar_t c = count * prob + 0.5;
182
+ const accscalar_t v_r = 0.92 - 4.2 / b;
183
+ const accscalar_t r = prob / (1 - prob);
184
+
185
+ const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
186
+ const accscalar_t m = compat_floor((count + 1) * prob);
187
+
188
+ while (1) {
189
+ U = standard_uniform.sample() - 0.5;
190
+ V = standard_uniform.sample();
191
+
192
+ us = 0.5 - compat_abs(U);
193
+ k = static_cast<scalar_t>(compat_floor((2 * a / us + b) * U + c));
194
+
195
+ // Reject non-sensical answers.
196
+ if (k < 0 || k > count) {
197
+ continue;
198
+ }
199
+ // Region for which the box is tight, and we can return our calculated value.
200
+ // This should happen 0.86 * v_r times. In the limit as n * p is large,
201
+ // the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
202
+ if (us >= 0.07 && V <= v_r) {
203
+ return k;
204
+ }
205
+
206
+ // This deviates from Hormann's BTRS algorithm, as there is a log missing.
207
+ // For all (u, v) pairs outside of the bounding box, this calculates the
208
+ // transformed-reject ratio.
209
+ V = compat_log(V * alpha / (a / (us * us) + b));
210
+ accscalar_t upperbound =
211
+ ((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
212
+ (count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
213
+ (k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
214
+ stirling_approx_tail<accscalar_t>(m) + stirling_approx_tail<accscalar_t>(count - m) -
215
+ stirling_approx_tail<accscalar_t>(k) - stirling_approx_tail<accscalar_t>(count - k));
216
+
217
+ if (V <= upperbound) {
218
+ return k;
219
+ }
220
+ }
221
+ }
222
+
223
+ template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
224
+ C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
225
+ if (count <= 0.0 || prob <= 0.0) {
226
+ return 0;
227
+ } else if (prob >= 1.0) {
228
+ return count;
229
+ } else if (prob <= 0.5) {
230
+ if (count * prob >= 10.0) {
231
+ // btrs
232
+ return btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
233
+ } else {
234
+ // binomial inversion
235
+ return binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
236
+ }
237
+ } else if (prob > 0.5) {
238
+ scalar_t qprob = 1.0 - prob;
239
+ if (count * qprob >= 10.0) {
240
+ // btrs
241
+ return count - btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
242
+ } else {
243
+ // count - binomial inversion
244
+ return count - binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
245
+ }
246
+ } else {
247
+ // prob is nan?
248
+ return static_cast<scalar_t>(NAN);
249
+ }
250
+ }
251
+
252
+ /*
253
+ * This function is derived from the implementation of the digamma function in the Cephes Math Library.
254
+ * See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
255
+ */
256
+ template<typename scalar_t, typename accscalar_t>
257
+ C10_DEVICE static inline scalar_t digamma_one(scalar_t x) {
258
+ constexpr accscalar_t PSI_10 = 2.25175258906672110764;
259
+ if (x == 0) {
260
+ return INFINITY;
261
+ }
262
+ accscalar_t additional_summand = 0;
263
+ int x_is_integer = x == compat_floor(x);
264
+ if (x < 0) {
265
+ if (x_is_integer) {
266
+ return INFINITY;
267
+ }
268
+ // it is more standard to write this as recursion, but
269
+ // nvcc does not like that
270
+ additional_summand = -c10::pi<scalar_t> /
271
+ compat_tan(c10::pi<scalar_t> * x);
272
+ x = 1 - x;
273
+ }
274
+
275
+ // Push x to be >= 10
276
+ accscalar_t result = 0;
277
+ while (x < 10) {
278
+ result -= 1 / x;
279
+ x += 1;
280
+ }
281
+ if (x == 10) {
282
+ return result + PSI_10 + additional_summand;
283
+ }
284
+
285
+ // Compute asymptotic digamma
286
+ static const accscalar_t A[] = {
287
+ 8.33333333333333333333E-2,
288
+ -2.10927960927960927961E-2,
289
+ 7.57575757575757575758E-3,
290
+ -4.16666666666666666667E-3,
291
+ 3.96825396825396825397E-3,
292
+ -8.33333333333333333333E-3,
293
+ 8.33333333333333333333E-2,
294
+ };
295
+
296
+ accscalar_t y = 0;
297
+ if (x < 1.0e17f) {
298
+ accscalar_t z = 1.0 / (x * x);
299
+ y = z * polevl<accscalar_t>(z, A, 6);
300
+ }
301
+ return static_cast<scalar_t>(
302
+ result + compat_log(x) - (0.5f / x) - y + additional_summand);
303
+ }
304
+
305
+ // Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
306
+ // for random number x drawn from a standard Gamma distribution Gamma(alpha).
307
+ template <typename scalar_t, typename accscalar_t>
308
+ C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
309
+ // Use a Taylor series expansion for small x.
310
+ accscalar_t x = static_cast<accscalar_t>(x_);
311
+ accscalar_t alpha = static_cast<accscalar_t>(alpha_);
312
+ if (x < 0.8f) {
313
+ accscalar_t numer = 1;
314
+ accscalar_t denom = alpha;
315
+ auto series1 = numer / denom;
316
+ auto series2 = numer / (denom * denom);
317
+ for (int i = 1; i <= 5; ++i) {
318
+ numer *= -x / static_cast<accscalar_t>(i);
319
+ denom += 1;
320
+ series1 += numer / denom;
321
+ series2 += numer / (denom * denom);
322
+ }
323
+ const auto pow_x_alpha = compat_pow(x, alpha);
324
+ const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x);
325
+ const auto gamma_cdf = pow_x_alpha * series1;
326
+ const auto gamma_cdf_alpha =
327
+ (compat_log(x) - digamma_one<accscalar_t, accscalar_t>(alpha)) *
328
+ gamma_cdf -
329
+ pow_x_alpha * series2;
330
+ const auto result = -gamma_cdf_alpha / gamma_pdf;
331
+ return isnan(result) ? static_cast<scalar_t>( 0.f ) : static_cast<scalar_t>(result);
332
+ }
333
+
334
+ // Use a Rice saddle point expansion for large alpha.
335
+ if (alpha > 8.0f) {
336
+ if (0.9f * alpha <= x && x <= 1.1f * alpha) {
337
+ const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
338
+ const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
339
+ - 65 * x * x / alpha + alpha * (107 + 3600 * x);
340
+ const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
341
+ return static_cast<scalar_t>(numer_1 * numer_2 / denom);
342
+ }
343
+ const auto denom = compat_sqrt(8 * alpha);
344
+ const auto term2 = denom / (alpha - x);
345
+ const auto term3 = compat_pow(
346
+ x - alpha - alpha * compat_log(x / alpha),
347
+ static_cast<accscalar_t>(-1.5));
348
+ const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
349
+ const auto term1 = compat_log(x / alpha) * term23 -
350
+ compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
351
+ const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
352
+ const auto numer = x * term1;
353
+ return static_cast<scalar_t>(-stirling * numer / denom);
354
+ }
355
+
356
+ // Use a bivariate rational approximation to the reparameterized gradient.
357
+ const auto u = compat_log(x / alpha);
358
+ const auto v = compat_log(alpha);
359
+ static const accscalar_t coef_uv[3][8] = {
360
+ {0.16009398, -0.094634809, 0.025146376, -0.0030648343,
361
+ 1, 0.32668115, 0.10406089, 0.0014179084},
362
+ {0.53487893, 0.1298071, 0.065735949, -0.0015649758,
363
+ 0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
364
+ {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
365
+ 0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
366
+ };
367
+ accscalar_t coef_v[8];
368
+ for (int i = 0; i < 8; ++ i) {
369
+ coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
370
+ }
371
+ const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
372
+ const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
373
+ return static_cast<scalar_t>(compat_exp(p / q));
374
+ }
375
+
376
+ // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
377
+ // Assumes x is close to zero and uses a Taylor expansion.
378
+ template <typename scalar_t, typename accscalar_t>
379
+ C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
380
+ const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha)
381
+ - digamma_one<scalar_t, accscalar_t>(alpha + beta) - compat_log(x);
382
+ scalar_t numer = 1;
383
+ scalar_t series = numer / alpha * (factor + 1 / alpha);
384
+ for (int i = 1; i <= 10; ++i) {
385
+ scalar_t casted_i = static_cast<scalar_t>(i);
386
+ numer *= (casted_i - beta) * x / casted_i;
387
+ const scalar_t denom = alpha + casted_i;
388
+ series += numer / denom * (factor + 1 / denom);
389
+ }
390
+ const scalar_t result = x * compat_pow(1 - x, -beta) * series;
391
+ return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
392
+ }
393
+
394
+ // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
395
+ // Assumes x is close to zero and uses a Taylor expansion.
396
+ template <typename scalar_t, typename accscalar_t>
397
+ C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
398
+ const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha + beta) - digamma_one<scalar_t, accscalar_t>(beta);
399
+ scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
400
+ for (int i = 1; i <= 8; ++i) {
401
+ scalar_t casted_i = static_cast<scalar_t>(i);
402
+ numer *= -x / casted_i;
403
+ dbetas = dbetas * (beta - casted_i) + betas;
404
+ betas = betas * (beta - casted_i);
405
+ series += numer / (alpha + casted_i) * (dbetas + factor * betas);
406
+ }
407
+ const scalar_t result = -compat_pow(1 - x, 1 - beta) * series;
408
+ return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
409
+ }
410
+
411
+ // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
412
+ // Assumes alpha and beta are both large and uses a Rice saddle point expansion.
413
+ // To ensure numerical stability, this computation is performed at higher precision.
414
+ template<typename scalar_t, typename accscalar_t>
415
+ C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
416
+ const accscalar_t total = alpha + beta;
417
+ const accscalar_t mean = alpha / total;
418
+ const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
419
+ if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
420
+ // Avoid the singularity at x = mean.
421
+ const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
422
+ (43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
423
+ 3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
424
+ (453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
425
+ 8 * (1 - x) * (135 * beta - 11)))));
426
+ const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
427
+ const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
428
+ return prefactor_num / (1 - x) * poly / prefactor_den;
429
+ }
430
+ const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
431
+ const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
432
+ * (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
433
+ / (1 + 1 / (12 * total) + 1 / (288 * total * total));
434
+ const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
435
+ const accscalar_t axbx = alpha * (x - 1) + beta * x;
436
+ const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast<accscalar_t>(1.5f)) * axbx * axbx;
437
+ const accscalar_t term1 = term1_num / term1_den;
438
+ const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
439
+ const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
440
+ const accscalar_t term3_den = beta * x + alpha * (x - 1);
441
+ const accscalar_t term3 = term3_num / term3_den;
442
+ const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
443
+ alpha * compat_log(alpha / (total * x));
444
+ const accscalar_t term4 = compat_pow(term4_base, static_cast<accscalar_t>(-1.5f));
445
+ const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
446
+ return static_cast<scalar_t>(stirling * prefactor * term1234);
447
+ }
448
+
449
+ // Computes a scaled reparameterized gradient
450
+ // -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x)
451
+ // for random number x drawn from a Beta distribution Beta(alpha,beta).
452
+ // This function inputs total=alpha+beta to make it easy to implement
453
+ // Dirichlet reparameterized gradients in terms of Betas.
454
+ template<typename scalar_t, typename accscalar_t>
455
+ C10_HOST_DEVICE static inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
456
+ accscalar_t x_ = static_cast<accscalar_t>(x);
457
+ accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
458
+ accscalar_t total_ = static_cast<accscalar_t>(total);
459
+
460
+ const scalar_t beta = total - alpha;
461
+ const accscalar_t beta_ = total_ - alpha_;
462
+ const scalar_t boundary = total * x * (1 - x);
463
+
464
+ // Use an asymptotic approximation for x close to 0.
465
+ if (x <= 0.5f && boundary < 2.5f) {
466
+ return _beta_grad_alpha_small<scalar_t, accscalar_t>(x, alpha, beta);
467
+ }
468
+
469
+ // Use an asymptotic approximation for x close to 1.
470
+ if (x >= 0.5f && boundary < 0.75f) {
471
+ return -_beta_grad_beta_small<scalar_t, accscalar_t>(1 - x, beta, alpha);
472
+ }
473
+
474
+ // Use an asymptotic approximation when alpha and (total - alpha) are both large.
475
+ if (alpha > 6 && beta > 6) {
476
+ return _beta_grad_alpha_mid<scalar_t, accscalar_t>(x_, alpha_, beta_);
477
+ }
478
+
479
+ // Use a rational correction to an analytic approximation.
480
+ static const accscalar_t c[2][3][3][4] = {
481
+ {{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
482
+ {0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
483
+ {-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
484
+ {{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
485
+ {-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
486
+ {0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
487
+ {{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
488
+ {0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
489
+ {0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
490
+ {{{1, -0.02924021934, -0.04438342661, 0.007285809825},
491
+ {0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
492
+ {-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
493
+ {{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
494
+ {0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
495
+ {-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
496
+ {{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
497
+ {0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
498
+ {-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
499
+ };
500
+ const accscalar_t u = compat_log(x_);
501
+ const accscalar_t a = compat_log(alpha_) - u;
502
+ const accscalar_t b = compat_log(total_) - a;
503
+ const accscalar_t pow_u[3] = {1, u, u * u};
504
+ const accscalar_t pow_a[3] = {1, a, a * a};
505
+ accscalar_t p = 0.0;
506
+ accscalar_t q = 0.0;
507
+ for (int i = 0; i < 3; ++i) {
508
+ for (int j = 0; j < 3; ++j) {
509
+ const accscalar_t ua = pow_u[i] * pow_a[j];
510
+ p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
511
+ q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
512
+ }
513
+ }
514
+ const accscalar_t approx = x_ * (digamma_one<scalar_t, accscalar_t>(total_) - digamma_one<scalar_t, accscalar_t>(alpha_)) / beta_;
515
+ return static_cast<scalar_t>(p / q * approx);
516
+ }
517
+
518
+ } // namespace
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonSymbolicBC.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <c10/util/irange.h>
4
+ #include <ATen/core/IListRef.h>
5
+
6
+ namespace at::native {
7
+ // This file contains non-symbolic signatures for ops that we have sym-intified the signature of.
8
+ // However, in certain cases (such as static runtime), we call the native versions of the ops directly.
9
+ // In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation.
10
+ TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape);
11
+ TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
12
+ TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, c10::optional<at::ScalarType> dtype=c10::nullopt, c10::optional<at::Layout> layout=c10::nullopt, c10::optional<at::Device> device=c10::nullopt, c10::optional<bool> pin_memory=c10::nullopt, c10::optional<bool> is_coalesced=c10::nullopt);
13
+ TORCH_API at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
14
+ TORCH_API at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
15
+ // The below ops don't get a duplicated C++ implementation.
16
+ // They are backward ops, which make them very unlikely to be called directly
17
+ // by external code (at::native::trace_backward).
18
+ // They get their own declaration for BC purposes however.
19
+ TORCH_API at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
20
+ TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
21
+ TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim);
22
+ TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes);
23
+ TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index);
24
+ TORCH_API at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index);
25
+ TORCH_API std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim);
26
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PixelShuffle.h ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <c10/util/Exception.h>
3
+
4
+ namespace at {
5
+ namespace native {
6
+
7
+ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
8
+ TORCH_CHECK(self.dim() >= 3,
9
+ "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
10
+ self.dim(), " dimension(s)");
11
+ TORCH_CHECK(upscale_factor > 0,
12
+ "pixel_shuffle expects a positive upscale_factor, but got ",
13
+ upscale_factor);
14
+ int64_t c = self.size(-3);
15
+ int64_t upscale_factor_squared = upscale_factor * upscale_factor;
16
+ TORCH_CHECK(c % upscale_factor_squared == 0,
17
+ "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
18
+ "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
19
+ }
20
+
21
+ inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
22
+ TORCH_CHECK(
23
+ self.dim() >= 3,
24
+ "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
25
+ self.dim(),
26
+ " dimension(s)");
27
+ TORCH_CHECK(
28
+ downscale_factor > 0,
29
+ "pixel_unshuffle expects a positive downscale_factor, but got ",
30
+ downscale_factor);
31
+ int64_t h = self.size(-2);
32
+ int64_t w = self.size(-1);
33
+ TORCH_CHECK(
34
+ h % downscale_factor == 0,
35
+ "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
36
+ h,
37
+ " is not divisible by ",
38
+ downscale_factor);
39
+ TORCH_CHECK(
40
+ w % downscale_factor == 0,
41
+ "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
42
+ w,
43
+ " is not divisible by ",
44
+ downscale_factor);
45
+ }
46
+
47
+ }} // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RangeFactories.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/native/DispatchStub.h>
2
+ #include <c10/core/Scalar.h>
3
+
4
+ namespace at {
5
+ struct TensorIterator;
6
+
7
+ namespace native {
8
+
9
+ DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&), arange_stub);
10
+ DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, int64_t), linspace_stub);
11
+
12
+ }} // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceAllOps.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace at {
6
+ class Tensor;
7
+ }
8
+
9
+ namespace at::native {
10
+
11
+ using reduce_all_fn = void (*)(Tensor & result, const Tensor & self);
12
+ using reduce_min_max_fn = void (*)(Tensor & max_result, Tensor & min_result, const Tensor & self);
13
+ DECLARE_DISPATCH(reduce_all_fn, min_all_stub);
14
+ DECLARE_DISPATCH(reduce_all_fn, max_all_stub);
15
+
16
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOpsUtils.h ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <limits>
4
+ #include <ATen/core/Tensor.h>
5
+ #include <ATen/native/Resize.h>
6
+ #include <ATen/native/TensorIterator.h>
7
+ #include <ATen/native/NonEmptyUtils.h>
8
+ #include <ATen/WrapDimUtilsMulti.h>
9
+ #include <c10/core/ScalarType.h>
10
+ #include <c10/util/irange.h>
11
+
12
+ #ifndef AT_PER_OPERATOR_HEADERS
13
+ #include <ATen/Functions.h>
14
+ #else
15
+ #include <ATen/ops/empty.h>
16
+ #include <ATen/ops/scalar_tensor.h>
17
+ #endif
18
+
19
+ namespace at::native {
20
+
21
+ // Maximum and minimum possible scalar values, including infinities
22
+ template <typename scalar_t>
23
+ constexpr scalar_t upper_bound() {
24
+ using lim = std::numeric_limits<scalar_t>;
25
+ return lim::has_infinity ? lim::infinity() : lim::max();
26
+ }
27
+
28
+ template <typename scalar_t>
29
+ constexpr scalar_t lower_bound() {
30
+ using lim = std::numeric_limits<scalar_t>;
31
+ return lim::has_infinity ? -lim::infinity() : lim::lowest();
32
+ }
33
+
34
+ static inline Tensor restride_dim(
35
+ const Tensor& src, int64_t dim,
36
+ IntArrayRef replacement_shape
37
+ ) {
38
+ auto strides = ensure_nonempty_vec(src.strides().vec());
39
+ strides[dim] = 0;
40
+ return src.as_strided(replacement_shape, strides);
41
+ }
42
+
43
+ inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
44
+ int64_t dim) {
45
+ IntArrayRef self_sizes = self.sizes();
46
+ std::vector<int64_t> result_sizes;
47
+ result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
48
+ result_sizes[dim] = 1;
49
+ result.resize_(result_sizes);
50
+ }
51
+
52
+ inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
53
+ const Scalar& ident, int64_t dim, bool keepdim) {
54
+ if (self.numel() == 1 && self.ndimension() == 0) {
55
+ result.resize_({});
56
+ result.fill_(self);
57
+ return true;
58
+ }
59
+ // Return identity
60
+ if (self.numel() == 0) {
61
+ _dimreduce_setup(result, self, dim);
62
+ result.fill_(ident);
63
+ if (!keepdim) result.squeeze_(dim);
64
+ return true;
65
+ }
66
+ return false;
67
+ }
68
+
69
+ inline bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &self,
70
+ int64_t /*dim*/, bool /*keepdim*/, const char* /*fn_name*/) {
71
+ if (self.numel() == 1 && self.ndimension() == 0) {
72
+ result.resize_({});
73
+ result.fill_(self);
74
+ return true;
75
+ }
76
+
77
+ return false;
78
+ }
79
+
80
+ inline c10::optional<Tensor> _allreduce_return_trivial(
81
+ const Tensor& self,
82
+ const Scalar& ident) {
83
+ // Return identity
84
+ if (self.numel() == 0) {
85
+ return at::scalar_tensor(ident, self.options());
86
+ }
87
+ return c10::nullopt;
88
+ }
89
+
90
+ #define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \
91
+ { \
92
+ TORCH_CHECK(\
93
+ out.option() == self.option(),\
94
+ "expected ", #option, " ",\
95
+ self.option(),\
96
+ " but found ", out.option())\
97
+ }
98
+
99
+ static inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
100
+ OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self);
101
+ OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options());
102
+ OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options());
103
+ }
104
+
105
+ static inline Tensor integer_upcast(const Tensor& self, c10::optional<ScalarType> dtype) {
106
+ ScalarType scalarType = self.scalar_type();
107
+ TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
108
+ ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
109
+ return self.toType(upcast_scalarType);
110
+ }
111
+
112
+ using DimMask = TensorIterator::DimMask;
113
+
114
+ static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
115
+ if (opt_dims.has_value()) {
116
+ return DimVector(opt_dims.value());
117
+ } else {
118
+ std::vector<int64_t> all_dims(ndim);
119
+ std::iota(all_dims.begin(), all_dims.end(), 0);
120
+ return DimVector(all_dims);
121
+ }
122
+ }
123
+
124
+ static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
125
+ DimMask mask;
126
+ if (opt_dims.has_value()) {
127
+ auto dims = opt_dims.value();
128
+ if (dims.empty() && !allow_empty_dims) {
129
+ mask = DimMask().flip();
130
+ } else {
131
+ mask = at::dim_list_to_bitset(dims, ndim);
132
+ }
133
+ } else {
134
+ mask = DimMask().flip();
135
+ }
136
+ return mask;
137
+ }
138
+
139
+ inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keepdim) {
140
+ auto shape = DimVector(self.sizes());
141
+ for (int dim = shape.size() - 1; dim >= 0; dim--) {
142
+ if (mask[dim]) {
143
+ if (keepdim) {
144
+ shape[dim] = 1;
145
+ } else {
146
+ shape.erase(shape.begin() + dim);
147
+ }
148
+ }
149
+ }
150
+ return shape;
151
+ }
152
+
153
+ static void resize_reduction_result(
154
+ Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
155
+ ScalarType /*dtype*/)
156
+ {
157
+ auto shape = shape_from_dim_mask(self, mask, keepdim);
158
+ TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
159
+ at::native::resize_output(result, shape);
160
+ }
161
+
162
+ inline Tensor create_reduction_result(
163
+ const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype
164
+ ) {
165
+ DimMask mask = make_dim_mask(dim, self.dim());
166
+ auto shape = shape_from_dim_mask(self, mask, keepdim);
167
+ return at::empty(shape, self.options().dtype(dtype));
168
+ }
169
+
170
+ static Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
171
+ if (keepdim) {
172
+ return result;
173
+ }
174
+ auto shape = DimVector(result.sizes());
175
+ auto stride = DimVector(result.strides());
176
+ for (const auto dim : c10::irange(ndim)) {
177
+ if (mask[dim]) {
178
+ shape.insert(shape.begin() + dim, 1);
179
+ stride.insert(stride.begin() + dim, 0);
180
+ }
181
+ }
182
+ return result.as_strided(shape, stride);
183
+ }
184
+
185
+ static TensorIterator make_reduction(
186
+ const char* name, Tensor& result, const Tensor& self,
187
+ at::OptionalIntArrayRef dim_opt,
188
+ bool keepdim, ScalarType in_dtype, ScalarType out_dtype) {
189
+ // check that result type and dtype match if provided
190
+ TORCH_CHECK(
191
+ !result.defined() || result.scalar_type() == out_dtype,
192
+ name, ": provided dtype must match dtype of result. Got ",
193
+ toString(result.scalar_type()),
194
+ " and ",
195
+ toString(out_dtype),
196
+ ".");
197
+ // dim={} performs an all-reduce, same as dim=None
198
+ IntArrayRef dim = dim_opt.value_or(IntArrayRef{});
199
+ int64_t ndim = self.dim();
200
+ auto mask = make_dim_mask(dim, ndim);
201
+ resize_reduction_result(result, self, mask, keepdim, out_dtype);
202
+ auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
203
+ namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
204
+ if (self.scalar_type() == in_dtype) {
205
+ return TensorIterator::reduce_op(viewed_result, self);
206
+ }
207
+ return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
208
+ }
209
+
210
+ static C10_UNUSED TensorIterator make_reduction(
211
+ const char* name, Tensor& result, const Tensor& self,
212
+ at::OptionalIntArrayRef dim, bool keepdim, ScalarType out_dtype) {
213
+ // special case for type promotion in mixed precision, improves computational
214
+ // efficiency.
215
+ // not generalize this to common mismatched input/output types to avoid cross
216
+ // product of templated kernel launches.
217
+ const bool gpu_lowp_to_f32 = (
218
+ self.is_cuda() && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat);
219
+ auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type()
220
+ : self.is_complex() ? c10::toComplexType(out_dtype)
221
+ : out_dtype;
222
+ return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
223
+ }
224
+
225
+ static TensorIterator make_reduction(
226
+ const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
227
+ at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1,
228
+ ScalarType dtype2) {
229
+ // check that result type and dtype match if provided
230
+ TORCH_CHECK(
231
+ (!result1.defined() || result1.scalar_type() == dtype1) && (!result2.defined() || result2.scalar_type() == dtype2),
232
+ name, ": provided dtype must match dtype of result. Got ",
233
+ toString(result1.scalar_type()), toString(result2.scalar_type()),
234
+ " and ",
235
+ toString(dtype1), toString(dtype2),
236
+ ".");
237
+
238
+ // dim={} performs an all-reduce, same as dim=None
239
+ auto dim = dim_opt.value_or(IntArrayRef{});
240
+ int64_t ndim = self.dim();
241
+ DimMask mask = make_dim_mask(dim, ndim);
242
+ resize_reduction_result(result1, self, mask, keepdim, dtype1);
243
+ auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim);
244
+
245
+ resize_reduction_result(result2, self, mask, keepdim, dtype2);
246
+ auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim);
247
+
248
+ namedinference::propagate_names_for_reduction(result1, self, dim, keepdim);
249
+ namedinference::propagate_names_for_reduction(result2, self, dim, keepdim);
250
+
251
+ // special case for type promotion in mixed precision, improves computational
252
+ // efficiency.
253
+ // We don't generalize this to common mismatched input/output types to avoid cross
254
+ // product of templated kernel launches.
255
+ if (self.scalar_type() == dtype1 ||
256
+ (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
257
+ return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
258
+ }
259
+ return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
260
+ }
261
+
262
+ static C10_UNUSED TensorIterator make_reduction(
263
+ const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
264
+ at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) {
265
+ return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype);
266
+ }
267
+
268
+ static void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
269
+ if (self.ndimension() == 0) {
270
+ TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name,
271
+ ": Expected reduction dim -1 or 0 for scalar but got ", dim);
272
+ }
273
+ else {
274
+ TORCH_CHECK_INDEX(self.size(dim) != 0, fn_name,
275
+ ": Expected reduction dim ", dim, " to have non-zero size.");
276
+ }
277
+ }
278
+
279
+ static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
280
+ TORCH_CHECK(
281
+ !dim.empty(),
282
+ fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
283
+ "Specify the reduction dim with the 'dim' argument.");
284
+ for (const int64_t d : dim) {
285
+ zero_numel_check_dims(self, d, fn_name);
286
+ }
287
+ }
288
+
289
+ static std::vector<int64_t> get_zero_numel_tensor_size(
290
+ const Tensor& self,
291
+ const int64_t dim,
292
+ const bool keepdim,
293
+ const char* fn_name) {
294
+ TORCH_INTERNAL_ASSERT(self.numel() == 0, fn_name, ": Expected self.numel() == 0.");
295
+ zero_numel_check_dims(self, dim, fn_name);
296
+ std::vector<int64_t> sizes;
297
+ if (keepdim) {
298
+ sizes = self.sizes().vec();
299
+ sizes[dim] = 1;
300
+ }
301
+ else {
302
+ for (const auto d : c10::irange(self.dim())) {
303
+ if (d != dim) {
304
+ sizes.push_back(self.sizes()[d]);
305
+ }
306
+ }
307
+ }
308
+ return sizes;
309
+ }
310
+
311
+ // Resize the result tensor and indices when result.numel() == 0 depending on values of
312
+ // dim and keepdim for returning tensors containing reduction results.
313
+ // This function should be called when you are reducing a zero-numel tensor and want to
314
+ // resize the output and return it. This function exists for resizing zero-numel
315
+ // tensors when the size of the reduction dimension is non-zero.
316
+ static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices,
317
+ const Tensor& self, const int64_t dim,
318
+ const bool keepdim, const char *fn_name) {
319
+ auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name);
320
+ at::native::resize_output(result, sizes);
321
+ at::native::resize_output(result_indices, sizes);
322
+ }
323
+
324
+ inline ScalarType get_dtype_from_self(
325
+ const Tensor& self,
326
+ const c10::optional<ScalarType>& dtype,
327
+ bool promote_integers) {
328
+ if (dtype.has_value()) {
329
+ return dtype.value();
330
+ }
331
+ ScalarType src_type = self.scalar_type();
332
+ if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
333
+ return kLong;
334
+ }
335
+ return src_type;
336
+ }
337
+
338
+ inline ScalarType get_dtype_from_result(Tensor& result, c10::optional<ScalarType> dtype) {
339
+ TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
340
+ if (dtype.has_value()) {
341
+ return dtype.value();
342
+ } else {
343
+ return result.scalar_type();
344
+ }
345
+ }
346
+
347
+
348
+ } // namespace at::native
349
+
350
+ namespace at::meta {
351
+
352
+ static C10_UNUSED DimVector get_reduction_shape(
353
+ const Tensor& self,
354
+ IntArrayRef dims,
355
+ bool keepdim,
356
+ bool allow_empty_dims=false) {
357
+ auto mask = native::make_dim_mask(dims, self.dim(), allow_empty_dims);
358
+ return native::shape_from_dim_mask(self, mask, keepdim);
359
+ }
360
+
361
+ static void resize_reduction(
362
+ impl::MetaBase& meta,
363
+ const Tensor& self,
364
+ OptionalIntArrayRef opt_dims,
365
+ bool keepdim,
366
+ ScalarType out_dtype,
367
+ bool allow_empty_dims=false) {
368
+ DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
369
+ maybe_wrap_dims(dims_, self.dim());
370
+ auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims);
371
+ meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
372
+ namedinference::propagate_names_for_reduction(
373
+ meta.maybe_get_output(), self, dims_, keepdim);
374
+ }
375
+
376
+ static void resize_reduction_with_indices(
377
+ impl::MetaBase& meta,
378
+ const Tensor& self,
379
+ IntArrayRef dims,
380
+ bool keepdim,
381
+ ScalarType out_dtype) {
382
+ DimVector dims_(dims);
383
+ maybe_wrap_dims(dims_, self.dim());
384
+ auto shape = get_reduction_shape(self, dims_, keepdim);
385
+ meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
386
+ meta.set_output_raw_strided(1, shape, {}, self.options().dtype(kLong));
387
+ namedinference::propagate_names_for_reduction(
388
+ meta.maybe_get_output(0), self, dims_, keepdim);
389
+ namedinference::propagate_names_for_reduction(
390
+ meta.maybe_get_output(1), self, dims_, keepdim);
391
+ }
392
+
393
+ static TensorIterator make_reduction(
394
+ const Tensor& self,
395
+ const Tensor& result,
396
+ OptionalIntArrayRef opt_dims,
397
+ bool keepdim,
398
+ ScalarType in_dtype) {
399
+ int64_t ndim = self.dim();
400
+ auto mask = at::native::make_dim_mask(opt_dims, ndim);
401
+ auto viewed_result =
402
+ at::native::review_reduce_result(result, ndim, mask, keepdim);
403
+ if (self.scalar_type() == in_dtype) {
404
+ return TensorIterator::reduce_op(viewed_result, self);
405
+ }
406
+ return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
407
+ }
408
+
409
+ static TensorIterator make_reduction(
410
+ const Tensor& self,
411
+ const Tensor& result1,
412
+ const Tensor& result2,
413
+ IntArrayRef dims,
414
+ bool keepdim,
415
+ ScalarType dtype1,
416
+ ScalarType /*dtype2*/) {
417
+ int64_t ndim = self.dim();
418
+ auto mask = at::native::make_dim_mask(dims, ndim);
419
+ auto viewed_result1 = at::native::review_reduce_result(result1, ndim, mask, keepdim);
420
+ auto viewed_result2 = at::native::review_reduce_result(result2, ndim, mask, keepdim);
421
+ // special case for type promotion in mixed precision, improves computational efficiency.
422
+ // We don't generalize this to common mismatched input/output types to avoid cross product
423
+ // of templated kernel launches.
424
+ if (self.scalar_type() == dtype1 ||
425
+ (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
426
+ return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
427
+ }
428
+ return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
429
+ }
430
+
431
+ static C10_UNUSED TensorIterator make_reduction_from_out_ty(
432
+ const Tensor& self,
433
+ const Tensor& result,
434
+ OptionalIntArrayRef opt_dims,
435
+ bool keepdim,
436
+ ScalarType out_dtype) {
437
+ // special case for type promotion in mixed precision, improves computational
438
+ // efficiency.
439
+ // not generalize this to common mismatched input/output types to avoid cross
440
+ // product of templated kernel launches.
441
+ const bool gpu_lowp_to_f32 =
442
+ (self.is_cuda() &&
443
+ (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
444
+ out_dtype == kFloat);
445
+ auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
446
+ return make_reduction(self, result, opt_dims, keepdim, in_dtype);
447
+ }
448
+
449
+ } // namespace at::meta
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReductionType.h ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Scalar.h>
4
+
5
+ namespace at::native {
6
+
7
+ enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
8
+
9
+ static inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
10
+ if (reduce == "max" || reduce == "amax") {
11
+ return ReductionType::MAX;
12
+ } else if (reduce == "mean") {
13
+ return ReductionType::MEAN;
14
+ } else if (reduce == "min" || reduce == "amin") {
15
+ return ReductionType::MIN;
16
+ } else if (reduce == "sum") {
17
+ return ReductionType::SUM;
18
+ } else if (reduce == "prod") {
19
+ return ReductionType::PROD;
20
+ } else {
21
+ TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
22
+ }
23
+ }
24
+
25
+ // used for `scatter_reduce`, old options for BC.
26
+ static inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
27
+ if (use_new_options) {
28
+ return get_reduction_enum(reduce);
29
+ } else {
30
+ if (reduce == "add") {
31
+ return ReductionType::SUM;
32
+ } else if (reduce == "multiply") {
33
+ return ReductionType::PROD;
34
+ } else {
35
+ TORCH_CHECK(false, "reduce argument must be either add or multiply.")
36
+ }
37
+ }
38
+ }
39
+
40
+ } // at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ScatterGatherChecks.h ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+ #include <ATen/core/Tensor.h>
5
+ #include <ATen/native/ReduceOpsUtils.h>
6
+ #include <c10/util/irange.h>
7
+
8
+ namespace at::native {
9
+
10
+ namespace {
11
+
12
+ // checks whether index.dtype == int64
13
+ // and self.dtype == src.dtype if src is a Tensor
14
+ static void scatter_gather_dtype_check(
15
+ const std::string& method_name,
16
+ const Tensor& self,
17
+ const Tensor& index,
18
+ const c10::optional<Tensor>& src_opt = c10::nullopt
19
+ ) {
20
+ if (index.numel() != 0) {
21
+ TORCH_CHECK(
22
+ index.scalar_type() == at::ScalarType::Long,
23
+ method_name, "(): Expected dtype int64 for index"
24
+ );
25
+ }
26
+
27
+ if (src_opt.has_value()) {
28
+ const auto& src = src_opt.value();
29
+ TORCH_CHECK(
30
+ self.scalar_type() == src.scalar_type(),
31
+ method_name, "(): Expected self.dtype to be equal to src.dtype"
32
+ );
33
+ }
34
+ }
35
+
36
+ // Used for `gather`-like methods
37
+ // Note: self means the input tensor here
38
+ // Test:
39
+ // 1. index.size(d) <= self.size(d) for all d != dim
40
+ // 2. index.dim() == self.dim()
41
+ static C10_UNUSED void gather_shape_check(const Tensor& self, int64_t dim,
42
+ const Tensor& index
43
+ ) {
44
+ auto self_dims = ensure_nonempty_dim(self.dim());
45
+ TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
46
+ "Index tensor must have the same number of dimensions as input tensor"
47
+ );
48
+
49
+ for (const auto i : c10::irange(self_dims)) {
50
+ if (i != dim) {
51
+ TORCH_CHECK(
52
+ ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
53
+ "Size does not match at dimension ", i,
54
+ " expected index ", index.sizes(),
55
+ " to be smaller than self ", self.sizes(),
56
+ " apart from dimension ", dim
57
+ );
58
+ }
59
+ }
60
+ }
61
+
62
+ // Used for `scatter` and `scatter_add`
63
+ // Tests:
64
+ // 1. index.size(d) <= self.size(d) for all d != dim
65
+ // 2. index.size(d) <= src.size(d) for all d if src is a Tensor
66
+ // 3. index.dim() == self.dim() == src.dim()
67
+ static C10_UNUSED void scatter_shape_check(
68
+ const Tensor& self, int64_t dim, const Tensor& index,
69
+ const c10::optional<Tensor>& src_opt = c10::nullopt
70
+ ) {
71
+ if (index.numel() == 0) return;
72
+ TORCH_CHECK(
73
+ ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
74
+ "Index tensor must have the same number of dimensions as self tensor"
75
+ );
76
+
77
+ bool is_wrong_shape = false;
78
+ int64_t self_dims = ensure_nonempty_dim(self.dim());
79
+
80
+ // Check: index.size(d) <= self.size(d) for all d != dim
81
+ for (const auto d : c10::irange(self_dims)) {
82
+ int64_t index_d_size = ensure_nonempty_size(index, d);
83
+ if (d == dim) continue;
84
+ if (index_d_size > ensure_nonempty_size(self, d)) {
85
+ is_wrong_shape = true;
86
+ break;
87
+ }
88
+ }
89
+
90
+ // Check: index.size(d) <= src.size(d) for all d if src is Tensor
91
+ if (!is_wrong_shape && src_opt.has_value()) {
92
+ const auto& src = src_opt.value();
93
+ for (const auto d : c10::irange(self_dims)) {
94
+ int64_t index_d_size = ensure_nonempty_size(index, d);
95
+ if (index_d_size > ensure_nonempty_size(src, d)) {
96
+ is_wrong_shape = true;
97
+ break;
98
+ }
99
+ }
100
+ }
101
+
102
+ if (src_opt.has_value()) {
103
+ const auto& src = src_opt.value();
104
+
105
+ TORCH_CHECK(
106
+ ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
107
+ "Index tensor must have the same number of dimensions as src tensor"
108
+ );
109
+
110
+ TORCH_CHECK(!is_wrong_shape,
111
+ "Expected index ", index.sizes(),
112
+ " to be smaller than self ", self.sizes(),
113
+ " apart from dimension ", dim,
114
+ " and to be smaller size than src ", src.sizes()
115
+ );
116
+ }
117
+ else {
118
+ TORCH_CHECK(!is_wrong_shape,
119
+ "Expected index ", index.sizes(),
120
+ " to be smaller than self ", self.sizes(),
121
+ " apart from dimension ", dim
122
+ );
123
+ }
124
+ }
125
+
126
+ } // anonymous namespace
127
+
128
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SegmentReduce.h ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <ATen/native/ReductionType.h>
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/util/Optional.h>
7
+
8
+ namespace at {
9
+ class Tensor;
10
+
11
+ namespace native {
12
+
13
+ using segment_reduce_lengths_fn = Tensor (*)(
14
+ ReductionType,
15
+ const Tensor&,
16
+ const Tensor&,
17
+ int64_t,
18
+ const c10::optional<Scalar>&);
19
+ DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
20
+
21
+ using segment_reduce_offsets_fn = Tensor (*)(
22
+ ReductionType,
23
+ const Tensor&,
24
+ const Tensor&,
25
+ int64_t,
26
+ const c10::optional<Scalar>&);
27
+ DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub);
28
+
29
+ using segment_reduce_lengths_backward_fn = Tensor (*)(
30
+ const Tensor&,
31
+ const Tensor&,
32
+ const Tensor&,
33
+ ReductionType,
34
+ const Tensor&,
35
+ int64_t,
36
+ const c10::optional<Scalar>&);
37
+ DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub);
38
+
39
+ using segment_reduce_offsets_backward_fn = Tensor (*)(
40
+ const Tensor&,
41
+ const Tensor&,
42
+ const Tensor&,
43
+ ReductionType,
44
+ const Tensor&,
45
+ int64_t,
46
+ const c10::optional<Scalar>&);
47
+ DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub);
48
+
49
+ } // namespace native
50
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // Indexing tensors by tensors
4
+
5
+ #include <ATen/core/List.h>
6
+ #include <ATen/core/Tensor.h>
7
+ #include <ATen/native/DispatchStub.h>
8
+ #include <ATen/native/ReductionType.h>
9
+
10
+ namespace at {
11
+ struct TensorIterator;
12
+ }
13
+
14
+ namespace at::native {
15
+
16
+ using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
17
+ using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<c10::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
18
+ using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
19
+ using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
20
+ using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
21
+ using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
22
+ using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
23
+ const Tensor& src, const ReductionType& reduce);
24
+ using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
25
+ const Scalar& value, const ReductionType& reduce);
26
+ using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
27
+ const Tensor& src, const ReductionType& reduce);
28
+
29
+ DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
30
+ DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
31
+ DECLARE_DISPATCH(gather_fn, gather_stub);
32
+ DECLARE_DISPATCH(scatter_fn, scatter_stub);
33
+ DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
34
+ DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
35
+ DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
36
+ DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
37
+ DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub);
38
+
39
+ TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<c10::optional<at::Tensor>>& indices);
40
+
41
+ using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
42
+ using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
43
+ using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
44
+
45
+ DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
46
+ DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
47
+ DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
48
+
49
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TypeProperties.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/core/IListRef.h>
5
+
6
+ namespace at::native {
7
+
8
+ struct ResultTypeState {
9
+ c10::ScalarType dimResult = ScalarType::Undefined;
10
+ c10::ScalarType wrappedResult = ScalarType::Undefined;
11
+ c10::ScalarType zeroResult = ScalarType::Undefined;
12
+ };
13
+
14
+ TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state);
15
+ TORCH_API ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state);
16
+ TORCH_API ScalarType result_type(const ResultTypeState& state);
17
+
18
+ TORCH_API ScalarType result_type(ITensorListRef tensors);
19
+
20
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold2d.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <c10/core/ScalarType.h>
5
+ #include <cstdint>
6
+
7
+ namespace at::native {
8
+
9
+ using unfold2d_fn = void (*)(
10
+ ScalarType dtype,
11
+ void *finput,
12
+ void *input,
13
+ int64_t kH,
14
+ int64_t kW,
15
+ int64_t dH,
16
+ int64_t dW,
17
+ int64_t padH,
18
+ int64_t padW,
19
+ int64_t n_input_plane,
20
+ int64_t input_height,
21
+ int64_t input_width,
22
+ int64_t output_height,
23
+ int64_t output_width,
24
+ bool is_channels_last
25
+ );
26
+
27
+ DECLARE_DISPATCH(unfold2d_fn, unfolded2d_copy_stub);
28
+ DECLARE_DISPATCH(unfold2d_fn, unfolded2d_acc_stub);
29
+
30
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/batch_norm.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ namespace at::native {
7
+
8
+ using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
9
+ const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
10
+ using batch_norm_collect_stats_fn = void (*)(Tensor&, Tensor&, const Tensor&);
11
+ using batch_norm_backward_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&,
12
+ const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
13
+
14
+ DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_stub);
15
+ DECLARE_DISPATCH(batch_norm_collect_stats_fn, batch_norm_cpu_collect_stats_stub);
16
+ DECLARE_DISPATCH(batch_norm_backward_fn, batch_norm_cpu_backward_stub);
17
+
18
+ // TensorAccessor when it is defined to work around undefined...
19
+ template <typename scalar_t>
20
+ static TensorAccessor<scalar_t, 1> conditional_accessor_1d(const Tensor& t) {
21
+ if (! t.defined()) {
22
+ return TensorAccessor<scalar_t, 1>(nullptr, nullptr, nullptr);
23
+ }
24
+ return t.accessor<scalar_t, 1>();
25
+ }
26
+
27
+ template <typename scalar_t>
28
+ static scalar_t* conditional_data_ptr(const Tensor& t) {
29
+ return t.defined() ? t.contiguous().data_ptr<scalar_t>()
30
+ : nullptr;
31
+ }
32
+
33
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/im2col_shape_check.h ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/TensorUtils.h>
4
+ #include <ATen/div_rtn.h>
5
+
6
+ namespace at::native {
7
+
8
+ static inline void col2im_shape_check(
9
+ const Tensor& input,
10
+ const Tensor& grad_output,
11
+ int64_t output_height,
12
+ int64_t output_width,
13
+ int64_t kernel_height,
14
+ int64_t kernel_width,
15
+ int64_t dilation_height,
16
+ int64_t dilation_width,
17
+ int64_t pad_height,
18
+ int64_t pad_width,
19
+ int64_t stride_height,
20
+ int64_t stride_width) {
21
+ TORCH_CHECK(
22
+ kernel_width > 0 && kernel_height > 0,
23
+ "kernel size should be greater than zero, but got kernel_height: ",
24
+ kernel_height,
25
+ " kernel_width: ",
26
+ kernel_width);
27
+ TORCH_CHECK(
28
+ stride_width > 0 && stride_height > 0,
29
+ "stride should be greater than zero, but got stride_height: ",
30
+ stride_height,
31
+ " stride_width: ",
32
+ stride_width);
33
+ TORCH_CHECK(
34
+ dilation_width > 0 && dilation_height > 0,
35
+ "dilation should be greater than zero, but got dilation_height: ",
36
+ dilation_height,
37
+ " dilation_width: ",
38
+ dilation_width);
39
+ TORCH_CHECK(
40
+ pad_width >= 0 && pad_height >= 0,
41
+ "padding should be non-negative, but got pad_height: ",
42
+ pad_height,
43
+ " pad_width: ",
44
+ pad_width);
45
+
46
+
47
+ int64_t ndim = input.ndimension();
48
+ // allow dim=0 only the batch dimension.
49
+ TORCH_CHECK(
50
+ (ndim == 2 && input.size(0) != 0 && input.size(1) != 0) ||
51
+ (ndim == 3 && input.size(1) != 0 && input.size(2) != 0),
52
+ "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ",
53
+ input.sizes());
54
+
55
+ int64_t batch_dim = (ndim == 3) ? 0 : -1;
56
+ int64_t n_input_plane = input.size(batch_dim + 1);
57
+
58
+ if (n_input_plane % (kernel_width * kernel_height) != 0) {
59
+ AT_ERROR(
60
+ "Expected size of input's dimension 1 to be divisible by the "
61
+ "product of kernel_size, but got input.size(1)=",
62
+ n_input_plane,
63
+ " and kernel_size=(",
64
+ kernel_height,
65
+ ", ",
66
+ kernel_width,
67
+ ").");
68
+ }
69
+
70
+ int64_t input_length = input.size(batch_dim + 2);
71
+ int64_t n_blocks_height =
72
+ div_rtn<int64_t>(
73
+ output_height + 2 * pad_height -
74
+ dilation_height * (kernel_height - 1) - 1,
75
+ stride_height) +
76
+ 1;
77
+ int64_t n_blocks_width = div_rtn<int64_t>(
78
+ output_width + 2 * pad_width -
79
+ dilation_width * (kernel_width - 1) - 1,
80
+ stride_width) +
81
+ 1;
82
+
83
+ if (input_length != (n_blocks_height * n_blocks_width)) {
84
+ AT_ERROR(
85
+ "Given output_size=(",
86
+ output_height,
87
+ ", ",
88
+ output_width,
89
+ "), kernel_size=(",
90
+ kernel_height,
91
+ ", ",
92
+ kernel_width,
93
+ "), dilation=(",
94
+ dilation_height,
95
+ ", ",
96
+ dilation_width,
97
+ "), padding=(",
98
+ pad_height,
99
+ ", ",
100
+ pad_width,
101
+ "), stride=(",
102
+ stride_height,
103
+ ", ",
104
+ stride_width,
105
+ "), expected size of input's dimension 2 to match the calculated number of ",
106
+ "sliding blocks ",
107
+ n_blocks_height,
108
+ " * ",
109
+ n_blocks_width,
110
+ " = ",
111
+ (n_blocks_height * n_blocks_width),
112
+ ", but got input.size(2)=",
113
+ input_length,
114
+ ".");
115
+ }
116
+
117
+ TORCH_CHECK(
118
+ n_blocks_height >= 1 && n_blocks_width >= 1,
119
+ "Given output_size=(", output_height, ", ", output_width, "), ",
120
+ "kernel_size=(", kernel_height, ", ", kernel_width, "), ",
121
+ "dilation=(", dilation_height, ", ", dilation_width, "), ",
122
+ "padding=(", pad_height, ", ", pad_width, "), ",
123
+ "stride=(", stride_height, ", ", stride_width, "), ",
124
+ "calculated shape of the array of sliding blocks as ",
125
+ "(", n_blocks_height, ", ", n_blocks_width, "), ",
126
+ "which is too small (non-positive)");
127
+
128
+ if (output_width < 1 || output_height < 1) {
129
+ AT_ERROR(
130
+ "Expected output spatial size to be positive, but got: output_size=(",
131
+ output_height,
132
+ ", ",
133
+ output_width,
134
+ ").");
135
+ }
136
+ }
137
+
138
+ static inline void im2col_shape_check(
139
+ const Tensor& input,
140
+ const Tensor& grad_output,
141
+ int64_t kernel_height,
142
+ int64_t kernel_width,
143
+ int64_t dilation_height,
144
+ int64_t dilation_width,
145
+ int64_t pad_height,
146
+ int64_t pad_width,
147
+ int64_t stride_height,
148
+ int64_t stride_width) {
149
+ TORCH_CHECK(
150
+ kernel_width > 0 && kernel_height > 0,
151
+ "kernel size should be greater than zero, but got kernel_height: ",
152
+ kernel_height,
153
+ " kernel_width: ",
154
+ kernel_width);
155
+
156
+ TORCH_CHECK(
157
+ dilation_width > 0 && dilation_height > 0,
158
+ "dilation should be greater than zero, but got dilation_height: ",
159
+ dilation_height,
160
+ " dilation_width: ",
161
+ dilation_width);
162
+
163
+ TORCH_CHECK(
164
+ pad_width >= 0 && pad_height >= 0,
165
+ "padding should be non-negative, but got pad_height: ",
166
+ pad_height,
167
+ " pad_width: ",
168
+ pad_width);
169
+
170
+ TORCH_CHECK(
171
+ stride_width > 0 && stride_height > 0,
172
+ "stride should be greater than zero, but got stride_height: ",
173
+ stride_height,
174
+ " stride_width: ",
175
+ stride_width);
176
+
177
+ int64_t ndim = input.ndimension();
178
+
179
+ // allow dim=0 only the batch dimension.
180
+ bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
181
+ TORCH_CHECK(
182
+ (ndim == 3 && input.size(0) && valid_dims) ||
183
+ (ndim == 4 && valid_dims && input.size(3) != 0),
184
+ "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
185
+ input.sizes());
186
+
187
+ int64_t dim_batch = 0;
188
+
189
+ if (ndim == 3) {
190
+ dim_batch = -1;
191
+ }
192
+
193
+ int64_t input_height = input.size(dim_batch + 2);
194
+ int64_t input_width = input.size(dim_batch + 3);
195
+ int64_t output_height = div_rtn<int64_t>(
196
+ input_height + 2 * pad_height -
197
+ (dilation_height * (kernel_height - 1) + 1),
198
+ stride_height) +
199
+ 1;
200
+ int64_t output_width = div_rtn<int64_t>(
201
+ input_width + 2 * pad_width -
202
+ (dilation_width * (kernel_width - 1) + 1),
203
+ stride_width) +
204
+ 1;
205
+
206
+ if (output_height < 1 || output_width < 1) {
207
+ AT_ERROR(
208
+ "Given input with spatial size (",
209
+ input_height,
210
+ ", ",
211
+ input_height,
212
+ "), kernel_size=(",
213
+ kernel_height,
214
+ ", ",
215
+ kernel_width,
216
+ "), dilation=(",
217
+ dilation_height,
218
+ ", ",
219
+ dilation_width,
220
+ "), padding=(",
221
+ pad_height,
222
+ ", ",
223
+ pad_width,
224
+ "), calculated shape of the array of sliding blocks as (",
225
+ output_height,
226
+ ", ",
227
+ output_width,
228
+ "), but its components must be at least one.");
229
+ }
230
+ }
231
+
232
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cholesky_solve_helper_cpu_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cpu {
19
+
20
+ TORCH_API at::Tensor _cholesky_solve_helper(const at::Tensor & self, const at::Tensor & A, bool upper);
21
+
22
+ } // namespace cpu
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_lgamma_ops.h ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _foreach_lgamma {
18
+ using schema = ::std::vector<at::Tensor> (at::TensorList);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_lgamma")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_lgamma(Tensor[] self) -> Tensor[]")
24
+ static ::std::vector<at::Tensor> call(at::TensorList self);
25
+ static ::std::vector<at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
26
+ };
27
+
28
+ struct TORCH_API _foreach_lgamma_ {
29
+ using schema = void (at::TensorList);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_lgamma_")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_lgamma_(Tensor(a!)[] self) -> ()")
35
+ static void call(at::TensorList self);
36
+ static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
37
+ };
38
+
39
+ struct TORCH_API _foreach_lgamma_out {
40
+ using schema = void (at::TensorList, at::TensorList);
41
+ using ptr_schema = schema*;
42
+ // See Note [static constexpr char* members for windows NVCC]
43
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_lgamma")
44
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
45
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> ()")
46
+ static void call(at::TensorList self, at::TensorList out);
47
+ static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out);
48
+ };
49
+
50
+ }} // namespace at::_ops
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_cuda_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace cuda {
19
+
20
+ TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> _scaled_dot_product_cudnn_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, c10::optional<double> scale=c10::nullopt);
21
+
22
+ } // namespace cuda
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_triton_scaled_dot_attention_compositeexplicitautograd_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautograd {
19
+
20
+ TORCH_API at::Tensor & _triton_scaled_dot_attention_out(at::Tensor & out, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p=0.0);
21
+ TORCH_API at::Tensor & _triton_scaled_dot_attention_outf(const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p, at::Tensor & out);
22
+
23
+ } // namespace compositeexplicitautograd
24
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_values_native.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor _values_sparse(const at::Tensor & self);
20
+ } // namespace native
21
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+ #include <ATen/ops/adaptive_max_pool3d_backward_meta.h>
16
+
17
+ namespace at {
18
+ namespace native {
19
+ struct TORCH_API structured_adaptive_max_pool3d_backward_out_cpu : public at::meta::structured_adaptive_max_pool3d_backward {
20
+ void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input);
21
+ };
22
+ struct TORCH_API structured_adaptive_max_pool3d_backward_out_cuda : public at::meta::structured_adaptive_max_pool3d_backward {
23
+ void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input);
24
+ };
25
+ } // namespace native
26
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautogradnonfunctional {
19
+
20
+ TORCH_API at::Tensor amin(const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false);
21
+
22
+ } // namespace compositeexplicitautogradnonfunctional
23
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeimplicitautograd {
19
+
20
+ TORCH_API at::Tensor atleast_2d(const at::Tensor & self);
21
+ TORCH_API ::std::vector<at::Tensor> atleast_2d(at::TensorList tensors);
22
+
23
+ } // namespace compositeimplicitautograd
24
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_2d_native.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <c10/util/Optional.h>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor atleast_2d(const at::Tensor & self);
20
+ TORCH_API ::std::vector<at::Tensor> atleast_2d(at::TensorList tensors);
21
+ } // namespace native
22
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautogradnonfunctional {
19
+
20
+ TORCH_API at::Tensor avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional<int64_t> divisor_override);
21
+
22
+ } // namespace compositeexplicitautogradnonfunctional
23
+ } // namespace at