Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py +537 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py +213 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py +212 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py +635 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py +256 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py +182 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_dimV_ops.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_exp.h +44 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_linalg_slogdet_meta_dispatch.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_tensor_from_mask_compositeexplicitautograd_dispatch.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_print_native.h +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sample_dirichlet.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_csr_sum.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_softmax_backward_data_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_standard_gamma_grad_compositeexplicitautograd_dispatch.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_warn_in_autograd_native.h +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_upsample_nearest_exact2d_cuda_dispatch.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/acosh_meta_dispatch.h +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/all_cuda_dispatch.h +31 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atanh_native.h +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_3d_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/binary_cross_entropy.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/binary_cross_entropy_with_logits_native.h +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_xor_ops.h +105 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/clip_ops.h +83 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cosine_similarity_compositeimplicitautograd_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cumprod_backward_native.h +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/digamma.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/divide_native.h +30 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/embedding_dense_backward_compositeexplicitautograd_dispatch.h +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/expm1_compositeexplicitautogradnonfunctional_dispatch.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h +30 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fft_ihfftn_compositeimplicitautograd_dispatch.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/gcd_cuda_dispatch.h +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/grid_sampler_2d.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hardswish_backward.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hstack.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/huber_loss_backward.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lift_fresh.h +30 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_eigh_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_lu_cpu_dispatch.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_pinv_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linear.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/log_sigmoid_forward_cpu_dispatch.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/logaddexp2_ops.h +39 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/reinplace.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import operator
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_functional
|
| 8 |
+
from torch._inductor import inductor_prims
|
| 9 |
+
from torch._inductor.fx_utils import get_node_storage, is_node_realized
|
| 10 |
+
from torch._inductor.lowering import (
|
| 11 |
+
inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings,
|
| 12 |
+
)
|
| 13 |
+
from torch._inductor.virtualized import V
|
| 14 |
+
from torch.fx.immutable_collections import immutable_dict
|
| 15 |
+
from torch.fx.passes.reinplace import _is_view_op
|
| 16 |
+
from torch.utils import _pytree as pytree
|
| 17 |
+
|
| 18 |
+
aten = torch.ops.aten
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(frozen=True)
|
| 22 |
+
class InplaceableOp:
|
| 23 |
+
inplace_op: Callable[..., Any]
|
| 24 |
+
mutated_arg: int
|
| 25 |
+
extra_check: Callable[[torch.fx.Node], bool] = lambda node: True
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_SCATTER_OP_TO_VIEW = {
|
| 29 |
+
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
|
| 30 |
+
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
|
| 31 |
+
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
|
| 32 |
+
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
|
| 33 |
+
}
|
| 34 |
+
_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs):
|
| 38 |
+
fake_args, fake_kwargs = pytree.tree_map(
|
| 39 |
+
lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
|
| 40 |
+
(args, kwargs),
|
| 41 |
+
)
|
| 42 |
+
with V.fake_mode:
|
| 43 |
+
fake_result = fn(*fake_args, **fake_kwargs)
|
| 44 |
+
|
| 45 |
+
node = graph.call_function(fn, args, kwargs)
|
| 46 |
+
node.meta["val"] = fake_result
|
| 47 |
+
return node
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class ViewOp:
|
| 52 |
+
target: torch._ops.OpOverload
|
| 53 |
+
args: Tuple[Any, ...]
|
| 54 |
+
kwargs: Dict[str, Any]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _inplace_generalized_scatter(
|
| 58 |
+
inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp]
|
| 59 |
+
) -> torch.Tensor:
|
| 60 |
+
tmp = inp
|
| 61 |
+
for view in view_ops:
|
| 62 |
+
fake_args, fake_kwargs = pytree.tree_map(
|
| 63 |
+
lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
|
| 64 |
+
(view.args, view.kwargs),
|
| 65 |
+
)
|
| 66 |
+
tmp = view.target(tmp, *fake_args, **fake_kwargs)
|
| 67 |
+
tmp.copy_(src)
|
| 68 |
+
return inp
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _generalized_scatter(
|
| 72 |
+
inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp]
|
| 73 |
+
) -> torch.Tensor:
|
| 74 |
+
out = inp.clone()
|
| 75 |
+
return _inplace_generalized_scatter(out, src, view_ops)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _decompose_scatter_functional_helper(
|
| 79 |
+
graph: torch.fx.Graph,
|
| 80 |
+
inp: torch.Tensor,
|
| 81 |
+
src: torch.Tensor,
|
| 82 |
+
view_ops: List[ViewOp],
|
| 83 |
+
) -> torch.fx.Node:
|
| 84 |
+
view_op, view_ops_tail = view_ops[0], view_ops[1:]
|
| 85 |
+
|
| 86 |
+
if view_ops_tail:
|
| 87 |
+
view = graph_call_function(
|
| 88 |
+
graph, view_op.target, inp, *view_op.args, **view_op.kwargs
|
| 89 |
+
)
|
| 90 |
+
src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment]
|
| 91 |
+
|
| 92 |
+
return graph_call_function(
|
| 93 |
+
graph,
|
| 94 |
+
_VIEW_OP_TO_SCATTER[view_op.target],
|
| 95 |
+
inp,
|
| 96 |
+
src,
|
| 97 |
+
*view_op.args,
|
| 98 |
+
**view_op.kwargs,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _decompose_scatter_functional(
|
| 103 |
+
graph: torch.fx.Graph, node: torch.fx.Node
|
| 104 |
+
) -> torch.fx.Node:
|
| 105 |
+
"""Decompose _generalized_scatter to a sequence of view_scatter operations
|
| 106 |
+
|
| 107 |
+
e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
|
| 108 |
+
|
| 109 |
+
will become
|
| 110 |
+
|
| 111 |
+
view = aten.slice(inp, 0, 0, 10)
|
| 112 |
+
view_updated = aten.slice_scatter(view, src, 1, 10, -10)
|
| 113 |
+
inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10)
|
| 114 |
+
"""
|
| 115 |
+
assert node.target is _generalized_scatter
|
| 116 |
+
inp, src, view_ops = node.args
|
| 117 |
+
return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _decompose_scatter_mutating(
|
| 121 |
+
graph: torch.fx.Graph, node: torch.fx.Node
|
| 122 |
+
) -> torch.fx.Node:
|
| 123 |
+
"""Decompose _generalized_scatter using mutations
|
| 124 |
+
|
| 125 |
+
e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
|
| 126 |
+
|
| 127 |
+
will become
|
| 128 |
+
|
| 129 |
+
inp_updated = aten.clone(inp)
|
| 130 |
+
slice1 = aten.slice(inp_updated, 0, 0, 10)
|
| 131 |
+
slice2 = aten.slice(slice1, 1, 10, -10)
|
| 132 |
+
slice2.copy_(src)
|
| 133 |
+
|
| 134 |
+
"""
|
| 135 |
+
assert node.target in (_generalized_scatter, _inplace_generalized_scatter)
|
| 136 |
+
inp, src, view_ops = node.args
|
| 137 |
+
assert not node.kwargs
|
| 138 |
+
|
| 139 |
+
if node.target is _generalized_scatter:
|
| 140 |
+
inp = graph_call_function(graph, aten.clone, inp)
|
| 141 |
+
|
| 142 |
+
tmp = inp
|
| 143 |
+
for view in view_ops: # type: ignore[union-attr]
|
| 144 |
+
tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr]
|
| 145 |
+
|
| 146 |
+
graph_call_function(graph, aten.copy_.default, tmp, src)
|
| 147 |
+
return inp # type: ignore[return-value]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# View ops whose view_scatter op is lowered into mutations anyway,
|
| 151 |
+
# so is never a pessimisation to decompose.
|
| 152 |
+
_ALWAYS_MUTATING_SCATTER_OPS = {
|
| 153 |
+
aten.as_strided.default,
|
| 154 |
+
aten.diagonal.default,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
|
| 159 |
+
_, _, view_ops = node.args
|
| 160 |
+
return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def should_reinplace_scatter(node: torch.fx.Node) -> bool:
|
| 164 |
+
"""Choose between mutating and functional scatter decompositions
|
| 165 |
+
|
| 166 |
+
Reinplacing view scatter ops can be pessimising as it blocks fusion with the
|
| 167 |
+
input or output tensor computations. However, it is still profitable if the
|
| 168 |
+
input and output would have been realized anyway.
|
| 169 |
+
|
| 170 |
+
"""
|
| 171 |
+
inp, src, view_ops = node.args
|
| 172 |
+
|
| 173 |
+
# Mutating scatter ops unconditionally realize input and output
|
| 174 |
+
if scatter_always_uses_mutation(node):
|
| 175 |
+
return True
|
| 176 |
+
|
| 177 |
+
if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type]
|
| 178 |
+
return True
|
| 179 |
+
|
| 180 |
+
# If the output is copied back into the input, this forces both to be
|
| 181 |
+
# realized as the output is a user of the input
|
| 182 |
+
if inp.op == "placeholder" and any( # type: ignore[union-attr]
|
| 183 |
+
user.target is aten.copy_.default and user.args[0] is inp for user in node.users
|
| 184 |
+
):
|
| 185 |
+
return True
|
| 186 |
+
|
| 187 |
+
# Otherwise, assume fusions will make functional variants profitable
|
| 188 |
+
return False
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def decompose_generalized_scatter(graph: torch.fx.Graph) -> None:
|
| 192 |
+
"""Replace _generalized_scatter with normal aten ops"""
|
| 193 |
+
for node in graph.nodes:
|
| 194 |
+
if node.target not in (_generalized_scatter, _inplace_generalized_scatter):
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
use_mutation = (
|
| 198 |
+
node.target is _inplace_generalized_scatter
|
| 199 |
+
or scatter_always_uses_mutation(node)
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
with graph.inserting_before(node):
|
| 203 |
+
if use_mutation:
|
| 204 |
+
new_node = _decompose_scatter_mutating(graph, node)
|
| 205 |
+
else:
|
| 206 |
+
new_node = _decompose_scatter_functional(graph, node)
|
| 207 |
+
|
| 208 |
+
node.replace_all_uses_with(new_node)
|
| 209 |
+
graph.erase_node(node)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
| 213 |
+
"""
|
| 214 |
+
This canonicalizes view scatter ops into a generalized form, defined as:
|
| 215 |
+
def scatter(inp, src, views):
|
| 216 |
+
tmp = inp.clone()
|
| 217 |
+
for view in views:
|
| 218 |
+
tmp = view(tmp)
|
| 219 |
+
tmp.copy_(src)
|
| 220 |
+
|
| 221 |
+
We also fuse consecutive view scatter ops of the form
|
| 222 |
+
a = scatter(view2(self), src, [view1])
|
| 223 |
+
b = scatter(self, a, [view2])
|
| 224 |
+
which can be rewritten as
|
| 225 |
+
b = scatter(self, src, [view2, view1])
|
| 226 |
+
a = view2(b)
|
| 227 |
+
|
| 228 |
+
This is both more efficient as we only do a single scatter, and also
|
| 229 |
+
easier to reinplace since there is only one use of `self`
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
node_to_view_base: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 233 |
+
node_to_view_op: Dict[torch.fx.Node, List[ViewOp]] = defaultdict(list)
|
| 234 |
+
|
| 235 |
+
def handle_views(node: torch.fx.Node):
|
| 236 |
+
inp = node.args[0]
|
| 237 |
+
node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
|
| 238 |
+
node_to_view_op[node] = [
|
| 239 |
+
*node_to_view_op[inp], # type: ignore[index]
|
| 240 |
+
ViewOp(
|
| 241 |
+
node.target, # type: ignore[arg-type]
|
| 242 |
+
args=node.args[1:],
|
| 243 |
+
kwargs=node.kwargs,
|
| 244 |
+
),
|
| 245 |
+
]
|
| 246 |
+
|
| 247 |
+
def handle_view_scatter(node: torch.fx.Node):
|
| 248 |
+
assert len(node.args) >= 2
|
| 249 |
+
inp, src = node.args[:2]
|
| 250 |
+
|
| 251 |
+
scatter_view_op = ViewOp(
|
| 252 |
+
_SCATTER_OP_TO_VIEW[node.target],
|
| 253 |
+
args=node.args[2:],
|
| 254 |
+
kwargs=node.kwargs,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
def can_fuse():
|
| 258 |
+
if src.target is not _generalized_scatter: # type: ignore[union-attr]
|
| 259 |
+
return False
|
| 260 |
+
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
|
| 261 |
+
|
| 262 |
+
inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
|
| 263 |
+
src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type]
|
| 264 |
+
return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index]
|
| 265 |
+
*node_to_view_op[inp], # type: ignore[index]
|
| 266 |
+
scatter_view_op,
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
if not can_fuse():
|
| 270 |
+
with graph.inserting_before(node):
|
| 271 |
+
new_node = graph_call_function(
|
| 272 |
+
graph,
|
| 273 |
+
_generalized_scatter,
|
| 274 |
+
inp,
|
| 275 |
+
src,
|
| 276 |
+
[scatter_view_op],
|
| 277 |
+
)
|
| 278 |
+
node.replace_all_uses_with(new_node)
|
| 279 |
+
graph.erase_node(node)
|
| 280 |
+
return
|
| 281 |
+
|
| 282 |
+
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
|
| 283 |
+
with graph.inserting_before(src):
|
| 284 |
+
new_node = graph_call_function(
|
| 285 |
+
graph,
|
| 286 |
+
_generalized_scatter,
|
| 287 |
+
inp,
|
| 288 |
+
src_src,
|
| 289 |
+
[scatter_view_op, *src_scatter_view_op], # type: ignore[misc]
|
| 290 |
+
)
|
| 291 |
+
node.replace_all_uses_with(new_node)
|
| 292 |
+
graph.erase_node(node)
|
| 293 |
+
|
| 294 |
+
if src.users: # type: ignore[union-attr]
|
| 295 |
+
new_src = graph_call_function(
|
| 296 |
+
graph,
|
| 297 |
+
_SCATTER_OP_TO_VIEW[node.target],
|
| 298 |
+
new_node,
|
| 299 |
+
*node.args[2:],
|
| 300 |
+
**node.kwargs,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
handle_views(new_src)
|
| 304 |
+
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
|
| 305 |
+
|
| 306 |
+
graph.erase_node(src)
|
| 307 |
+
|
| 308 |
+
for node in graph.nodes:
|
| 309 |
+
if _is_view_op(node.target):
|
| 310 |
+
handle_views(node)
|
| 311 |
+
elif node.target in _SCATTER_OP_TO_VIEW:
|
| 312 |
+
handle_view_scatter(node)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
inplaceable_ops = {
|
| 316 |
+
aten.index_put.default: InplaceableOp(aten.index_put_.default, 0),
|
| 317 |
+
aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0),
|
| 318 |
+
_generalized_scatter: InplaceableOp(
|
| 319 |
+
_inplace_generalized_scatter,
|
| 320 |
+
0,
|
| 321 |
+
extra_check=should_reinplace_scatter,
|
| 322 |
+
),
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
try:
|
| 326 |
+
c10d_functional = torch.ops._c10d_functional
|
| 327 |
+
inplaceable_collective_ops = {
|
| 328 |
+
c10d_functional.all_reduce.default: InplaceableOp(
|
| 329 |
+
c10d_functional.all_reduce_.default, 0
|
| 330 |
+
),
|
| 331 |
+
c10d_functional.all_reduce_coalesced.default: InplaceableOp(
|
| 332 |
+
c10d_functional.all_reduce_coalesced_.default, 0
|
| 333 |
+
),
|
| 334 |
+
}
|
| 335 |
+
inplaceable_ops.update(inplaceable_collective_ops)
|
| 336 |
+
except AttributeError:
|
| 337 |
+
# _c10d_functional ops are only available when torch
|
| 338 |
+
# is built with USE_DISTRIBUTED=1.
|
| 339 |
+
pass
|
| 340 |
+
|
| 341 |
+
inplaceable_foreach_ops: Dict[torch._ops.OpOverload, InplaceableOp] = {}
|
| 342 |
+
for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items():
|
| 343 |
+
inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
inplaceable_triton_ops = {triton_kernel_wrapper_functional}
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# Operators that don't depend on the tensor data
|
| 350 |
+
META_ONLY_OPS = {
|
| 351 |
+
aten.sym_size.int,
|
| 352 |
+
aten.sym_stride.int,
|
| 353 |
+
aten.sym_numel.default,
|
| 354 |
+
aten.sym_storage_offset.default,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
| 359 |
+
"""
|
| 360 |
+
Reinplaces in-placeable operations.
|
| 361 |
+
If there are no uses of a view of the mutated arg after the current node,
|
| 362 |
+
it is possible to inplace the op.
|
| 363 |
+
This above algorithm could be justified by observing side effects. While
|
| 364 |
+
we traverse the graph in forwards direction, only latter nodes could view
|
| 365 |
+
side effects of the current node. If the current node is not used later as
|
| 366 |
+
well as no view of this node is used later in the graph, then it is safe to
|
| 367 |
+
inplace as there would be no way to observe the side effects.
|
| 368 |
+
This condition is slightly different for graph inputs where they can only
|
| 369 |
+
be inplaced if the above condition is true and there's a copy_ in the
|
| 370 |
+
epilogue that signals that the caller wants to observe the mutation.
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
copy_args_to_copy_nodes = {}
|
| 374 |
+
mutated_inputs = set()
|
| 375 |
+
storage_to_nodes = defaultdict(list)
|
| 376 |
+
node_order: Dict[Any, int] = {}
|
| 377 |
+
for i, node in enumerate(reversed(graph.nodes)):
|
| 378 |
+
node_order[node] = len(graph.nodes) - i - 1
|
| 379 |
+
storage_to_nodes[get_node_storage(node)].append(node)
|
| 380 |
+
if node.target == aten.copy_.default and node.args[0].op == "placeholder":
|
| 381 |
+
dst = node.args[0]
|
| 382 |
+
src = node.args[1]
|
| 383 |
+
# If the target is a getitem and it indexes a possible clone,
|
| 384 |
+
# then skip over it
|
| 385 |
+
if src.target == operator.getitem and (
|
| 386 |
+
(
|
| 387 |
+
src.args[0].target == triton_kernel_wrapper_functional
|
| 388 |
+
and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0]
|
| 389 |
+
)
|
| 390 |
+
or (src.args[0].target in inplaceable_foreach_ops)
|
| 391 |
+
or (src.args[0].target == torch.ops.higher_order.auto_functionalized)
|
| 392 |
+
):
|
| 393 |
+
src = src.args[0]
|
| 394 |
+
|
| 395 |
+
copy_args_to_copy_nodes[(dst, src)] = node
|
| 396 |
+
|
| 397 |
+
mutated_inputs.add(node.args[0])
|
| 398 |
+
|
| 399 |
+
def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node):
|
| 400 |
+
node_loc = node_order[node]
|
| 401 |
+
copy_node_loc = node_order[copy_node] if copy_node is not None else None
|
| 402 |
+
|
| 403 |
+
def is_meta_only_user(node):
|
| 404 |
+
if _is_view_op(node.target):
|
| 405 |
+
return all(is_meta_only_user(u) for u in node.users)
|
| 406 |
+
return node.target in META_ONLY_OPS
|
| 407 |
+
|
| 408 |
+
for view in shared_view_nodes:
|
| 409 |
+
for user in view.users:
|
| 410 |
+
user_loc = node_order[user]
|
| 411 |
+
# Skip all users before node
|
| 412 |
+
if user_loc <= node_loc:
|
| 413 |
+
continue
|
| 414 |
+
# Ignore uses after the copy_ epilogue node, where the input
|
| 415 |
+
# has already been mutated anyway
|
| 416 |
+
if copy_node_loc is not None and copy_node_loc <= user_loc:
|
| 417 |
+
continue
|
| 418 |
+
# Reinplacing does not change shape metadata
|
| 419 |
+
if is_meta_only_user(user):
|
| 420 |
+
continue
|
| 421 |
+
return True
|
| 422 |
+
return False
|
| 423 |
+
|
| 424 |
+
def can_inplace(node, mutated_arg):
|
| 425 |
+
if isinstance(mutated_arg, (list, tuple)):
|
| 426 |
+
return all(can_inplace(node, arg) for arg in mutated_arg)
|
| 427 |
+
|
| 428 |
+
if get_node_storage(mutated_arg) is None:
|
| 429 |
+
return False
|
| 430 |
+
shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)]
|
| 431 |
+
if mutated_arg.op == "placeholder":
|
| 432 |
+
if not (
|
| 433 |
+
copy_node := copy_args_to_copy_nodes.get((mutated_arg, node), False)
|
| 434 |
+
):
|
| 435 |
+
return False
|
| 436 |
+
|
| 437 |
+
if any_use_of_views_after_node(
|
| 438 |
+
node, shared_view_nodes, copy_node=copy_node
|
| 439 |
+
):
|
| 440 |
+
return False
|
| 441 |
+
|
| 442 |
+
return True
|
| 443 |
+
elif any(view.op == "placeholder" for view in shared_view_nodes):
|
| 444 |
+
# If mutated arg is view of any of the inputs of the graph,
|
| 445 |
+
# do not allow for inplacing.
|
| 446 |
+
# This would require more sophisticated algorithm to handle
|
| 447 |
+
return False
|
| 448 |
+
else:
|
| 449 |
+
return not any_use_of_views_after_node(
|
| 450 |
+
node, shared_view_nodes, copy_node=None
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 454 |
+
|
| 455 |
+
def reinplace_and_refine_tensors_to_clone(old_tensors_to_clone, kwargs):
|
| 456 |
+
tensors_to_clone: List[str] = []
|
| 457 |
+
for arg in old_tensors_to_clone:
|
| 458 |
+
assert arg in kwargs
|
| 459 |
+
mutated_arg = kwargs[arg]
|
| 460 |
+
if can_inplace(node, mutated_arg):
|
| 461 |
+
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
|
| 462 |
+
if copy_node is not None:
|
| 463 |
+
replace_dict[copy_node] = copy_node.args[0]
|
| 464 |
+
for user in node.users:
|
| 465 |
+
if user.target == operator.getitem and user.args[1] == arg:
|
| 466 |
+
replace_dict[user] = mutated_arg
|
| 467 |
+
else:
|
| 468 |
+
tensors_to_clone.append(arg)
|
| 469 |
+
return tensors_to_clone
|
| 470 |
+
|
| 471 |
+
for node in graph.nodes:
|
| 472 |
+
if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None:
|
| 473 |
+
mutated_arg = node.args[inplaceable_op.mutated_arg]
|
| 474 |
+
if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node):
|
| 475 |
+
# TODO(yifu): this doesn't properly remove copy epilogues for
|
| 476 |
+
# ops that mutate multiple inputs. Need to revise the copy
|
| 477 |
+
# node tracking logic to support the case.
|
| 478 |
+
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
|
| 479 |
+
if copy_node is not None:
|
| 480 |
+
replace_dict[copy_node] = copy_node.args[0]
|
| 481 |
+
node.target = inplaceable_op.inplace_op
|
| 482 |
+
elif node.target == torch.ops.higher_order.auto_functionalized:
|
| 483 |
+
_mutable_op = node.args[0]
|
| 484 |
+
from torch._higher_order_ops.auto_functionalize import get_mutable_arg_names
|
| 485 |
+
|
| 486 |
+
tensors_to_clone = get_mutable_arg_names(_mutable_op)
|
| 487 |
+
# Don't try to reinplace Optional[Tensor] args that are None.
|
| 488 |
+
tensors_to_clone = [
|
| 489 |
+
t for t in tensors_to_clone if node.kwargs[t] is not None
|
| 490 |
+
]
|
| 491 |
+
tensors_to_clone = reinplace_and_refine_tensors_to_clone(
|
| 492 |
+
tensors_to_clone, node.kwargs
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# Stash the metadata. There is a pass later on where we decompose
|
| 496 |
+
# auto_functionalized into clones + a mutable op; this metadata
|
| 497 |
+
# tells the decomp to only clone the following inputs
|
| 498 |
+
node.meta["only_clone_these_tensors"] = tensors_to_clone
|
| 499 |
+
elif node.target in inplaceable_triton_ops:
|
| 500 |
+
# inplaceable_triton_ops take an additional argument called
|
| 501 |
+
# tensors_to_clone which contain a list of tensors to clone
|
| 502 |
+
# This pass iterates over them and sees which ones are safe
|
| 503 |
+
# to eliminate (i.e. no longer need the clones)
|
| 504 |
+
tensors_to_clone = reinplace_and_refine_tensors_to_clone(
|
| 505 |
+
node.kwargs["tensors_to_clone"], node.kwargs["kwargs"]
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
kwargs = dict(node.kwargs)
|
| 509 |
+
kwargs["tensors_to_clone"] = tensors_to_clone
|
| 510 |
+
node.kwargs = immutable_dict(kwargs)
|
| 511 |
+
elif (
|
| 512 |
+
inplaceable_op := inplaceable_foreach_ops.get(node.target, None)
|
| 513 |
+
) is not None:
|
| 514 |
+
mutated_args = node.args[inplaceable_op.mutated_arg]
|
| 515 |
+
|
| 516 |
+
if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args):
|
| 517 |
+
continue
|
| 518 |
+
|
| 519 |
+
if can_inplace(node, mutated_args):
|
| 520 |
+
for arg in mutated_args:
|
| 521 |
+
copy_node = copy_args_to_copy_nodes[(arg, node)]
|
| 522 |
+
replace_dict[copy_node] = copy_node.args[0]
|
| 523 |
+
|
| 524 |
+
node.target = inplaceable_op.inplace_op
|
| 525 |
+
for node, replacement in replace_dict.items():
|
| 526 |
+
while replacement in replace_dict:
|
| 527 |
+
replacement = replace_dict[replacement]
|
| 528 |
+
replace_dict[node] = replacement
|
| 529 |
+
|
| 530 |
+
node.replace_all_uses_with(replacement)
|
| 531 |
+
graph.erase_node(node)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None:
|
| 535 |
+
canonicalize_view_scatter_ops(graph)
|
| 536 |
+
reinplace_inplaceable_ops_core(graph)
|
| 537 |
+
decompose_generalized_scatter(graph)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (246 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-311.pyc
ADDED
|
Binary file (21.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 35 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 36 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 37 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 38 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 39 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 40 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 41 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 42 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 43 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 44 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 45 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
|
| 46 |
+
amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
|
| 47 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
| 48 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 49 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 50 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 51 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 52 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 53 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 54 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 55 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 56 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 57 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 58 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 59 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 60 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 61 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 62 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 63 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
|
| 64 |
+
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
| 65 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 66 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 67 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 68 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 69 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 70 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
|
| 71 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 72 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 73 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 74 |
+
view_default_8 = CallFunction(aten.view.default, sub_Tensor_1, Ignored(), _users=2)
|
| 75 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 76 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 77 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 78 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
|
| 79 |
+
permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
|
| 80 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 81 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 82 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 83 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 84 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 85 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 86 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 87 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 88 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 89 |
+
_sfdp_pattern_10_training = MultiOutputPattern([view_default_5,
|
| 90 |
+
permute_default_6,
|
| 91 |
+
permute_default_9,
|
| 92 |
+
permute_default_11
|
| 93 |
+
])
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 97 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 98 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 99 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 100 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 101 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 102 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 103 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 104 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 105 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 106 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 107 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
|
| 108 |
+
amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
|
| 109 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
| 110 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 111 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 112 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 113 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 114 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 115 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 116 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 117 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 118 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 119 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 120 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 121 |
+
_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 125 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 126 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 127 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 128 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 129 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 130 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 131 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 132 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 133 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 134 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 135 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 136 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
|
| 137 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 138 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 139 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 140 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 141 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 142 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 143 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 144 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 145 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 146 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 147 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 148 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 149 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 150 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 151 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 152 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 153 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 154 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 155 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 156 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 157 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 158 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 159 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 160 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
|
| 161 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 162 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 163 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 164 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 165 |
+
view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2)
|
| 166 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 167 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 168 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 169 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
|
| 170 |
+
permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
|
| 171 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 172 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 173 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 174 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 175 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 176 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 177 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 178 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 179 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 180 |
+
_sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5,
|
| 181 |
+
permute_default_6,
|
| 182 |
+
permute_default_9,
|
| 183 |
+
permute_default_11
|
| 184 |
+
])
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 188 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 189 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 190 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 191 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 192 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 193 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 194 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 195 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 196 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 197 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 198 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 199 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
|
| 200 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 201 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 202 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 203 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 204 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 205 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 206 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 207 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 208 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 209 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 210 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 211 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 212 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 213 |
+
_sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 35 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 36 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 37 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 38 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 39 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 40 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 41 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 42 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 43 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 44 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 45 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
| 46 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 47 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 48 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 49 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 50 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 51 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 52 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 53 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 54 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 55 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 56 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 57 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 58 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 59 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 60 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 61 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 62 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 63 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 64 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 65 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 66 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 67 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
| 68 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 69 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 70 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 71 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
|
| 72 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 73 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 74 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 75 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 76 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 77 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 78 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 79 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 80 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 81 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 82 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 83 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 84 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 85 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 86 |
+
_sfdp_pattern_11_training = MultiOutputPattern([view_default_5,
|
| 87 |
+
permute_default_6,
|
| 88 |
+
permute_default_9,
|
| 89 |
+
permute_default_11,
|
| 90 |
+
None
|
| 91 |
+
])
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 95 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 96 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 97 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 98 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 99 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 100 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 101 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 102 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 103 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 104 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 105 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
| 106 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 107 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 108 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 109 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 110 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 111 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 112 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 113 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 114 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 115 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 116 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 117 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 118 |
+
_sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 122 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 123 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 124 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 125 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 126 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 127 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 128 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 129 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 130 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 131 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 132 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 133 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 134 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 135 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 136 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 137 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 138 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 139 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 140 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 141 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 142 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 143 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 144 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 145 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 146 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 147 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 148 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 149 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 150 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 151 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 152 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 153 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 154 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 155 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 156 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 157 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 158 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
|
| 159 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 160 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
|
| 161 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 162 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 163 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
|
| 164 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 165 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 166 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 167 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 168 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 169 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 170 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 171 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 172 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 173 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 174 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 175 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 176 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 177 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 178 |
+
_sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5,
|
| 179 |
+
permute_default_6,
|
| 180 |
+
permute_default_9,
|
| 181 |
+
permute_default_11,
|
| 182 |
+
None
|
| 183 |
+
])
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 187 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 188 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 189 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 190 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 191 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 192 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 193 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 194 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 195 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 196 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 197 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 198 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 199 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 200 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 201 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 202 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 203 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 204 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 205 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 206 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 207 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 208 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 209 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 210 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 211 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 212 |
+
_sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 35 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 36 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 37 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 38 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 39 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 40 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 41 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 42 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 43 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 44 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 45 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 46 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 47 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 48 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 49 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 50 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 51 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 52 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 53 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 54 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 55 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 56 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 57 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 58 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 59 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 60 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 61 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 62 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 63 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 64 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 65 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 66 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 67 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 68 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 69 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 70 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 71 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 72 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 73 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 74 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 75 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 76 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
| 77 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 78 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 79 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 80 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
|
| 81 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 82 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 83 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 84 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 85 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 86 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 87 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 88 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 89 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 90 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 91 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 92 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 93 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 94 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 95 |
+
_sfdp_pattern_16_training = MultiOutputPattern([view_default_5,
|
| 96 |
+
permute_default_6,
|
| 97 |
+
permute_default_9,
|
| 98 |
+
permute_default_11,
|
| 99 |
+
None,
|
| 100 |
+
None,
|
| 101 |
+
None
|
| 102 |
+
])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 106 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 107 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 108 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 109 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 110 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 111 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 112 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 113 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 114 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 115 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 116 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 117 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 118 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 119 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 120 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 121 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 122 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 123 |
+
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
| 124 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
| 125 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 126 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 127 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 128 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 129 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 130 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 131 |
+
_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 135 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 136 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 137 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 138 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 139 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 140 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 141 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 142 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 143 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 144 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 145 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 146 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 147 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 148 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 149 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 150 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 151 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 152 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 153 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 154 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 155 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 156 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 157 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 158 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 159 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 160 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 161 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 162 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 163 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 164 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 165 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 166 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 167 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 168 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 169 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 170 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 171 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 172 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 173 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
| 174 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 175 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 176 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 177 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
|
| 178 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 179 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 180 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 181 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 182 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 183 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 184 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 185 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 186 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 187 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 188 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 189 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 190 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 191 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 192 |
+
_sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5,
|
| 193 |
+
permute_default_6,
|
| 194 |
+
permute_default_9,
|
| 195 |
+
permute_default_11,
|
| 196 |
+
None,
|
| 197 |
+
None,
|
| 198 |
+
None
|
| 199 |
+
])
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 203 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 204 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 205 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 206 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 207 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 208 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 209 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 210 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 211 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 212 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 213 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 214 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 215 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 216 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 217 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 218 |
+
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
|
| 219 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
| 220 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 221 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 222 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 223 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 224 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 225 |
+
_sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 229 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 230 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 231 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 232 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 233 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 234 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 235 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 236 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 237 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 238 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 239 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 240 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 241 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 242 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 243 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 244 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 245 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 246 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 247 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 248 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 249 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 250 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 251 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 252 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 253 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 254 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 255 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 256 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 257 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 258 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 259 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 260 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 261 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 262 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 263 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 264 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 265 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 266 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 267 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 268 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
|
| 269 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 270 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 271 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 272 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 273 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 274 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
|
| 275 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 276 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
|
| 277 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 278 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 279 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale'))
|
| 280 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 281 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 282 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 283 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 284 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 285 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 286 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 287 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 288 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 289 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 290 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 291 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 292 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 293 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 294 |
+
_sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5,
|
| 295 |
+
permute_default_6,
|
| 296 |
+
permute_default_9,
|
| 297 |
+
permute_default_11,
|
| 298 |
+
None,
|
| 299 |
+
None,
|
| 300 |
+
None
|
| 301 |
+
])
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 305 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 306 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 307 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 308 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 309 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 310 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 311 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 312 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 313 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 314 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 315 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 316 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 317 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 318 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 319 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 320 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 321 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 322 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 323 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 324 |
+
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
|
| 325 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
| 326 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 327 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 328 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 329 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 330 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 331 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 332 |
+
_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 336 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 337 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 338 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 339 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 340 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 341 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 342 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 343 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 344 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 345 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 346 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 347 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 348 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 349 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 350 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 351 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 352 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 353 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 354 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 355 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 356 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 357 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 358 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 359 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 360 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 361 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 362 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 363 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 364 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 365 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 366 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 367 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 368 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 369 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 370 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 371 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 372 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
| 373 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 374 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 375 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 376 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 377 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 378 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
|
| 379 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 380 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
|
| 381 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 382 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 383 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale'))
|
| 384 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 385 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 386 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 387 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 388 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 389 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 390 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 391 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 392 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 393 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 394 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 395 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 396 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 397 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 398 |
+
_sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5,
|
| 399 |
+
permute_default_6,
|
| 400 |
+
permute_default_9,
|
| 401 |
+
permute_default_11,
|
| 402 |
+
None,
|
| 403 |
+
None,
|
| 404 |
+
None
|
| 405 |
+
])
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 409 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 410 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 411 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 412 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 413 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 414 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 415 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 416 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 417 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 418 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 419 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 420 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 421 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 422 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 423 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 424 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 425 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 426 |
+
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
| 427 |
+
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
| 428 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 429 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 430 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 431 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 432 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 433 |
+
_sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 437 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 438 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 439 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 440 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 441 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 442 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 443 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 444 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 445 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 446 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 447 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 448 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 449 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 450 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 451 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 452 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 453 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 454 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 455 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 456 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 457 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 458 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 459 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 460 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 461 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 462 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 463 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 464 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 465 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 466 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 467 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 468 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 469 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 470 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 471 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 472 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 473 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 474 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
|
| 475 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 476 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 477 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 478 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 479 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 480 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
| 481 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 482 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 483 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 484 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 485 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale'))
|
| 486 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 487 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 488 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 489 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 490 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 491 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 492 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 493 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 494 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 495 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 496 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 497 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 498 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 499 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 500 |
+
_sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5,
|
| 501 |
+
permute_default_6,
|
| 502 |
+
permute_default_9,
|
| 503 |
+
permute_default_11,
|
| 504 |
+
None,
|
| 505 |
+
None,
|
| 506 |
+
None
|
| 507 |
+
])
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 511 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 512 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 513 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 514 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 515 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 516 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 517 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 518 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 519 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 520 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 521 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 522 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 523 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 524 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 525 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 526 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 527 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 528 |
+
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
| 529 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
|
| 530 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 531 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 532 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 533 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 534 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 535 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 536 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 537 |
+
_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 541 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 542 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 543 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 544 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 545 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 546 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 547 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 548 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 549 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 550 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 551 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 552 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 553 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 554 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 555 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 556 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 557 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 558 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 559 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 560 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 561 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 562 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 563 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 564 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 565 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 566 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 567 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 568 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 569 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 570 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 571 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 572 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 573 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 574 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 575 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
|
| 576 |
+
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 577 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 578 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 579 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 580 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 581 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
| 582 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 583 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 584 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 585 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 586 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale'))
|
| 587 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 588 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 589 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 590 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 591 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 592 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 593 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 594 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 595 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 596 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 597 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 598 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 599 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 600 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 601 |
+
_sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5,
|
| 602 |
+
permute_default_6,
|
| 603 |
+
permute_default_9,
|
| 604 |
+
permute_default_11,
|
| 605 |
+
None,
|
| 606 |
+
None,
|
| 607 |
+
None
|
| 608 |
+
])
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 612 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 613 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 614 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 615 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 616 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 617 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 618 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 619 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 620 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 621 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 622 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 623 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 624 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 625 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 626 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 627 |
+
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
|
| 628 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
| 629 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 630 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 631 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 632 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 633 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 634 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 635 |
+
_sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 35 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 36 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 37 |
+
expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
|
| 38 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 39 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 40 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 41 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 42 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 43 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 44 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 45 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 46 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 47 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 48 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 49 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 50 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 51 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
|
| 52 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 53 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 54 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 55 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 56 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 57 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 58 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 59 |
+
expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 60 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 61 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 62 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 63 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 64 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 65 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 66 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 67 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 68 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 69 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 70 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 71 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 72 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 73 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 74 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 75 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 76 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 77 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 78 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 79 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 80 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
| 81 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 82 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 83 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 84 |
+
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, sub_Tensor_1)
|
| 85 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
|
| 86 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 87 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 88 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 89 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 90 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 91 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 92 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 93 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 94 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 95 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 96 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 97 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 98 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 99 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 100 |
+
_sfdp_pattern_17_training = MultiOutputPattern([view_default_5,
|
| 101 |
+
permute_default_6,
|
| 102 |
+
permute_default_9,
|
| 103 |
+
permute_default_11,
|
| 104 |
+
None,
|
| 105 |
+
None,
|
| 106 |
+
None
|
| 107 |
+
])
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 111 |
+
view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
|
| 112 |
+
expand_default = CallFunction(aten.expand.default, view_default, Ignored())
|
| 113 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 114 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 115 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 116 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 117 |
+
view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
|
| 118 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 119 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 120 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 121 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 122 |
+
view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 123 |
+
bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
|
| 124 |
+
view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 125 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
|
| 126 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
|
| 127 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 128 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 129 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 130 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 131 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 132 |
+
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
| 133 |
+
expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
| 134 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 135 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 136 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 137 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 138 |
+
view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 139 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
|
| 140 |
+
_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 144 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 145 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 146 |
+
expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
|
| 147 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 148 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 149 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 150 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 151 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 152 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 153 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 154 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 155 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 156 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 157 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 158 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 159 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 160 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
|
| 161 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 162 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 163 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 164 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 165 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 166 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 167 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 168 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 169 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 170 |
+
expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 171 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 172 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 173 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 174 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 175 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 176 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 177 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 178 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 179 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 180 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 181 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 182 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 183 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 184 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 185 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 186 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 187 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
|
| 188 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 189 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 190 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 191 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 192 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 193 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2)
|
| 194 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 195 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1)
|
| 196 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 197 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 198 |
+
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5)
|
| 199 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
|
| 200 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 201 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 202 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 203 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 204 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 205 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 206 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 207 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 208 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 209 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 210 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 211 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 212 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 213 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 214 |
+
_sfdp_pattern_17_half_training = MultiOutputPattern([view_default_5,
|
| 215 |
+
permute_default_6,
|
| 216 |
+
permute_default_9,
|
| 217 |
+
permute_default_11,
|
| 218 |
+
None,
|
| 219 |
+
None,
|
| 220 |
+
None
|
| 221 |
+
])
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 225 |
+
view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
|
| 226 |
+
expand_default = CallFunction(aten.expand.default, view_default, Ignored())
|
| 227 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 228 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 229 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 230 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 231 |
+
view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
|
| 232 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 233 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 234 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 235 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 236 |
+
view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 237 |
+
bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
|
| 238 |
+
view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 239 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
|
| 240 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
|
| 241 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 242 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 243 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 244 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 245 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 246 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 247 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 248 |
+
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
|
| 249 |
+
expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
| 250 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 251 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 252 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 253 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 254 |
+
view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 255 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
|
| 256 |
+
_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 35 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 36 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 37 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 38 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 39 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 40 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 41 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
|
| 42 |
+
amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
|
| 43 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
| 44 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 45 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 46 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 47 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 48 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 49 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 50 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 51 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 52 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 53 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 54 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 55 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 56 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 57 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor)
|
| 58 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 59 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 60 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 61 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
| 62 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
|
| 63 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 64 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_1, mul_Tensor_2)
|
| 65 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, sub_Tensor_1, KeywordArg('scale_factor'))
|
| 66 |
+
view_default_8 = CallFunction(aten.view.default, mul_Tensor_3, Ignored(), _users=2)
|
| 67 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 68 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 69 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 70 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 71 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 72 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 73 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 74 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 75 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 76 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 77 |
+
_sfdp_pattern_2_training = MultiOutputPattern([view_default_5,
|
| 78 |
+
view_default_9,
|
| 79 |
+
permute_default_4,
|
| 80 |
+
view_default_11,
|
| 81 |
+
None
|
| 82 |
+
])
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 86 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 87 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 88 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 89 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 90 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 91 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 92 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
|
| 93 |
+
amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
|
| 94 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
| 95 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 96 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 97 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 98 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 99 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 100 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 101 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 102 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 103 |
+
_sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 107 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 108 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 109 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 110 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 111 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 112 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 113 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
|
| 114 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
|
| 115 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 116 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 117 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 118 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 119 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 120 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 121 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 122 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 123 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 124 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 125 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 126 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 127 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 128 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 129 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 130 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 131 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 132 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 133 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 134 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 135 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 136 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 137 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
|
| 138 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
|
| 139 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
|
| 140 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_1, mul_Tensor_2)
|
| 141 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 142 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor'))
|
| 143 |
+
view_default_8 = CallFunction(aten.view.default, mul_Tensor_3, Ignored(), _users=2)
|
| 144 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 145 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 146 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 147 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 148 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 149 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 150 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 151 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 152 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 153 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 154 |
+
_sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5,
|
| 155 |
+
view_default_9,
|
| 156 |
+
permute_default_4,
|
| 157 |
+
view_default_11,
|
| 158 |
+
None
|
| 159 |
+
])
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 163 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 164 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 165 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 166 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 167 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 168 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 169 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
|
| 170 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
|
| 171 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 172 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 173 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 174 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 175 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 176 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
| 177 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 178 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 179 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 180 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 181 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 182 |
+
_sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_dimV_ops.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _dimV {
|
| 18 |
+
using schema = int64_t (const at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_dimV")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_dimV(Tensor self) -> int")
|
| 24 |
+
static int64_t call(const at::Tensor & self);
|
| 25 |
+
static int64_t redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_exp.h
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_foreach_exp_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_foreach_exp(Tensor[] self) -> Tensor[]
|
| 26 |
+
inline ::std::vector<at::Tensor> _foreach_exp(at::TensorList self) {
|
| 27 |
+
return at::_ops::_foreach_exp::call(self);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::_foreach_exp_(Tensor(a!)[] self) -> ()
|
| 31 |
+
inline void _foreach_exp_(at::TensorList self) {
|
| 32 |
+
return at::_ops::_foreach_exp_::call(self);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
// aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
|
| 36 |
+
inline void _foreach_exp_out(at::TensorList out, at::TensorList self) {
|
| 37 |
+
return at::_ops::_foreach_exp_out::call(self, out);
|
| 38 |
+
}
|
| 39 |
+
// aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
|
| 40 |
+
inline void _foreach_exp_outf(at::TensorList self, at::TensorList out) {
|
| 41 |
+
return at::_ops::_foreach_exp_out::call(self, out);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_linalg_slogdet_meta_dispatch.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace meta {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> _linalg_slogdet(const at::Tensor & A);
|
| 21 |
+
TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _linalg_slogdet_out(at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A);
|
| 22 |
+
TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _linalg_slogdet_outf(const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots);
|
| 23 |
+
|
| 24 |
+
} // namespace meta
|
| 25 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_tensor_from_mask_compositeexplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor & _nested_tensor_from_mask_out(at::Tensor & out, const at::Tensor & t, const at::Tensor & mask, bool mask_check=true);
|
| 21 |
+
TORCH_API at::Tensor & _nested_tensor_from_mask_outf(const at::Tensor & t, const at::Tensor & mask, bool mask_check, at::Tensor & out);
|
| 22 |
+
|
| 23 |
+
} // namespace compositeexplicitautograd
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_print_native.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API void _print(c10::string_view s);
|
| 20 |
+
} // namespace native
|
| 21 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sample_dirichlet.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_sample_dirichlet_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor
|
| 26 |
+
inline at::Tensor _sample_dirichlet(const at::Tensor & self, c10::optional<at::Generator> generator=c10::nullopt) {
|
| 27 |
+
return at::_ops::_sample_dirichlet::call(self, generator);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & _sample_dirichlet_out(at::Tensor & out, const at::Tensor & self, c10::optional<at::Generator> generator=c10::nullopt) {
|
| 32 |
+
return at::_ops::_sample_dirichlet_out::call(self, generator, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & _sample_dirichlet_outf(const at::Tensor & self, c10::optional<at::Generator> generator, at::Tensor & out) {
|
| 36 |
+
return at::_ops::_sample_dirichlet_out::call(self, generator, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_csr_sum.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_sparse_csr_sum_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
| 26 |
+
inline at::Tensor _sparse_csr_sum(const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, c10::optional<at::ScalarType> dtype=c10::nullopt) {
|
| 27 |
+
return at::_ops::_sparse_csr_sum_dim_dtype::call(self, dim, keepdim, dtype);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & _sparse_csr_sum_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, c10::optional<at::ScalarType> dtype=c10::nullopt) {
|
| 32 |
+
return at::_ops::_sparse_csr_sum_dim_dtype_out::call(self, dim, keepdim, dtype, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & _sparse_csr_sum_outf(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, c10::optional<at::ScalarType> dtype, at::Tensor & out) {
|
| 36 |
+
return at::_ops::_sparse_csr_sum_dim_dtype_out::call(self, dim, keepdim, dtype, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_softmax_backward_data_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _sparse_softmax_backward_data {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, const at::Tensor &, int64_t, const at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_softmax_backward_data")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _sparse_softmax_backward_data_out {
|
| 29 |
+
using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, int64_t, const at::Tensor &, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_softmax_backward_data")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_standard_gamma_grad_compositeexplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor & _standard_gamma_grad_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & output);
|
| 21 |
+
TORCH_API at::Tensor & _standard_gamma_grad_outf(const at::Tensor & self, const at::Tensor & output, at::Tensor & out);
|
| 22 |
+
|
| 23 |
+
} // namespace compositeexplicitautograd
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_warn_in_autograd_native.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor _test_warn_in_autograd(const at::Tensor & self);
|
| 20 |
+
TORCH_API at::Tensor & _test_warn_in_autograd_out(const at::Tensor & self, at::Tensor & out);
|
| 21 |
+
} // namespace native
|
| 22 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_upsample_nearest_exact2d_cuda_dispatch.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cuda {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor _upsample_nearest_exact2d(const at::Tensor & self, at::IntArrayRef output_size, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 21 |
+
TORCH_API at::Tensor _upsample_nearest_exact2d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 22 |
+
TORCH_API at::Tensor & _upsample_nearest_exact2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 23 |
+
TORCH_API at::Tensor & _upsample_nearest_exact2d_outf(const at::Tensor & self, at::IntArrayRef output_size, c10::optional<double> scales_h, c10::optional<double> scales_w, at::Tensor & out);
|
| 24 |
+
TORCH_API at::Tensor & _upsample_nearest_exact2d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 25 |
+
TORCH_API at::Tensor & _upsample_nearest_exact2d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional<double> scales_h, c10::optional<double> scales_w, at::Tensor & out);
|
| 26 |
+
|
| 27 |
+
} // namespace cuda
|
| 28 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace meta {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor _upsample_nearest_exact3d_backward(const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional<double> scales_d=c10::nullopt, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 21 |
+
TORCH_API at::Tensor _upsample_nearest_exact3d_backward_symint(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional<double> scales_d=c10::nullopt, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 22 |
+
TORCH_API at::Tensor & _upsample_nearest_exact3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional<double> scales_d=c10::nullopt, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 23 |
+
TORCH_API at::Tensor & _upsample_nearest_exact3d_backward_outf(const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional<double> scales_d, c10::optional<double> scales_h, c10::optional<double> scales_w, at::Tensor & grad_input);
|
| 24 |
+
TORCH_API at::Tensor & _upsample_nearest_exact3d_backward_symint_out(at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional<double> scales_d=c10::nullopt, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 25 |
+
TORCH_API at::Tensor & _upsample_nearest_exact3d_backward_symint_outf(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional<double> scales_d, c10::optional<double> scales_h, c10::optional<double> scales_w, at::Tensor & grad_input);
|
| 26 |
+
|
| 27 |
+
} // namespace meta
|
| 28 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/acosh_meta_dispatch.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace meta {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor acosh(const at::Tensor & self);
|
| 21 |
+
TORCH_API at::Tensor & acosh_out(at::Tensor & out, const at::Tensor & self);
|
| 22 |
+
TORCH_API at::Tensor & acosh_outf(const at::Tensor & self, at::Tensor & out);
|
| 23 |
+
TORCH_API at::Tensor & acosh_(at::Tensor & self);
|
| 24 |
+
|
| 25 |
+
} // namespace meta
|
| 26 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/all_cuda_dispatch.h
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cuda {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor all(const at::Tensor & self, int64_t dim, bool keepdim=false);
|
| 21 |
+
TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false);
|
| 22 |
+
TORCH_API at::Tensor & all_outf(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out);
|
| 23 |
+
TORCH_API at::Tensor all(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false);
|
| 24 |
+
TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false);
|
| 25 |
+
TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out);
|
| 26 |
+
TORCH_API at::Tensor all(const at::Tensor & self);
|
| 27 |
+
TORCH_API at::Tensor & all_out(at::Tensor & out, const at::Tensor & self);
|
| 28 |
+
TORCH_API at::Tensor & all_outf(const at::Tensor & self, at::Tensor & out);
|
| 29 |
+
|
| 30 |
+
} // namespace cuda
|
| 31 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atanh_native.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
#include <ATen/ops/atanh_meta.h>
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
struct TORCH_API structured_atanh_out : public at::meta::structured_atanh {
|
| 20 |
+
void impl(const at::Tensor & self, const at::Tensor & out);
|
| 21 |
+
};
|
| 22 |
+
TORCH_API at::Tensor atanh_sparse(const at::Tensor & self);
|
| 23 |
+
TORCH_API at::Tensor & atanh_sparse_out(const at::Tensor & self, at::Tensor & out);
|
| 24 |
+
TORCH_API at::Tensor & atanh_sparse_(at::Tensor & self);
|
| 25 |
+
TORCH_API at::Tensor atanh_sparse_csr(const at::Tensor & self);
|
| 26 |
+
TORCH_API at::Tensor & atanh_sparse_csr_out(const at::Tensor & self, at::Tensor & out);
|
| 27 |
+
TORCH_API at::Tensor & atanh_sparse_csr_(at::Tensor & self);
|
| 28 |
+
} // namespace native
|
| 29 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/atleast_3d_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API atleast_3d {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::atleast_3d")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "atleast_3d(Tensor self) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API atleast_3d_Sequence {
|
| 29 |
+
using schema = ::std::vector<at::Tensor> (at::TensorList);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::atleast_3d")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Sequence")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]")
|
| 35 |
+
static ::std::vector<at::Tensor> call(at::TensorList tensors);
|
| 36 |
+
static ::std::vector<at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/binary_cross_entropy.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/binary_cross_entropy_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
|
| 26 |
+
inline at::Tensor binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight={}, int64_t reduction=at::Reduction::Mean) {
|
| 27 |
+
return at::_ops::binary_cross_entropy::call(self, target, weight, reduction);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & binary_cross_entropy_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight={}, int64_t reduction=at::Reduction::Mean) {
|
| 32 |
+
return at::_ops::binary_cross_entropy_out::call(self, target, weight, reduction, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & binary_cross_entropy_outf(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, at::Tensor & out) {
|
| 36 |
+
return at::_ops::binary_cross_entropy_out::call(self, target, weight, reduction, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/binary_cross_entropy_with_logits_native.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor binary_cross_entropy_with_logits(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight={}, const c10::optional<at::Tensor> & pos_weight={}, int64_t reduction=at::Reduction::Mean);
|
| 20 |
+
TORCH_API at::Tensor & binary_cross_entropy_with_logits_out(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & pos_weight, int64_t reduction, at::Tensor & out);
|
| 21 |
+
} // namespace native
|
| 22 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/bitwise_xor_ops.h
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API bitwise_xor_Tensor_out {
|
| 18 |
+
using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor_out")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)")
|
| 24 |
+
static at::Tensor & call(const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
|
| 25 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API bitwise_xor_Scalar_out {
|
| 29 |
+
using schema = at::Tensor & (const at::Tensor &, const at::Scalar &, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar_out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(const at::Tensor & self, const at::Scalar & other, at::Tensor & out);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
struct TORCH_API bitwise_xor_Scalar {
|
| 40 |
+
using schema = at::Tensor (const at::Tensor &, const at::Scalar &);
|
| 41 |
+
using ptr_schema = schema*;
|
| 42 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 43 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
|
| 44 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar")
|
| 45 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor")
|
| 46 |
+
static at::Tensor call(const at::Tensor & self, const at::Scalar & other);
|
| 47 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
struct TORCH_API bitwise_xor_Scalar_Tensor {
|
| 51 |
+
using schema = at::Tensor (const at::Scalar &, const at::Tensor &);
|
| 52 |
+
using ptr_schema = schema*;
|
| 53 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 54 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
|
| 55 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar_Tensor")
|
| 56 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor")
|
| 57 |
+
static at::Tensor call(const at::Scalar & self, const at::Tensor & other);
|
| 58 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other);
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
struct TORCH_API bitwise_xor_Tensor {
|
| 62 |
+
using schema = at::Tensor (const at::Tensor &, const at::Tensor &);
|
| 63 |
+
using ptr_schema = schema*;
|
| 64 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 65 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
|
| 66 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
|
| 67 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor")
|
| 68 |
+
static at::Tensor call(const at::Tensor & self, const at::Tensor & other);
|
| 69 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other);
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
struct TORCH_API bitwise_xor__Scalar {
|
| 73 |
+
using schema = at::Tensor & (at::Tensor &, const at::Scalar &);
|
| 74 |
+
using ptr_schema = schema*;
|
| 75 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 76 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor_")
|
| 77 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar")
|
| 78 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)")
|
| 79 |
+
static at::Tensor & call(at::Tensor & self, const at::Scalar & other);
|
| 80 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other);
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
struct TORCH_API bitwise_xor__Tensor {
|
| 84 |
+
using schema = at::Tensor & (at::Tensor &, const at::Tensor &);
|
| 85 |
+
using ptr_schema = schema*;
|
| 86 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 87 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor_")
|
| 88 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
|
| 89 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)")
|
| 90 |
+
static at::Tensor & call(at::Tensor & self, const at::Tensor & other);
|
| 91 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other);
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
struct TORCH_API bitwise_xor_Scalar_Tensor_out {
|
| 95 |
+
using schema = at::Tensor & (const at::Scalar &, const at::Tensor &, at::Tensor &);
|
| 96 |
+
using ptr_schema = schema*;
|
| 97 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 98 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::bitwise_xor")
|
| 99 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Scalar_Tensor_out")
|
| 100 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)")
|
| 101 |
+
static at::Tensor & call(const at::Scalar & self, const at::Tensor & other, at::Tensor & out);
|
| 102 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out);
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/clip_ops.h
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API clip {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, const c10::optional<at::Scalar> &, const c10::optional<at::Scalar> &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API clip_Tensor {
|
| 29 |
+
using schema = at::Tensor (const at::Tensor &, const c10::optional<at::Tensor> &, const c10::optional<at::Tensor> &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor")
|
| 35 |
+
static at::Tensor call(const at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max);
|
| 36 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
struct TORCH_API clip_ {
|
| 40 |
+
using schema = at::Tensor & (at::Tensor &, const c10::optional<at::Scalar> &, const c10::optional<at::Scalar> &);
|
| 41 |
+
using ptr_schema = schema*;
|
| 42 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 43 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip_")
|
| 44 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 45 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)")
|
| 46 |
+
static at::Tensor & call(at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max);
|
| 47 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
struct TORCH_API clip__Tensor {
|
| 51 |
+
using schema = at::Tensor & (at::Tensor &, const c10::optional<at::Tensor> &, const c10::optional<at::Tensor> &);
|
| 52 |
+
using ptr_schema = schema*;
|
| 53 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 54 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip_")
|
| 55 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
|
| 56 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)")
|
| 57 |
+
static at::Tensor & call(at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max);
|
| 58 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max);
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
struct TORCH_API clip_out {
|
| 62 |
+
using schema = at::Tensor & (const at::Tensor &, const c10::optional<at::Scalar> &, const c10::optional<at::Scalar> &, at::Tensor &);
|
| 63 |
+
using ptr_schema = schema*;
|
| 64 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 65 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip")
|
| 66 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 67 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)")
|
| 68 |
+
static at::Tensor & call(const at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max, at::Tensor & out);
|
| 69 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional<at::Scalar> & min, const c10::optional<at::Scalar> & max, at::Tensor & out);
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
struct TORCH_API clip_Tensor_out {
|
| 73 |
+
using schema = at::Tensor & (const at::Tensor &, const c10::optional<at::Tensor> &, const c10::optional<at::Tensor> &, at::Tensor &);
|
| 74 |
+
using ptr_schema = schema*;
|
| 75 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 76 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::clip")
|
| 77 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor_out")
|
| 78 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)")
|
| 79 |
+
static at::Tensor & call(const at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max, at::Tensor & out);
|
| 80 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional<at::Tensor> & min, const c10::optional<at::Tensor> & max, at::Tensor & out);
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cosine_similarity_compositeimplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeimplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor cosine_similarity(const at::Tensor & x1, const at::Tensor & x2, int64_t dim=1, double eps=1e-08);
|
| 21 |
+
|
| 22 |
+
} // namespace compositeimplicitautograd
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cumprod_backward_native.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor cumprod_backward(const at::Tensor & grad, const at::Tensor & input, int64_t dim, const at::Tensor & output);
|
| 20 |
+
} // namespace native
|
| 21 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/digamma.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/digamma_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
| 26 |
+
inline at::Tensor & digamma_out(at::Tensor & out, const at::Tensor & self) {
|
| 27 |
+
return at::_ops::digamma_out::call(self, out);
|
| 28 |
+
}
|
| 29 |
+
// aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
| 30 |
+
inline at::Tensor & digamma_outf(const at::Tensor & self, at::Tensor & out) {
|
| 31 |
+
return at::_ops::digamma_out::call(self, out);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// aten::digamma(Tensor self) -> Tensor
|
| 35 |
+
inline at::Tensor digamma(const at::Tensor & self) {
|
| 36 |
+
return at::_ops::digamma::call(self);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/divide_native.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor divide(const at::Tensor & self, const at::Tensor & other);
|
| 20 |
+
TORCH_API at::Tensor & divide_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
|
| 21 |
+
TORCH_API at::Tensor & divide_(at::Tensor & self, const at::Tensor & other);
|
| 22 |
+
TORCH_API at::Tensor divide(const at::Tensor & self, const at::Scalar & other);
|
| 23 |
+
TORCH_API at::Tensor & divide_(at::Tensor & self, const at::Scalar & other);
|
| 24 |
+
TORCH_API at::Tensor divide(const at::Tensor & self, const at::Tensor & other, c10::optional<c10::string_view> rounding_mode);
|
| 25 |
+
TORCH_API at::Tensor & divide_out(const at::Tensor & self, const at::Tensor & other, c10::optional<c10::string_view> rounding_mode, at::Tensor & out);
|
| 26 |
+
TORCH_API at::Tensor & divide_(at::Tensor & self, const at::Tensor & other, c10::optional<c10::string_view> rounding_mode);
|
| 27 |
+
TORCH_API at::Tensor divide(const at::Tensor & self, const at::Scalar & other, c10::optional<c10::string_view> rounding_mode);
|
| 28 |
+
TORCH_API at::Tensor & divide_(at::Tensor & self, const at::Scalar & other, c10::optional<c10::string_view> rounding_mode);
|
| 29 |
+
} // namespace native
|
| 30 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/embedding_dense_backward_compositeexplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor & embedding_dense_backward_out(at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq);
|
| 21 |
+
TORCH_API at::Tensor & embedding_dense_backward_outf(const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, at::Tensor & out);
|
| 22 |
+
TORCH_API at::Tensor & embedding_dense_backward_symint_out(at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq);
|
| 23 |
+
TORCH_API at::Tensor & embedding_dense_backward_symint_outf(const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, at::Tensor & out);
|
| 24 |
+
|
| 25 |
+
} // namespace compositeexplicitautograd
|
| 26 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/expm1_compositeexplicitautogradnonfunctional_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautogradnonfunctional {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor expm1(const at::Tensor & self);
|
| 21 |
+
TORCH_API at::Tensor & expm1_(at::Tensor & self);
|
| 22 |
+
|
| 23 |
+
} // namespace compositeexplicitautogradnonfunctional
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor
|
| 26 |
+
inline at::Tensor fake_quantize_per_channel_affine_cachemask_backward(const at::Tensor & grad, const at::Tensor & mask) {
|
| 27 |
+
return at::_ops::fake_quantize_per_channel_affine_cachemask_backward::call(grad, mask);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/fft_ihfftn_compositeimplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeimplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor fft_ihfftn(const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional<c10::string_view> norm=c10::nullopt);
|
| 21 |
+
TORCH_API at::Tensor fft_ihfftn_symint(const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional<c10::string_view> norm=c10::nullopt);
|
| 22 |
+
TORCH_API const at::Tensor & fft_ihfftn_out(const at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional<c10::string_view> norm=c10::nullopt);
|
| 23 |
+
TORCH_API const at::Tensor & fft_ihfftn_outf(const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional<c10::string_view> norm, const at::Tensor & out);
|
| 24 |
+
TORCH_API const at::Tensor & fft_ihfftn_symint_out(const at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional<c10::string_view> norm=c10::nullopt);
|
| 25 |
+
TORCH_API const at::Tensor & fft_ihfftn_symint_outf(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional<c10::string_view> norm, const at::Tensor & out);
|
| 26 |
+
|
| 27 |
+
} // namespace compositeimplicitautograd
|
| 28 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/gcd_cuda_dispatch.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cuda {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor gcd(const at::Tensor & self, const at::Tensor & other);
|
| 21 |
+
TORCH_API at::Tensor & gcd_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other);
|
| 22 |
+
TORCH_API at::Tensor & gcd_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
|
| 23 |
+
TORCH_API at::Tensor & gcd_(at::Tensor & self, const at::Tensor & other);
|
| 24 |
+
|
| 25 |
+
} // namespace cuda
|
| 26 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/grid_sampler_2d.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/grid_sampler_2d_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
|
| 26 |
+
inline at::Tensor grid_sampler_2d(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
|
| 27 |
+
return at::_ops::grid_sampler_2d::call(input, grid, interpolation_mode, padding_mode, align_corners);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & grid_sampler_2d_out(at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
|
| 32 |
+
return at::_ops::grid_sampler_2d_out::call(input, grid, interpolation_mode, padding_mode, align_corners, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & grid_sampler_2d_outf(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) {
|
| 36 |
+
return at::_ops::grid_sampler_2d_out::call(input, grid, interpolation_mode, padding_mode, align_corners, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hardswish_backward.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/hardswish_backward_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::hardswish_backward(Tensor grad_output, Tensor self) -> Tensor
|
| 26 |
+
inline at::Tensor hardswish_backward(const at::Tensor & grad_output, const at::Tensor & self) {
|
| 27 |
+
return at::_ops::hardswish_backward::call(grad_output, self);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & hardswish_backward_out(at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) {
|
| 32 |
+
return at::_ops::hardswish_backward_out::call(grad_output, self, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & hardswish_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) {
|
| 36 |
+
return at::_ops::hardswish_backward_out::call(grad_output, self, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/hstack.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/hstack_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::hstack(Tensor[] tensors) -> Tensor
|
| 26 |
+
inline at::Tensor hstack(at::TensorList tensors) {
|
| 27 |
+
return at::_ops::hstack::call(tensors);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & hstack_out(at::Tensor & out, at::TensorList tensors) {
|
| 32 |
+
return at::_ops::hstack_out::call(tensors, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & hstack_outf(at::TensorList tensors, at::Tensor & out) {
|
| 36 |
+
return at::_ops::hstack_out::call(tensors, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/huber_loss_backward.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/huber_loss_backward_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!)
|
| 26 |
+
inline at::Tensor & huber_loss_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) {
|
| 27 |
+
return at::_ops::huber_loss_backward_out::call(grad_output, self, target, reduction, delta, grad_input);
|
| 28 |
+
}
|
| 29 |
+
// aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!)
|
| 30 |
+
inline at::Tensor & huber_loss_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & grad_input) {
|
| 31 |
+
return at::_ops::huber_loss_backward_out::call(grad_output, self, target, reduction, delta, grad_input);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// aten::huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor
|
| 35 |
+
inline at::Tensor huber_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) {
|
| 36 |
+
return at::_ops::huber_loss_backward::call(grad_output, self, target, reduction, delta);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/lift_fresh.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/lift_fresh_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::lift_fresh(Tensor(a) self) -> Tensor(a)
|
| 26 |
+
inline at::Tensor lift_fresh(const at::Tensor & self) {
|
| 27 |
+
return at::_ops::lift_fresh::call(self);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_eigh_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API linalg_eigh {
|
| 18 |
+
using schema = ::std::tuple<at::Tensor,at::Tensor> (const at::Tensor &, c10::string_view);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::linalg_eigh")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "linalg_eigh(Tensor self, str UPLO=\"L\") -> (Tensor eigenvalues, Tensor eigenvectors)")
|
| 24 |
+
static ::std::tuple<at::Tensor,at::Tensor> call(const at::Tensor & self, c10::string_view UPLO);
|
| 25 |
+
static ::std::tuple<at::Tensor,at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API linalg_eigh_eigvals {
|
| 29 |
+
using schema = ::std::tuple<at::Tensor &,at::Tensor &> (const at::Tensor &, c10::string_view, at::Tensor &, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::linalg_eigh")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "eigvals")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "linalg_eigh.eigvals(Tensor self, str UPLO=\"L\", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)")
|
| 35 |
+
static ::std::tuple<at::Tensor &,at::Tensor &> call(const at::Tensor & self, c10::string_view UPLO, at::Tensor & eigvals, at::Tensor & eigvecs);
|
| 36 |
+
static ::std::tuple<at::Tensor &,at::Tensor &> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO, at::Tensor & eigvals, at::Tensor & eigvecs);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_lu_cpu_dispatch.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cpu {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor> linalg_lu(const at::Tensor & A, bool pivot=true);
|
| 21 |
+
TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> linalg_lu_out(at::Tensor & P, at::Tensor & L, at::Tensor & U, const at::Tensor & A, bool pivot=true);
|
| 22 |
+
TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> linalg_lu_outf(const at::Tensor & A, bool pivot, at::Tensor & P, at::Tensor & L, at::Tensor & U);
|
| 23 |
+
|
| 24 |
+
} // namespace cpu
|
| 25 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linalg_pinv_compositeexplicitautogradnonfunctional_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautogradnonfunctional {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor linalg_pinv(const at::Tensor & self, const c10::optional<at::Tensor> & atol={}, const c10::optional<at::Tensor> & rtol={}, bool hermitian=false);
|
| 21 |
+
|
| 22 |
+
} // namespace compositeexplicitautogradnonfunctional
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/linear.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/linear_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
|
| 26 |
+
inline at::Tensor linear(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias={}) {
|
| 27 |
+
return at::_ops::linear::call(input, weight, bias);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & linear_out(at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias={}) {
|
| 32 |
+
return at::_ops::linear_out::call(input, weight, bias, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & linear_outf(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::Tensor & out) {
|
| 36 |
+
return at::_ops::linear_out::call(input, weight, bias, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/log_sigmoid_forward_cpu_dispatch.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cpu {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::tuple<at::Tensor,at::Tensor> log_sigmoid_forward(const at::Tensor & self);
|
| 21 |
+
TORCH_API ::std::tuple<at::Tensor &,at::Tensor &> log_sigmoid_forward_out(at::Tensor & output, at::Tensor & buffer, const at::Tensor & self);
|
| 22 |
+
TORCH_API ::std::tuple<at::Tensor &,at::Tensor &> log_sigmoid_forward_outf(const at::Tensor & self, at::Tensor & output, at::Tensor & buffer);
|
| 23 |
+
|
| 24 |
+
} // namespace cpu
|
| 25 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/logaddexp2_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API logaddexp2_out {
|
| 18 |
+
using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::logaddexp2")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)")
|
| 24 |
+
static at::Tensor & call(const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
|
| 25 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API logaddexp2 {
|
| 29 |
+
using schema = at::Tensor (const at::Tensor &, const at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::logaddexp2")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "logaddexp2(Tensor self, Tensor other) -> Tensor")
|
| 35 |
+
static at::Tensor call(const at::Tensor & self, const at::Tensor & other);
|
| 36 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|