diff --git a/.gitattributes b/.gitattributes index 8f38b0d65d1d6f01f6bddf690ebfbef3696e9f9d..8c9edf9b3920f496bfd13070c9efa60592961d6d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -76,3 +76,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 filter=lfs diff=lfs merge=lfs -text diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 new file mode 100644 index 0000000000000000000000000000000000000000..5a061bc9172d01c1657eaca390327573c50759e9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26a7288b7315d658acab1073f02c4f18cd1d27eeadde102958f0317dad6656e0 +size 150200 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f13ef0cd23ba0e9b0cdb2c226a896ba92c7a78bc Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comms.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af388789a44d3f68e82531f0369a28078eb86941 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/comms.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e20a2f28745fe0a2d88e332bd62305feebac1af Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/exc.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/exc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e9be2294c2b262cdd3330bac19ecad58d034844 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/exc.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/graph.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/graph.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8bf477635b8235c9d8cd99e91f20a8615386ba3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/graph.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e1b50a3de908ae153518af596efc61febff370f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed9be357ba4a2f2c47455673050e5759f7f3ddc6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comms.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comms.py new file mode 100644 index 0000000000000000000000000000000000000000..7c776a78b14ac16a5e06d77407090d0d92aed071 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comms.py @@ -0,0 +1,363 @@ +# pyre-strict + +from typing import List + +import torch + +from . import config, ir, scheduler +from .dependencies import WeakDep +from .utils import tuple_sorted + +overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") + + +def sink_waits( + snodes: List["scheduler.BaseSchedulerNode"], +) -> List["scheduler.BaseSchedulerNode"]: + """ + Greedily moves waits as late as possible (i.e. until we reach a use). Optimal in terms of + communication overlap. + """ + new_order = [] + cur_waits = set() + for snode in snodes: + if isinstance(snode.node, ir.Wait): + cur_waits.add(snode) + else: + for wait in tuple_sorted(cur_waits): + if snode in wait.node_users: + new_order.append(wait) + cur_waits.remove(wait) + new_order.append(snode) + new_order.extend(tuple_sorted(cur_waits)) + return new_order + + +def raise_comms( + snodes: List["scheduler.BaseSchedulerNode"], +) -> List["scheduler.BaseSchedulerNode"]: + """ + Greedily moves comms as early as possible (i.e. until we reach an input). + Optimal in terms of communication overlap. + + TODO: We might want to adjust this in the future to account for memory limitations. + e.g. when we are compiling FSDP, this heuristics will cause the all-gathers to be prefetched as soon as possible, + which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP, + or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way. + """ + new_order_reversed: List["scheduler.BaseSchedulerNode"] = [] + cur_comms: List["scheduler.BaseSchedulerNode"] = [] + for snode in reversed(snodes): + if isinstance(snode.node, ir.CollectiveKernel): + cur_comms.append(snode) + else: + for comm in cur_comms: + assert len(comm.inverse_users) > 0 + while len(cur_comms) > 0 and any( + snode in comm.inverse_users for comm in cur_comms + ): + comm = cur_comms.pop(0) + new_order_reversed.append(comm) + new_order_reversed.append(snode) + assert len(cur_comms) <= 1 + new_order_reversed.extend(tuple_sorted(cur_comms)) + return new_order_reversed[::-1] + + +def get_ancestors(node): + ancestors = set() + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.inverse_users: + if inp not in ancestors: + ancestors.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return ancestors + + +def get_descendants(node): + descendants = set() + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.node_users: + if inp not in descendants: + descendants.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return descendants + + +def decide_global_ordering_of_comms(nodes: List["scheduler.BaseSchedulerNode"]): + """ + Decide global ordering of comms, by just enforcing the ordering that's in the input graph + (might not be the same ordering as the eager mode program). + TODO: Come up with a better approach + """ + comm_nodes = [n for n in nodes if isinstance(n.node, ir.CollectiveKernel)] + for i in range(1, len(comm_nodes)): + # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm + comm_nodes[i].add_fake_dep(WeakDep(comm_nodes[i - 1].get_name())) + + +def assert_no_comm_nodes(snodes: List["scheduler.BaseSchedulerNode"]) -> None: + assert not any(isinstance(snode.node, ir.CollectiveKernel) for snode in snodes) + + +def estimate_op_runtime(snode: "scheduler.BaseSchedulerNode") -> float: + """ + Returns estimated op runtime in nanoseconds (ns) + """ + if config.estimate_op_runtime == "default": + runtime = snode.get_estimated_runtime() + else: + assert callable(config.estimate_op_runtime) + runtime = config.estimate_op_runtime(snode) + return runtime + + +def reorder_compute_for_overlap( + snodes: List["scheduler.BaseSchedulerNode"], +) -> List["scheduler.BaseSchedulerNode"]: + """ + Decides a global ordering of all compute and communication nodes, + assuming that we already have a global ordering of communication nodes. + + Overall scheduling procedure is: + Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes + that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N. + Step 2: If all those compute nodes are sufficient to overlap comm N, we're done. + Otherwise, we now need to look elsewhere to find compute that overlaps with comm N. + We prioritize compute nodes that are needed sooner. + Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1. + Step 4: We schedule comm N + 1. + Repeat this for subsequent comm nodes. + """ + final_order = [] + + comm_nodes = [] + for snode in snodes: + if isinstance(snode.node, ir.CollectiveKernel): + comm_nodes.append(snode) + if len(comm_nodes) == 0: + # if there is no comm nodes, return the current order + return snodes + + comm_ancestors = {node: get_ancestors(node) for node in comm_nodes} + comm_descendants = {node: get_descendants(node) for node in comm_nodes} + + indeg = dict.fromkeys(snodes, 0) + for snode in snodes: + for user in snode.node_users: + if user in indeg: + indeg[user] += 1 + ready_to_schedule_nodes = {node for node in snodes if indeg[node] == 0} + + unscheduled_nodes = set() + unscheduled_nodes = set(snodes) + + def schedule_node(snode): + """ + Schedule a single node. + """ + assert snode in unscheduled_nodes + assert snode in ready_to_schedule_nodes + ready_to_schedule_nodes.remove(snode) + unscheduled_nodes.remove(snode) + final_order.append(snode) + for user in tuple_sorted(snode.node_users): + if user in indeg: + indeg[user] -= 1 + if indeg[user] == 0: + ready_to_schedule_nodes.add(user) + + def schedule_nodes(snodes): + """ + Schedules all nodes in `snodes` in an arbitrary topologically valid order. + """ + all_nodes = set(snodes) + assert all(node in unscheduled_nodes for node in all_nodes) + while len(all_nodes) > 0: + # NOTE: since model graph is always a DAG and does not have circular dependency inside, + # there should be at least one node that is a "free node" (i.e. indeg == 0), + # hence infinite loop is not possible. But we check here just to be safe. + progress = False + for node in tuple_sorted(all_nodes): + if node in ready_to_schedule_nodes: + schedule_node(node) + all_nodes.remove(node) + progress = True + if not progress: + raise Exception( + "Unable to find a free node (indeg == 0). This is an impossible state to reach. " + "Please report a bug to PyTorch." + ) + + # First, schedule all compute nodes that are required by first comm node, + # as well as the first comm node itself. + assert len(comm_nodes) > 0 + schedule_nodes( + list(comm_ancestors[comm_nodes[0]]) + [comm_nodes[0]], + ) + + rolled_over_compute_cost = 0 + for idx in range(1, len(comm_ancestors)): + # Step 1: Given that we've currently scheduled comm `idx-1`, we now schedule + # all compute nodes that are required for comm `idx` but do not depend on comm `idx-1`, + # to run at the same time with comm `idx-1`. + needed_by_next_comm_and_ready_compute_nodes = unscheduled_nodes & ( + comm_ancestors[comm_nodes[idx]] - comm_descendants[comm_nodes[idx - 1]] + ) + assert_no_comm_nodes(needed_by_next_comm_and_ready_compute_nodes) + + total_compute_runtime_cost = rolled_over_compute_cost + sum( + [ + estimate_op_runtime(node) + for node in needed_by_next_comm_and_ready_compute_nodes + ] + ) + prev_comm_runtime_cost = estimate_op_runtime(comm_nodes[idx - 1]) + schedule_nodes(tuple_sorted(needed_by_next_comm_and_ready_compute_nodes)) + + # Step 2: If all those compute nodes are sufficient to overlap comm `idx-1`, we're done. + # Otherwise, we now need to look elsewhere to find compute that overlaps with comm `idx`. + # We prioritize compute nodes that are needed sooner. + step1_runtime_cost = total_compute_runtime_cost + if step1_runtime_cost >= prev_comm_runtime_cost: + pass + else: + # Find all ready to schedule compute nodes that do not depend on comm `idx-1`. + ready_to_schedule_compute_nodes = tuple_sorted( + ready_to_schedule_nodes - comm_descendants[comm_nodes[idx - 1]] + ) + assert_no_comm_nodes(ready_to_schedule_compute_nodes) + + def earliest_comm_descendant(node): + for idx in range(len(comm_nodes)): + if node in comm_ancestors[comm_nodes[idx]]: + return idx + return len(comm_nodes) + + # Prioritize compute nodes that are needed sooner. + ready_to_schedule_compute_nodes = sorted( + ready_to_schedule_compute_nodes, key=earliest_comm_descendant + ) + + for snode in ready_to_schedule_compute_nodes: + if total_compute_runtime_cost >= prev_comm_runtime_cost: + # If accumulated compute runtime cost is greater than comm `idx-1` runtime cost, + # it means we have maximized overlap for comm `idx-1`, and hence we stop looking + # for more compute to schedule. + break + compute_runtime_cost = estimate_op_runtime(snode) + # If we're not able to leverage more than half of this + # node's compute to overlap, we skip it. + # TODO: Smarter heuristics here + if ( + prev_comm_runtime_cost - total_compute_runtime_cost + ) <= compute_runtime_cost / 2: + continue + schedule_node(snode) + total_compute_runtime_cost += compute_runtime_cost + rollable_compute_cost = total_compute_runtime_cost - step1_runtime_cost + + # Step 3: We schedule the compute nodes dependent on comm `idx-1` and required for comm `idx`. + needed_by_next_comm_nodes = unscheduled_nodes & comm_ancestors[comm_nodes[idx]] + schedule_nodes(list(needed_by_next_comm_nodes)) + + # Step 4: We schedule comm `idx`. + schedule_nodes([comm_nodes[idx]]) + + is_prev_comm_blocking_next_comm = len(needed_by_next_comm_nodes) > 0 + # The idea here is that if there are no compute nodes from Step 3 + # (i.e. if prev comm is not blocking next comm), we can roll over the compute nodes + # in Step 2 to overlap with the next comm, since they're not required to finish + # before the next comm starts. + if is_prev_comm_blocking_next_comm: + rolled_over_compute_cost = 0 + else: + rolled_over_compute_cost = rollable_compute_cost # type: ignore[assignment] + + schedule_nodes(unscheduled_nodes) + return final_order + + +def node_summary(snode): + detail = "" + if isinstance(snode.node, ir.ExternKernelOut): + detail = f" ({snode.node.python_kernel_name})" + out_tensor_info = "" + if ( + hasattr(snode.node, "layout") + and hasattr(snode.node.layout, "size") + and hasattr(snode.node.layout, "stride") + ): + out_tensor_info = ( + f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})" + ) + node_name = "" + if hasattr(snode.node, "name"): + node_name = snode.node.name + return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})" + + +def visualize_overlap(order): + total_est_runtime: float = 0.0 + cur_comm_node = None + for snode in order: + if cur_comm_node is None: + if isinstance(snode.node, ir.CollectiveKernel): + total_est_runtime += estimate_op_runtime(snode) + cur_comm_node = snode.node + elif isinstance(snode.node, ir.Wait): + raise Exception( + "Wait is not expected when there is no collective running" + ) + else: # exposed compute op + total_est_runtime += estimate_op_runtime(snode) + overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 + else: # cur_comm_node is not None + if isinstance(snode.node, ir.CollectiveKernel): + raise Exception( + "Found two collectives running at the same time. " + "`visualize_overlap` needs to be updated to handle this case" + ) + elif isinstance(snode.node, ir.Wait): # end of this comm op + overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 + cur_comm_node = None + else: # overlapped compute op + overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004 + overlap_log.debug( + f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 + ) + + +def reorder_compute_and_comm_for_overlap( + snodes: List["scheduler.BaseSchedulerNode"], +) -> List["scheduler.BaseSchedulerNode"]: + order = snodes + for p in config.reorder_for_compute_comm_overlap_passes: + if isinstance(p, str) and p in globals(): + p = globals()[p] # it is a builtin pass + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug(str(e)) + order = p(order) # type: ignore[operator] + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug(str(e)) + return order diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/config.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/config.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8d9546f4fd7553bf78197226086f01d6ffbc8f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/config.py @@ -0,0 +1,752 @@ +import os # noqa: C101 +import sys +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING + +import torch + + +def is_fbcode(): + return not hasattr(torch.version, "git_version") + + +# add some debug printouts +debug = False + +# add inf and NaN checkers +debug_check_inf_and_nan = False + +# Whether to disable a progress bar for autotuning +disable_progress = True + +# Whether to enable printing the source code for each future +verbose_progress = False + +# use fx aot graph codegen cache +fx_graph_cache = os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE") == "1" + +# use cpp wrapper instead of python wrapper +cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" + +# codegen cpp wrapper code in an ABI compatible mode +abi_compatible = ( + os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1" +) + +c_shim_version = os.environ.get( + "TORCHINDUCTOR_C_SHIM_VERSION", "1" if is_fbcode() else "2" +) + +# dead code elimination +dce = False + +# assume weight tensors are fixed size +static_weight_shapes = True + +# put correctness assertions in generated code +size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1" +nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1" + +# enable loop reordering based on input orders +pick_loop_orders = True + +# reuse a kernel input as the output +inplace_buffers = True + +# reuse a buffer for an unrelated purpose +allow_buffer_reuse = True + +# Enable pooled allocations for non-output tensors +memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1" + +# How to organize memory under memory_planning=True: +# - "none": do not try to pool storage, just reuse +# - "intermediates": all non-outputs share storage, outputs each get unique storage +# - "outputs": two pools, one for intermediates (freed on return) and one for outputs +# - "combined": a single pool for both intermediates and outputs +memory_pool = os.environ.get("TORCHINDUCTOR_MEMORY_POOL", "intermediates") + +# codegen benchmark harness +benchmark_harness = True + +# fuse pointwise into templates +epilogue_fusion = True + +# do epilogue fusions before other fusions +epilogue_fusion_first = False + +# enable pattern match+replace optimizations +pattern_matcher = True + +# register custom graph optimization pass hook. so far, pre/post passes are +# only applied before/after pattern_matcher in post_grad_passes. +# +# def my_custom_pre_pass(graph: torch.fx.graph.Graph): +# # my custom graph optimization pass +# ... +# +# def my_custom_post_pass(graph: torch.fx.graph.Graph): +# # my custom graph optimization pass +# ... +# +# torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass +# torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass +post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None +post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None + +# Registers a custom pregrad pass. Note that the pre-grad IR is 1. +# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should +# use post-grad passes. +pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None + +# Optimize away split cat patterns (Experimental) +split_cat_fx_passes = True + +# Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability. +efficient_conv_bn_eval_fx_passes = False + +# Enable predispatch aten IR for export +is_predispatch = False + +# Deprecated +group_fusion = False + +# Deprecated +batch_fusion = True + +# Pre grad group/batch fusion and options in order, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions. +pre_grad_fusion_options: Dict[str, Dict[str, Any]] = { + "batch_linear": {}, + "batch_linear_lhs": {}, + "batch_layernorm": {}, + "batch_tanh": {}, + "batch_relu": {}, + "batch_sigmoid": {}, +} + +# Post grad group/batch fusion and options, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions. +post_grad_fusion_options: Dict[str, Dict[str, Any]] = {} + +# enable reordering pass for improving memory locality +reorder_for_locality = True + +# Scale down RBLOCK for better occupancy +dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1" + +# this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32 +# but the mul gets fused with other pointwise ops instead. +force_fuse_int_mm_with_mul = False + +# for pattern torch.mm(a, b.to(dtype)) with cuda tensors, +# enable torch._inductor.kernel.mm.tuned_mixed_mm fused kernel. +# Autotune will compare perf with normal cast->then->mm option +use_mixed_mm = False + +# enable runtime numeric check for pre/post grad fx passes +# floating point provides limited accuracy (about 7 decimal digits for single precision +# floating point numbers,about 16 decimal digits for double precision floating point numbers) +# according to PyTorch documentation. +# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations +fx_passes_numeric_check: Dict[str, Any] = { + "pre_grad": False, + "precision": 1e-4, + "num_iterations": 1, + "requires_optimizer": True, +} + +# for pattern torch.mm(a, b.to(dtype)) with cuda tensors, always use +# torch._inductor.kernel.mm.tuned_mixed_mm's fused kernel. +# Autotune will not compare with normal cast->then->mm option. +# (if force_mixed_mm is true, the use_mixed_mm flag will be ignored) +force_mixed_mm = False + +# enable reordering pass for increasing overlap between compute and communication +reorder_for_compute_comm_overlap = False + +# passes (in execution order) for increasing overlap between compute and communication +# for built-in passes, use string name; for user-defined passes, pass in the function handle +reorder_for_compute_comm_overlap_passes = [ + "reorder_compute_for_overlap", + "sink_waits", + "raise_comms", +] + +# runtime estimation function for ops +# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle +estimate_op_runtime = "default" + +# unit: GB/s, uni-directional P2P bandwidth per card +# default value is NVLink +intra_node_bw = 300 + +# unit: GB/s, uni-directional P2P bandwidth per node +# default value is InfiniBand +inter_node_bw = 25 + +# enable slow autotuning passes to select algorithms +max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" + +# enable slow autotuning passes to select pointwise/reductions algorithms +max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1" + +# enable slow autotuning passes to select gemm algorithms +max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" + +# enable autotune local cache +use_autotune_local_cache = True + +# enable autotune remote cache +use_autotune_remote_cache = ( + os.environ.get("TORCH_INDUCTOR_AUTOTUNE_REMOTE_CACHE") == "1" +) + +# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations +# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations +# for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure +# that triton does not use TF32 wherever cublas would not use TF32 +force_same_precision = ( + True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" +) +# Specify candidate backends for gemm autotune. +# Possible choices are combinations of: ATen, Triton, CUTLASS. +# ATen: default Pytorch ATen kernels. +# Triton: Triton templates defined in torch inductor. +# CUTLASS: Cutlass templates and kernels. +max_autotune_gemm_backends = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON" +).upper() + +# the value used as a fallback for the unbacked SymInts +# that can appear in the input shapes (e.g., in autotuning) +unbacked_symint_fallback = 8192 + +# enable searching global and local cache regardless of `max_autotune` +search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1" + +save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" + +# We will disable creating subprocess for autotuning if this is False +autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1" + +# If autotuning in subprocess, whether to use multiple devices +autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" + +coordinate_descent_tuning = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" +) +coordinate_descent_check_all_directions = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1" +) +coordinate_descent_search_radius = int( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1") +) + +# Disabled by default on ROCm, opt-in if model utilises NHWC convolutions +layout_opt_default = "1" if not torch.version.hip else "0" +layout_optimization = ( + os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1" +) + +force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1" + + +# Whether to keep the output strides the same as eager after layout optimization. +keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1" + +# Enabling this will let compiler print warning messages if a generated triton +# kernel has inputs with mixed layouts. This is helpful for perf debugging +# since kernel with mixed layout inputs may run much slower then one whose inputs +# have uniform layouts. +warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1" + +# control store vs recompute heuristic +# For fanouts, rematerialization can lead to exponential blowup. So, have +# smaller threshold +realize_reads_threshold = 4 +realize_opcount_threshold = 30 + +# Threshold to prevent excessive accumulation of ops in one buffer during lowering +realize_acc_reads_threshold = 8 + +# fallback to eager for random/dropout, this is slow but useful for debugging +fallback_random = False + +# automatically create fallbacks when encountering an unhandled op +implicit_fallbacks = True + +# fuse even in cases without common reads +aggressive_fusion = False + +# For each fused kernel in the wrapper, comment with the nodes that get fused. +# Useful for debugging fusion. +debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" +benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" +enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "") + +# how many nodes to allow into a single fusion +max_fusion_size = 64 + +# max number of inputs to generate cat as a pointwise op with masked laods +max_pointwise_cat_inputs = 8 + +# replace small reductions with pointwise, disable with `= 1` +unroll_reductions_threshold = 8 + +# Add extra comments to output code (causes compile cache misses) +comment_origin = False + +# Convert 1x1 convs into matmuls +conv_1x1_as_mm = False + +# Enable split reductions for better utilization when the dimension +# being reduced over is large (by splitting it) +split_reductions = True + +benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1" + +# Enable constant and index_expr folding +constant_and_index_propagation = True + +# we always add constants into graph.constants without +# performing any constant-inlining optimization +always_keep_tensor_constants = False + +# assert that indirect indexing does not read / write out of bounds +assert_indirect_indexing = True + +# constant folding on the joint graph +joint_graph_constant_folding = True + +# Enable indirect_indexing asserts for decompositions and lowerings +debug_index_asserts = False + +# warnings intended for PyTorch developers, disable for point releases +is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__ +developer_warnings = is_fbcode() or is_nightly_or_source + +# The multiprocessing start method to use for inductor workers in the codecache. +# TODO: fork is not safe in a multithreaded environment, we should evaluate changing +# the default to spawn. +worker_start_method = "fork" + + +def decide_compile_threads(): + """ + Here are the precedence to decide compile_threads + 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by + setting this to 1 to make pdb happy. + 2. Set to 1 if it's win32 platform or it's a fbcode build + 3. decide by the number of CPU cores + """ + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + elif sys.platform == "win32" or is_fbcode(): + return 1 + else: + cpu_count = ( + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count() + ) + assert cpu_count + return min(32, cpu_count) + + +compile_threads = decide_compile_threads() + +# gemm autotuning global cache dir +if is_fbcode(): + from libfb.py import parutil + + try: + if __package__: + global_cache_dir = parutil.get_dir_path( + os.path.join(__package__.replace(".", os.sep), "fb/cache") + ) + else: + global_cache_dir = parutil.get_dir_path("fb/cache") + except ValueError: + global_cache_dir = None +else: + global_cache_dir = None + +# If kernel is fused, the name is generated from the origin node op names +# for larger kernels limit this +kernel_name_max_ops = 10 + +# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs +shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1" + +# Fx-based linear/matmul/bmm + permute/transpose vertical fusion +permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" + +# Mark the wrapper call in PyTorch profiler +profiler_mark_wrapper_call = False + +# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for +# every intermediate for which we can correlate it with an intermediate +# from the original FX graph +generate_intermediate_hooks = False + +# Populate traceback field on IRNode; good for debugging why origin_node is +# not populated, or finding out where an IRNode was constructed +debug_ir_traceback = False + +# used for debugging to make sure config is properly set +_raise_error_for_testing = False + +_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "") +profile_bandwidth = _profile_var != "" +profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var +# Specify a file where we print out the profiling results. +# None means we do not dump results to a file. +profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None) + +# TODO: remove later +disable_cpp_codegen = False + + +# Freezing will attempt to inline weights as constants in optimization +# and run constant folding and other optimizations on them. After freezing, weights +# can no longer be updated. +freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1" + +# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead +# of potentially keeping multiple copies of weights. +freezing_discard_parameters: bool = False + +# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests +# should be run with this flag both on and off to make sure we have coverage. +allow_stack_allocation: bool = ( + os.environ.get("TORCHINDUCTOR_STACK_ALLOCATION", "1") == "1" +) + +# Enables an alternate DSO interface (the "minimal ArrayRef interface") intended +# to maximize performance for use cases that it can accommodate at the expense of +# generality. In brief: +# - inputs and outputs are ArrayRefTensor (note that strides are required, but the +# tensor must be contiguous) +# - constant handling is unchanged because it is not a per-inference-iteration bottleneck +# +# When the DSO is generated in this mode, the usual interface will also be supported, +# but performance for that interface may be degraded. +use_minimal_arrayref_interface: bool = False + +# decompose some memory bound matmul/bmm to mul +decompose_mem_bound_mm: bool = False + + +# config specific to codegen/cpp.py +class cpp: + # set to torch.get_num_threads() + threads = -1 + + # Do not generate loops when the condition doesn't hold, like: + # for(long i0=4096; i0<4096; i0+=1) + no_redundant_loops = True + + # Assume number of threads is dynamic, don't specialize thread number. + # Kernels don't recompile on thread number changes with this flag on. + # For single-threaded workload, turning it on would incur a slight + # performance degradation. + dynamic_threads = False + + simdlen: Optional[int] = None + min_chunk_size = 4096 + cxx = ( + None, # download gcc12 from conda-forge if conda is installed + # "g++-12", + # "g++-11", + # "g++-10", + # "clang++", + os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"), + # "g++.par", + ) + # Allow kernel performance profiling via PyTorch profiler + enable_kernel_profile = False + + # enable weight prepacking to get a better performance; may lead to large memory footprint + weight_prepack = True + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + inject_log1p_bug_TESTING_ONLY: Optional[str] = None + + # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise, + # force usage as specified, without testing. + vec_isa_ok: Optional[bool] = None + + # similar to config.triton.descriptive_names + descriptive_names = "original_aten" + + # how many nodes to allow into a single horizontal fusion + max_horizontal_fusion_size = 16 + + # Make scatter_reduce fallback when reduce is sum to avoid performance regression + # using atomic_add. + fallback_scatter_reduce_sum = True + + # Use funsafe-math-optimizations when compiling + enable_unsafe_math_opt_flag = False + + # Use ffp-contract when compiling + enable_floating_point_contract_flag = False + + +# config specific to codegen/triton.py +class triton: + # Use cudagraphs on output code + cudagraphs = False + + # Use cudagraph trees for memory pooling if `cudagraphs` is True + cudagraph_trees = True + + # assertions not on the fast path, steady state + slow_path_cudagraph_asserts = True + + # TODO - need to debug why this prevents cleanup + cudagraph_trees_history_recording = False + + # assertions on the fast path + fast_path_cudagraph_asserts = False + + # skip warmup for cudagraph trees + skip_cudagraph_warmup = False + + # Synchronize before and after every compiled graph. + debug_sync_graph = False + + # Synchronize after every kernel launch, to help pinpoint bugs + debug_sync_kernel = False + + # Always load full blocks (rather than broadcasting inside the block) + dense_indexing = False + + # limit tiling dimensions + max_tiles = 2 + + # use triton.autotune for pointwise ops with complex layouts + # this should only be disabled for debugging/testing + autotune_pointwise = True + + # max autotune gemm with cublasLt + autotune_cublasLt = True + + # should we stop a fusion to allow better tiling? + tiling_prevents_pointwise_fusion = True + tiling_prevents_reduction_fusion = True + + # should we give different names to kernels + # Note: This is orthogonal to descriptive_names - this is deciding whether + # our triton kernel names should all be `triton_` (to maximize caching) or + # whether they should be unique. + unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1" + + # should we put op names in kernel names + # False: No special names (just triton__1, triton__2, etc.) + # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.) + # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions) + # "inductor_node": Maps to the node name in the FX graph passed to Inductor + descriptive_names = "original_aten" + + # use alternate codegen for smaller reductions + persistent_reductions = ( + os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1" + ) + + # 0/False: disable + # 1/True: enable, use tuning to pick between different subkernels + # 2: enable, force using persistent reduction (for debugging) + # 3: enable, force using non-persistent reduction (for debugging) + multi_kernel = int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0")) + + # hint to Triton when arguments are divisible by 16 + divisible_by_16 = True + + # theses are not enforced, but they are used by asserts in triton_heuristics.py + # NOTE: mobilevit_s in timm_models required X to be set to the higher value 2048 + + # Max RBLOCK will be large for multi-kernel since we do more aggressive + # persistent reduction. + max_block = { + "X": 2048, + "Y": 1024, + "Z": 1024, + "R": 4096 * (16 if multi_kernel else 1), + } + + # Minimum RBLOCK to be used for a TritonSplitScanKernel + # NOTE: This also indirectly controls the size of workspace buffer required + min_split_scan_rblock = 256 + + # Store the generated cubin files for cpp wrapper code to load + store_cubin = False + + # the max number of spills we allow for the configs we benchmark. + # Setting this to 0 means we skip a config if it spills even a single + # register. + # Setting it to a larger value allows a config spilling a small amount + # of registers being benchmarked. + # + # NOTE: triton will always report >0 register spills for kernels using sin/cos. + # (check this issue https://github.com/openai/triton/issues/1756 ) + # So far we see a fixed 8 spilled registers for kernels using sin/cos. + # Raise the threshold to 16 to be safe. + # We should revisit this once we understand more of the source of register spills. + spill_threshold: int = 16 + + # Generate code containing the newer tl.make_block_ptr() API for loads/store + use_block_ptr = False + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + + +class aot_inductor: + # AOTInductor output path + # If an absolute path is specified, the generated lib files will be stored under the directory; + # If a relative path is specified, it will be used as a subdirectory under the default caching path; + # If not specified, a temp directory will be created under the default caching path. + # If the specified path contains something like "model.so", the sub-string will be used + # to name the generated library. + output_path = "" + + debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" + + # Serialized tree spec for flattening inputs + serialized_in_spec = "" + + # Serialized tree spec for flattening outputs + serialized_out_spec = "" + + # flag to decide whether to create a submodule for constant graph. + use_runtime_constant_folding: bool = False + + +class cuda: + # CUDA arch to use for CUDA template kernel compilation. + # e.g. "70", "75", "80", "90", etc. + # When arch is None, Inductor uses torch.cuda.get_device_capability(0). + arch: Optional[str] = None + + # CUDA version to use for CUDA template kernel compilation. + # e.g. "11.4", "12.1", etc. + # When version is None, Inductor uses torch.version.cuda. + version: Optional[str] = None + + # Optimization level for the host compiler. + compile_opt_level = "-O1" + + # Whether to enable device LTO (link-time-optimization). + enable_cuda_lto = False + + # Whether to keep intermediate files dring compilation. + enable_ptxas_info = False + + # Whether to enable debug info, e.g. line number, cutlass debug info. + enable_debug_info = False + + # Whether to use fast math. + use_fast_math = False + + # Path to the CUTLASS repo root directory. + # The default path only works under PyTorch local development environment. + cutlass_dir = os.environ.get( + "TORCHINDUCTOR_CUTLASS_DIR", + os.path.abspath( + os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/") + ), + ) + + # Configures the maximum number of CUTLASS configs to profile in max_autotune. + # By default it's None, so that all CUTLASS configs are tuned. + # This is mainly used to reduce test time in CI. + cutlass_max_profiling_configs: Optional[int] = None + + # Path to CUDA NVCC. + # NVCC search order: + # 1) cuda_cxx set in this config + # 2)CUDACXX environment variable + # 3)CUDA_HOME environment variable + # 4) default system search PATH. + cuda_cxx: Optional[str] = None + + # If set to True, it will ensure that only GEMM ops capable of + # epilogue fusion via CUTLASS Epilogue Visitor Trees ( EVT ) + # are enabled for the CUTLASS backend. + cutlass_only_evt_capable_ops: bool = False + + +# create a directory containing lots of debug information +class trace: + # master switch for all debugging flags below + enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + # Save debug information to a temporary directory + # If not specified, a temp directory will be created by system + debug_dir: Optional[str] = None + + # Save python logger call >=logging.DEBUG + debug_log = False + + # Save python logger call >=logging.INFO + info_log = False + + # Save input FX graph (post decomps, pre optimization) + fx_graph = True + + # Save FX graph after transformations + fx_graph_transformed = True + + # Save TorchInductor IR before fusion pass + ir_pre_fusion = True + + # Save TorchInductor IR after fusion pass + ir_post_fusion = True + + # Copy generated code to trace dir + output_code = True + + # SVG figure showing post-fusion graph + graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1" + + # SVG figure showing fx with fusion + draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1" + + # We draw our fx graphs with the "record" shape attribute by default. + # Sometimes, when the graph is very complex, we may hit dot errors like below: + # "flat edge between adjacent nodes one of which has a record shape - + # replace records with HTML-like labels" + # and thus fail to generate a graph. So, let's give the user an option + # to specify the shape attribute for the dot graph. For example, passing + # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like lables + # to workaround the above failure. + dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None) + + # Store cProfile (see snakeviz to view) + compile_profile = False + + # Upload the .tar.gz file + # Needs to be overriden based on specific environment needs + upload_tar: Optional[Callable[[str], None]] = None + + log_autotuning_results: bool = False + + +_save_config_ignore = { + # workaround: "Can't pickle " + "trace.upload_tar", +} + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + +from torch.utils._config_module import install_config_module + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..c1f031ee4a133e7621e2739bc0f447d5de516c24 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/constant_folding.py @@ -0,0 +1,264 @@ +import collections +from typing import Any, Callable, Dict, Optional + +import torch +import torch.utils._pytree as pytree + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + + +def replace_node_with_constant(gm, node, constant, name=None): + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm, + skip_constructors=False, + ): + super().__init__(gm) + self.node_replacements: Dict[torch.fx.Node, Any] = {} + self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + + def is_impure(self, node: torch.fx.node.Node): + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + return False + + def node_to_last_non_output_use(self): + last_non_output_use = collections.defaultdict(list) + seen_uses = set() + output_node = next(iter(reversed(self.module.graph.nodes))) + + for node in reversed(self.module.graph.nodes): + if node.target == "output": + continue + + def add_use(inp): + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + pytree.tree_map_only(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node): + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg): + self.env[arg] = self.unknown_value + + pytree.tree_map_only(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + if self.unknown_value in flattened_inputs: + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target == aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and node.op != "get_attr" + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + out = super().run_node(node) + + if node.op != "get_attr" and isinstance(out, torch.Tensor): + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self): + env = {} + for n in self.module.graph.nodes: + if n.op == "placeholder": + env[n] = self.unknown_value + return super().run(initial_env=env) + + +@torch.utils._python_dispatch._disable_current_modes() +def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + for node in gm.graph.nodes: + if node.op == "get_attr" and len(node.users) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +@torch.utils._python_dispatch._disable_current_modes() +def constant_graph_tag(gm: torch.fx.GraphModule): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node in gm.graph.nodes: + if ( + node.op == "get_attr" + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag(gm) + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.nodes: + if node.op == "get_attr": + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + + new_graph = torch.fx.Graph() + + node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + return new_gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/graph.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..605ea9c130416af3e47a03694a511b939608aeca --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/graph.py @@ -0,0 +1,1324 @@ +import itertools +import logging +import operator +import os +import re +import sys +import time +from collections import defaultdict +from contextlib import contextmanager +from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple + +import sympy + +import torch +import torch._logging +import torch.fx +from torch._decomp import get_decompositions +from torch._dynamo.utils import defake, dynamo_timed +from torch._logging import LazyString, trace_structured +from torch._subclasses.fake_tensor import FakeTensor +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes +from torch.utils._mode_utils import no_dispatch + +from . import config, ir +from .codegen.common import ( + DeviceOpOverrides, + get_device_op_overrides, + get_scheduling_for_device, + get_wrapper_codegen_for_device, + register_backend_for_device, +) +from .codegen.cpp_wrapper_cpu import CppWrapperCpu +from .codegen.cpp_wrapper_cuda import CppWrapperCuda +from .codegen.wrapper import WrapperCodeGen +from .exc import ( + CppWrapperCodeGenError, + LoweringException, + MissingOperatorWithDecomp, + MissingOperatorWithoutDecomp, +) +from .ir import ( + Constant, + FixedLayout, + InputBuffer, + Pointwise, + Reduction, + StorageBox, + TensorBox, +) +from .lowering import ( + constrain_to_fx_strides, + FALLBACK_ALLOW_LIST, + fallback_handler, + fallback_node_due_to_unsupported_type, + layout_constraints, + lowerings, + make_fallback, + needs_realized_inputs, + unsupported_output_tensor, +) +from .sizevars import SizeVarAllocator +from .utils import convert_shape_to_inductor, gather_origins, get_sympy_Expr_dtype +from .virtualized import V + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") + + +if config.is_fbcode(): + from torch._inductor.fb.utils import log_module_code +else: + + def log_module_code(*args, **kwargs): + pass + + +def supported_dtype_of_cpp_wrapper(dtype, cuda): + supported_dtype = { + torch.float32, + torch.float64, + torch.int64, + torch.int32, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + torch.bfloat16, + torch.complex32, + torch.complex64, + torch.complex128, + torch.float16, + } + if cuda: + supported_dtype.add(torch.float8_e4m3fn) + supported_dtype.add(torch.float8_e5m2) + supported_dtype.add(torch.float8_e4m3fnuz) + supported_dtype.add(torch.float8_e5m2fnuz) + + return dtype in supported_dtype + + +def may_get_constant_buffer_dtype(constant_buffer): + assert isinstance( + constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) + ), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" + if isinstance(constant_buffer, sympy.core.numbers.Integer): + return torch.int64 + + if isinstance(constant_buffer, sympy.Expr): + return get_sympy_Expr_dtype(constant_buffer) + + if constant_buffer.is_integer: + return torch.int64 + elif constant_buffer.is_float: + return torch.float32 + else: + return None + + +def is_magic_method(op): + magic_ops = {method_to_operator(m) for m in magic_methods} + return op in magic_ops + + +def getattr_recursive(obj, target): + target_atoms = target.split(".") + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +class GraphLowering(torch.fx.Interpreter): + graph_outputs: List[ir.IRNode] + + def symbolic_sizes_strides(self, ex: torch.Tensor): + """ + Support dynamic shapes and dynamic strides by assigning variables + to each dimension. We duck-shape tensors, so if two tensors + have the same size they get assigned the same symbolic variable. + """ + if self.reuse_shape_env: + return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor( + ex.stride() + ) + else: + from torch._dynamo.source import ConstantSource + + # TODO: this should not be needed once #93059 lands + # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816 + # TODO: make a dedicated UnknownSource for this? + # NB: This is using the legacy default behavior from + # create_symbolic_sizes_strides_storage_offset but we hope we can + # just delete this entirely + source = ConstantSource( + f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}" + ) + ( + size, + stride, + _, + ) = self._shape_env.create_symbolic_sizes_strides_storage_offset( + ex, + source, + ) + + size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size] + stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride] + return size, stride + + def static_sizes_strides(self, ex: torch.Tensor): + """ + Primarily used to weights + """ + size = [sympy.Integer(i) for i in ex.size()] + stride = [sympy.Integer(i) for i in ex.stride()] + return size, stride + + def init_backend_registration(self): + if get_scheduling_for_device("cpu") is None: + from .codegen.cpp import CppScheduling + + register_backend_for_device("cpu", CppScheduling, WrapperCodeGen) + + if get_scheduling_for_device("cuda") is None: + from .codegen.cuda_combined_scheduling import CUDACombinedScheduling + + # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation + register_backend_for_device("cuda", CUDACombinedScheduling, WrapperCodeGen) + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: Optional[List[torch.Tensor]] = None, + shape_env=None, + num_static_inputs=None, + graph_id=None, + cpp_wrapper=False, + aot_mode=False, + user_visible_outputs=frozenset(), + layout_opt=None, + extern_node_serializer=None, + is_inference=False, + is_const_graph=False, + const_output_index=None, + const_code=None, + const_module=None, + name=None, + ): + super().__init__(gm) + + self.example_inputs = example_inputs + self.layout_opt = ( + layout_opt + if layout_opt is not None + else self.decide_layout_opt(gm, is_inference=is_inference) + ) + self.num_channels_last_conv = 0 + self.is_inference = is_inference + self.is_const_graph = is_const_graph + self.const_code = const_code + self.const_module = const_module + + self.extra_traceback = False # we do our own error wrapping + if shape_env is None: + shape_env = ShapeEnv() + self.reuse_shape_env = False + else: + self._shape_env = shape_env + self.reuse_shape_env = True + self._shape_env = shape_env + self.sizevars = SizeVarAllocator(shape_env) + self.graph_input_names: List[str] = [] + self.graph_inputs: Dict[str, TensorBox] = {} + self.graph_inputs_original: Dict[str, InputBuffer] = {} + self.device_types: Set[str] = ( + const_module.device_types if const_module else set() + ) + self.device_idxs: Set[int] = const_module.device_idxs if const_module else set() + self.cuda = False + self.buffers: List[ir.Buffer] = [] + self.const_output_index: Dict[str, int] = ( + const_output_index if const_output_index else {} + ) + self.folded_constants: Set[str] = ( + set(const_output_index.keys()) if const_output_index else set() + ) + self.constants: Dict[str, torch.Tensor] = ( + const_module.constants if const_module else {} + ) + self.constant_reprs: Dict[str, str] = {} + self.removed_buffers: Set[str] = set() + self.removed_inplace_buffers: Set[str] = set() + self.mutated_buffers: Set[str] = set() + self.never_reuse_buffers: Set[str] = set() + self.inplaced_to_remove: Set[str] = set() + self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] + self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment] + # See `ProxyExecutor Design Note` in ir.py for more details + self.extern_kernel_nodes: List[ir.ExternKernelNode] = [] + self.extern_node_serializer: Optional[ + Callable[[List[ir.ExternKernelNode]], Any] + ] = extern_node_serializer + self.current_node: torch.fx.Node = None # type: ignore[assignment] + self.num_static_inputs = num_static_inputs + self.lists: Dict[str, List[str]] = {} + self.mutated_inputs: Set[str] = set() + self.mutated_input_idxs: List[int] = [] + self.name_to_buffer: Dict[str, ir.Buffer] = {} + self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list) + self.creation_time = time.time() + self.name = name + self.cpp_wrapper = cpp_wrapper + + # record multi_kernel choice for cpp_wrapper so the second pass knows + # which sub-kernel is picked. Copy cpp_wrapper to another variable + # since cpp_wrapper flag is set to false for the first pass of codegen. + self.record_multi_kernel_choice = cpp_wrapper + self.multi_kernel_to_choice: Dict[str, int] = {} + + self.aot_mode = aot_mode + self.graph_id = graph_id + self.scheduler: "torch._inductor.scheduler.Scheduler" = None # type: ignore[assignment] + self.nodes_prefer_channels_last = ( + self.find_nodes_prefer_channels_last() if self.layout_opt else set() + ) + self._warned_fallback = {"aten.convolution_backward"} + self.user_visible_outputs = user_visible_outputs + self.cache_key: str = "" # This is the cache key for the compiled artifact + self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored + self.cache_linemap: List[ + Tuple[int, str] + ] = ( + [] + ) # This is the linemap used by the profiler to mark custom compiled kernels getting run + # Used if lowering encounters cases where cudagraphs are not supported + self.disable_cudagraphs_reason: Optional[str] = None + + # only keeping one node per device for stack trace purposes + self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {} + self.orig_gm: torch.fx.GraphModule = gm.__copy__() + self.dynamo_flat_name_to_original_fqn = self.module.meta.get( + "dynamo_flat_name_to_original_fqn", {} + ) + self.allocated_constant_name = ( + const_module.allocated_constant_name if const_module is not None else {} + ) + self.init_backend_registration() + + @staticmethod + def decide_layout_opt(gm, *, is_inference) -> bool: + """ + Decide if we should enable layout optimization for this graph based on + heuristics. + """ + if not config.layout_optimization: + return False + + if config.force_layout_optimization: + return True + + conv_nodes = [ + n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default + ] + nconv = len(conv_nodes) + + if nconv == 0: + return False + + # For cpu backend and mkldnn enabled, we always use channels_last for better performance. + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and all( + n.args[idx].meta["val"].device == torch.device("cpu") + for n in conv_nodes + for idx in [0, 1] + ) + ): + return True + + # Following models are skipped due to this: + # jx_nest_base + # volo_d1_224 + if len(list(gm.graph.nodes)) >= 300 * nconv: + log.debug("Skipped layout opt because only a few conv") + return False + + if any( + has_free_symbols(n.args[idx].meta["val"]) + for n in conv_nodes + for idx in [0, 1] + ): + log.debug( + "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670" + ) + return False + + def is_grouped(n): + return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1 + + def is_in_out_channel(n): + return ( + n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) + and n.args[1].meta["val"].size(2) > 1 + ) + + def is_small_channel(n): + return ( + n.args[1].meta["val"].size(0) <= 64 + and n.args[1].meta["val"].size(1) <= 64 + ) + + # only grouped convolutions benchmarked as slower in conv samples for inference only + if is_inference: + from torch.utils.flop_counter import FlopCounterMode + + flop_counts: Dict[str, float] = defaultdict(float) + for node in conv_nodes: + success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( + node + ) + + if success: + with FlopCounterMode(display=False) as flop_counter_mode: + with V.fake_mode: + node.target(*args, **kwargs) + + counted_flops = flop_counter_mode.get_total_flops() + if is_grouped(node): + node_type = "grouped" + elif is_small_channel(node): + node_type = "small" + elif is_in_out_channel(node): + node_type = "in_out" + else: + node_type = "default" + + flop_counts[node_type] += counted_flops + else: + log.debug("Conv inputs meta not found") + + # average benchmarked channels last speedup / slowdown, < 1 is speedup. + # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/ + # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb + GROUPED_MULTIPLIER = 1.358 + DEFAULT_MULTIPLIER = 0.823 + IN_OUT_MULTIPLIER = 0.725 + SMALL_MULTIPLIER = 0.783 + + total_flops = sum(flop_counts.values()) + # TODO - get different values per hardware + weighted_flops = ( + flop_counts["grouped"] * GROUPED_MULTIPLIER + + flop_counts["small"] * SMALL_MULTIPLIER + + flop_counts["in_out"] * IN_OUT_MULTIPLIER + + flop_counts["default"] * DEFAULT_MULTIPLIER + ) + do_layout_opt = weighted_flops <= total_flops + if not do_layout_opt: + log.debug( + "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d", + total_flops, + weighted_flops, + ) + return do_layout_opt + + # Channels last layout can dramatically hurt grouped conv perf. E.g. + # Conv with arguments like + # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 2} + # slows down 31x using channels last.. + + # But a lot of timm models use depthwise separable convolution which will + # result in grouped convolution with in-channel size == 1. + # For those grouped convolution, channels last still helps a lot. + # E.g. + # Conv with arguments + # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 58} + # get 1.86x speedup with channels last layout. + # + # The following heuristics skip using channels-last if the model contains + # grouped convolution with in-channels > 1. + if any(map(is_grouped, conv_nodes)): + log.debug( + "Skip layout opt because found grouped convolution with >1 in_channels!" + ) + return False + + # For some models that contain convolution with larger in-channel than out-channel, applying + # channels last hurts performance. + # Following models are skipped due to this: + # - pytorch_unet + # - phlippe_densenet (slightly worse) + # - Background_Matting (1.22x -> 0.821x) + # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x) + if any(map(is_in_out_channel, conv_nodes)): + log.debug( + "Skip layout opt because some convolutions have smaller out_channel" + ) + return False + + # Following models are skipped due to this: + # - functorch_maml_omniglot + if all(map(is_small_channel, conv_nodes)): + log.debug("Skip layout opt because all convolution channels are too small") + return False + + return True + + def qualify_name(self, name: str) -> str: + """Prepend the given name with the graph name if any.""" + if self.name is not None: + return f"{self.name}_{name}" + return name + + def make_subgraph( + self, + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + subgraph_name: str, + ) -> "GraphLowering": + """ + Make a subgraph of the current graph with all inherited + parts, except the graph module (`gm`) and `example_inputs`. + The subgraphs are lowered separately, but intended to be + inlined in the parent graph's codegening. Hence the need + for maintaining the same `shape_env` and other properties. + The subgraph name is qualified by the parent graph's name. + """ + return GraphLowering( + gm=gm, + example_inputs=example_inputs, + shape_env=self._shape_env, + cpp_wrapper=self.cpp_wrapper, + aot_mode=self.aot_mode, + extern_node_serializer=self.extern_node_serializer, + is_inference=self.is_inference, + name=self.qualify_name(subgraph_name), + ) + + def find_nodes_prefer_channels_last(self): + """ + The rule to decide if an node prefer channels last is simple. + 1. if it's input/output of a convolution + 2. if one of its user prefers channels last + + We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs; + Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers + channels last. + + Consider the scenario: conv -> batch-norm -> relu -> conv + Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies: + 1. the output of batch-norm should be channels last initially since its input is a conv's output. + Forcing the batch-norm's output to be contiguous results in the first copy + 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output. + We need convert it to channels last layout which results in the second copy. + With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies + can be saved. + """ + output_set = set() + for n in reversed(self.module.graph.nodes): + if n.target == torch.ops.aten.convolution.default: + output_set.add(n) + continue + + for user in n.users: + if user in output_set: + output_set.add(n) + break + + # need a second pass to add downstream nodes of those channel last nodes to the sets. + # This pass is especially needed to avoid mix-layout kernel inputs in backward pass. + # + # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned + # from the fwd graph. Without this second pass, we will force relu's output to be contiguous. + # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last + # tensors and passed to a kernel. + # + # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x. + # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x . + # This also helps the following models: + # - res2net101_26w_4s + # - res2net50_14w_8s + # - sebotnet33ts_256 + for n in self.module.graph.nodes: + if n in output_set: + for child in n.users: + output_set.add(child) + + return output_set + + def warn_fallback(self, name): + if name not in self._warned_fallback: + self._warned_fallback.add(name) + perf_hint_log.info("Using FallbackKernel: %s", name) + + def add_device_info(self, device: torch.device): + self.device_types.add(device.type) + if device.index is not None: + self.device_idxs.add(device.index) + if V.graph.current_node and device not in self.device_node_mapping: + self.device_node_mapping[device] = V.graph.current_node + + @property + def fake_mode(self): + return V.fake_mode + + def get_buffer(self, buffer_name: str): + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name] + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name] + return None + + def get_dtype(self, buffer_name: str): + if buffer_name in self.constants: + return self.constants[buffer_name].dtype + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name].get_dtype() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_dtype() + m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name) + if m: + return self.get_dtype(m.group(1)) + raise KeyError(f"could not find {buffer_name}") + + def get_numel(self, buffer_name: str): + from .ir import MultiOutputLayout + + if buffer_name in self.constants: + return self.constants[buffer_name].numel() + if buffer_name in self.name_to_buffer: + buf = self.name_to_buffer[buffer_name] + if isinstance(getattr(buf, "layout", None), MultiOutputLayout): + return 1 + return buf.get_numel() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_numel() + raise KeyError(f"could not find {buffer_name}") + + @dynamo_timed + def run(self, *args): + return super().run(*args) + + def register_buffer(self, buffer: ir.Buffer): + name = self.qualify_name(f"buf{len(self.buffers)}") + self.buffers.append(buffer) + self.name_to_buffer[name] = buffer + # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 + if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements(): + self.add_device_info(buffer.get_device()) + return name + + def register_list(self, buffer_names: List[str]): + name = self.qualify_name("list_" + "_".join(buffer_names)) + self.lists[name] = buffer_names + return name + + def register_users_of(self, node_output): + def register(value): + if isinstance(value, (list, tuple)): + for x in value: + register(x) + if isinstance(value, ir.IRNode): + if ( + not hasattr(value, "data") + or not isinstance(value.data, ir.IRNode) + or not ( + hasattr(value.data, "data") + and isinstance(value.data.data, ir.IRNode) + ) + ): + return + + for read_name in value.get_read_names(): + self.name_to_users[read_name].append(value) + + register(node_output) + + def mark_buffer_mutated(self, name: str): + """ + When a buffer is mutated we need to make sure all the reads to + the old version are realized before the mutation happens. + """ + assert isinstance(name, str) + self.mutated_buffers.add(name) + + if name not in self.name_to_users: + return + + for user in self.name_to_users[name]: + user.realize() + + def add_tensor_constant(self, data, name=None): + def allocate(name): + if not config.aot_inductor.use_runtime_constant_folding: + for constant_name, value in self.constants.items(): + if ( + not data.is_mkldnn + and data.size() == value.size() + and data.stride() == value.stride() + and data.dtype == value.dtype + and data.device == value.device + and torch.eq(data, value).all() + ): + return constant_name + + if name is None: + name = f"constant{len(self.constants)}" + if name[0].isdigit(): + name = f"constant_{name}" + name = self.qualify_name(name) + # We may generate a var name for each constant in the codegen. + # Let's only keep sane characters. + prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name) + name = prefix + cnt = 0 + while name in self.constants: + name = f"{prefix}_{cnt}" + cnt += 1 + self.constants[name] = data + self.constant_reprs[name] = ( + f"{data.device!r} {data.dtype!r} " + f"{tuple(data.size())!r} {tuple(data.stride())!r} " + f"{hash(data):x}" + ) + return name + + new_name = allocate(name) + self.allocated_constant_name[new_name] = name + + return TensorBox.create( + ir.ConstantBuffer( + new_name, + FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)), + ) + ) + + def constant_name(self, name: str, device_override: Optional[torch.device]): + """ + We AOT copy constants to the devices they are needed on. + If device_override doesn't match the constant's device, then + copy it and return a different name. + """ + if self.constants[name].device == device_override or device_override is None: + return name + alt_name = f"{name}_{device_override.type}{device_override.index or 0}" + if alt_name not in self.constants: + self.constants[alt_name] = self.constants[name].to(device_override) + return alt_name + + def placeholder(self, target: str, args, kwargs): + example = super().placeholder(target, args, kwargs) + self.graph_input_names.append(target) + if isinstance(example, SymTypes): + expr = example.node.expr + self.graph_inputs[target] = expr + return expr + elif isinstance(example, (int, bool, float)): + expr = sympy.sympify(example) + self.graph_inputs[target] = expr + return expr + if isinstance(example, BackwardState): + # Ignored arg, must be unused + # Alternately we could filter this out in AotAutograd + return None + assert isinstance(example, torch.Tensor), example + # todo(chilli): We can remove the last check once we turn buffers into + # static shape tensors. That's a hack to workaround Inductor believing + # the buffer should be static but us passing in a fake tensor with + # symbolic shapes. + if not example._has_symbolic_sizes_strides: + # the first N inputs are weights + sizes, strides = self.static_sizes_strides(example) + else: + sizes, strides = self.symbolic_sizes_strides(example) + # TODO(jansel): handle input aliasing + target = self.qualify_name(target) + tensor = TensorBox.create( + InputBuffer( + target, + FixedLayout(example.device, example.dtype, sizes, strides), + ) + ) + self.graph_inputs[target] = tensor + self.graph_inputs_original[target] = tensor.data.data + self.add_device_info(example.device) + return tensor + + def call_function(self, target, args, kwargs): + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) + + if hasattr(target, "_inductor_lowering_function"): + # passthrough lowerings from .pattern_matcher + return target(*args, **kwargs) + + def get_custom_op_layout_constraints(target, args, kwargs): + # Custom operations that require preserving stride order + # which run through implicit fallback must constrain their + # arguments' fx strides + layout_constraint = None + if torch._C.Tag.needs_fixed_stride_order in target.tags: + # We have to set the current args because call_function will immediately + # evaluate this lowering after creating the fallback, without evaluating + # the layout constraint + args, kwargs = constrain_to_fx_strides( + self.current_node, *args, **kwargs + ) + # Also register the layout constraint so when the fallback + # is used again, we can constrain the args to the same layout + layout_constraint = constrain_to_fx_strides + return layout_constraint, args, kwargs + + if target not in lowerings: + assert isinstance( + target, torch._ops.OpOverload + ), f"{target} is not an OpOverload" + base_name = target.name().split(".")[0] + if base_name in FALLBACK_ALLOW_LIST: + make_fallback(target) + elif config.implicit_fallbacks: + layout_constraint, args, kwargs = get_custom_op_layout_constraints( + target, args, kwargs + ) + error = ( + MissingOperatorWithDecomp + if get_decompositions([target]) + else MissingOperatorWithoutDecomp + ) + log.info( + "Creating implicit fallback for:\n%s", + error.operator_str(target, args, kwargs), + ) + make_fallback(target, layout_constraint) + + elif get_decompositions([target]): + # There isn't a good way to dynamically patch this in + # since AOT Autograd already ran. The error message tells + # the user how to fix it. + raise MissingOperatorWithDecomp(target, args, kwargs) + else: + raise MissingOperatorWithoutDecomp(target, args, kwargs) + + try: + log.debug(" via %s", lowerings[target]) + out = lowerings[target](*args, **kwargs) + return out + except Exception as e: + raise LoweringException(e, target, args, kwargs).with_traceback( + e.__traceback__ + ) from None + + @staticmethod + def can_inline_constant(t: torch.Tensor) -> bool: + """ + True if this is a small constant attr that will be inlined. + """ + return len(t.shape) == 1 and t.shape[0] <= 8 + + def get_attr(self, target, args, kwargs): + # this is a constant + value = getattr_recursive(self.module, target) + + if isinstance(value, torch.fx.GraphModule): + return ir.Subgraph(name=target, graph_module=value) + + if ( + config.aot_inductor.use_runtime_constant_folding + or config.always_keep_tensor_constants + or unsupported_output_tensor(value) + ): + return self.add_tensor_constant(value, target) + + with no_dispatch(): + if value.shape == (): + return Constant(value.item(), value.dtype, value.device) + if self.can_inline_constant(value): + # tensor lowering has constant inlining logic + from .lowering import tensor + + return tensor(value.tolist(), dtype=value.dtype, device=value.device) + + return self.add_tensor_constant(value, target) + + def call_module(self, target, args, kwargs): + raise AssertionError() + + def call_method(self, target, args, kwargs): + raise AssertionError() + + def output(self, target, args, kwargs): + result = super().output(target, args, kwargs) + assert isinstance(result, (tuple, list)), type(result) + assert all( + isinstance( + x, + ( + TensorBox, + ir.Constant, + type(None), + ir.ConstantBuffer, + sympy.Expr, + sympy.logic.boolalg.Boolean, + int, + ), + ) + for x in result + ), result + self.graph_outputs = [ir.ExternKernel.realize_input(x) for x in result] + value: ir.IRNode + for name, value in self.graph_inputs.items(): + assert isinstance( + value, (TensorBox, sympy.Expr) + ), f"Unsupported inductor graph input type: {type(value)}" + if not isinstance(value, TensorBox): + continue + value.realize() + assert isinstance(value, TensorBox) + value = value.data + assert isinstance(value, ir.StorageBox) + value_storage_box = value + value = value.data + if not isinstance(value, InputBuffer) or value.get_name() != name: + # one of our inputs was mutated, need to turn that into a copy + ir.MutationLayout.realize_into(value, self.graph_inputs_original[name]) + # replace output with mutated input + try: + ind = self.graph_outputs.index(value_storage_box) + self.graph_outputs[ind] = self.graph_inputs_original[name] + except ValueError: + pass + + self.finalize() + log.debug( + "Force channels last inputs for %d conv for the current graph with id %d", + self.num_channels_last_conv, + self.graph_id if self.graph_id is not None else -1, + ) + + def finalize(self): + for buf in self.buffers: + buf.decide_layout() + + @contextmanager + def set_current_node(self, node: torch.fx.Node): + old = self.current_node + try: + self.current_node = node + yield + finally: + self.current_node = old + + def run_node(self, n: torch.fx.Node): + def debug(msg): + log.debug("lowering %s %s", LazyString(n.format_node), msg) + + origins = {n} + if n.op == "call_function": + args, kwargs = self.fetch_args_kwargs_from_env(n) + origins |= gather_origins(args, kwargs) + with ir.IRNode.current_origins(origins), self.set_current_node( + n + ), V.set_current_node(n): + if ( + n.op == "call_function" + and n.target is not operator.getitem + and fallback_node_due_to_unsupported_type(n) + ): + debug("fallback_handler") + result = fallback_handler(n.target, add_to_fallback_set=False)( + *args, **kwargs # type: ignore[possibly-undefined] + ) + elif n.op == "call_function" and n.target in layout_constraints: + debug("layout_constraints") + args, kwargs = layout_constraints[n.target](n, *args, **kwargs) # type: ignore[index] + result = self.call_function(n.target, args, kwargs) + elif is_magic_method(n.target): + # TODO: this is sus, it probably should be handled in the + # lowerings themselves similarly to sym_size/sym-stride + debug("is_magic_method") + if isinstance(n.meta["val"], torch.SymInt): + result = n.meta["val"].node.expr + else: + result = super().run_node(n) + else: + debug("") + result = super().run_node(n) + + # require the same stride order for dense outputs, + # 1. user-land view() will not throw because inductor + # output different strides than eager + # long term the solution is to make view() always succeed + # with infallible strides. + # 2: as_strided ops, we need make sure its input has same size/stride with + # eager model to align with eager behavior. + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + ] + is_output = any(user.op == "output" for user in n.users) + is_input_for_as_strided = any( + user.target in as_strided_ops for user in n.users + ) + if ( + is_output + and isinstance(result, TensorBox) + and isinstance(result.data, ir.BaseView) + ): + # Realize so that outputs are correctly aliased + result.realize() + + if (is_output or is_input_for_as_strided) and isinstance( + n.meta["val"], torch.Tensor + ): + strides = n.meta["val"].stride() + dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]) + # requiring a stride order for a non-dense output wouldn't + # recreate the same strides, and would fail with view, defer for now. + if dense and len(strides): + stride_order = ir.get_stride_order(strides) + if ( + len(result.get_size()) == 4 + and n in self.nodes_prefer_channels_last + and n.name not in self.user_visible_outputs + and not is_input_for_as_strided + ): + stride_order = ir.NHWC_STRIDE_ORDER + result = ir.ExternKernel.require_stride_order(result, stride_order) + + # Realize if (1) any user need inputs realized, or (2) there is + # already too many reads and rematerializing can be bad. + num_users = len(set(n.users)) + if num_users > 1 and isinstance(result, TensorBox): + for user in n.users: + if user.target in needs_realized_inputs: + result.realize_hint() + # This inclusion is somewhat controversial (from + # discussion between Horace, Natalia, and Elias). + # Currently, it's not very clear why this is helpful. + # The general idea here is that even though a node may + # have FlexibleLayout, we still often *treat* it as if + # it was contiguous. This appears to sometimes result in + # suboptimal behavior. + # + # When we do a better job selecting layout, we should + # revisit this. + need_fixed_layout = [ + torch.ops.aten.convolution_backward.default, + torch.ops.aten.mm.default, + torch.ops.aten._int_mm.default, + ] + if not self.layout_opt: + need_fixed_layout.append(torch.ops.aten.convolution.default) + if torch._C._has_mkldnn: + need_fixed_layout += [ + torch.ops.mkldnn._convolution_pointwise.default, + torch.ops.mkldnn._convolution_pointwise.binary, + torch.ops.mkldnn._convolution_pointwise_.binary, + torch.ops.mkldnn._convolution_transpose_pointwise.default, + torch.ops.mkldnn._linear_pointwise.default, + torch.ops.mkldnn._linear_pointwise.binary, + torch.ops.aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qconv2d_pointwise.default, + torch.ops.onednn.qconv2d_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.default, + torch.ops.onednn.qlinear_pointwise.tensor, + ] + if torch._C.has_mkl: + need_fixed_layout += [torch.ops.mkl._mkl_linear.default] + if user.target in need_fixed_layout: + result = ir.ExternKernel.require_stride_order( + result, ir.get_stride_order(n.meta["val"].stride()) + ) + if user.op == "output": + if isinstance(result.data.data, (Pointwise, Reduction)): + result.realize() + + # TODO(jansel): introduce a store vs inline choice + result.mark_reuse(len(n.users)) + + # Realize if the IRNode already has accumulated lots of reads + if isinstance(result, TensorBox) and result.has_exceeded_max_reads(): + # Prevent excessive accumulation in a computed buffer, when + # there are multiple branches each with small number of memory + # reads, but they converge to a user. + result.realize_hint() + + # Realize if a Pointwise has too much stuff to be inlined. + # As this may cause RecursionError during Inductor's evaluation. + if isinstance(result, TensorBox) and isinstance(result.data, StorageBox): + curr = result.data.data + if isinstance(curr, Pointwise): + # Use inner fn as a rough proxy. Good enough. + if curr.has_large_inner_fn(): + result.realize() + + # This is not complete, but it doesn't have to be: origin_node + # tracking is best effort. The logic here critically relies on direct + # TensorBox -> StorageBox denoting a non-view; we don't bother trying + # to get views to work. Feel free to add any extra cases as needed. + # + # Note: we can't YOLO tree_map over this result, because if there are + # buffers or a view involved, we might not be able to validly assign + # the origin_node here. + if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox): + if isinstance(result.data.data, ir.Loops): + result.data.data.origin_node = n + elif isinstance(result.data.data, ir.Buffer): + result.data.data.origin_node = n + if isinstance(result.data.data, ir.ComputedBuffer) and isinstance( + result.data.data.data, ir.Loops + ): + result.data.data.data.origin_node = n + # Not really multi-output, can straightforwardly recurse in + elif ( + isinstance(result.data.data, ir.MultiOutput) + and not result.data.data.indices + ): + if isinstance(result.data.data.inputs[0], ir.Buffer): + result.data.data.inputs[0].origin_node = n + + self.register_users_of(result) + + return result + + def validate_can_generate_cpp_wrapper(self): + if config.disable_cpp_codegen: + raise CppWrapperCodeGenError("C++ codegen is disabled") + + if sys.platform not in ["linux", "darwin"]: + raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}") + + for value in self.graph_inputs.values(): + dtype = None + if isinstance(value, TensorBox): + dtype = value.get_dtype() + elif isinstance( + value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) + ): + dtype = may_get_constant_buffer_dtype(value) + + if not supported_dtype_of_cpp_wrapper(dtype, self.cuda): + raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}") + + def init_wrapper_code(self): + self.cuda = "cuda" in self.device_types + if self.cpp_wrapper: + self.validate_can_generate_cpp_wrapper() + self.wrapper_code = CppWrapperCuda() if self.cuda else CppWrapperCpu() + else: + device_types = self.device_types.copy() + device_types.discard("cpu") + # TODO(Eikan): Only support mixing cpu and other device now. + assert len(device_types) <= 1, "Does not support mixing {}".format( + "+".join(device_types) + ) + only_cpu = len(device_types) == 0 + device_type = "cpu" if only_cpu else device_types.pop() + + self.device_ops = get_device_op_overrides(device_type) + wrapper_code_gen_cls = get_wrapper_codegen_for_device(device_type) + assert ( + wrapper_code_gen_cls is not None + ), f"Device {device_type} not supported" + self.wrapper_code = wrapper_code_gen_cls() + + if self.const_module: + # If we have const module, we could reuse the kernels + # This could avoid duplication and save time on doing recompilation (if Triton.) + self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter + self.wrapper_code.src_to_kernel = ( + self.const_module.wrapper_code.src_to_kernel + ) + + def codegen_with_cpp_wrapper(self): + """ + For CPU, the cpp wrapper codegen is done in one pass. + For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python + wrapper code and run it to generate autotuned kernel binaries in the first pass; and then + generate cpp wrapper code and compile it to a dynamic library in the second pass. + """ + if "cuda" in self.device_types: + # first pass + self.cpp_wrapper = False + compiled = self.compile_to_module().call + + def materialize(x): + if isinstance(x, (torch.SymInt, torch.SymFloat)): + # Need concrete value to run dynamic shapes and tune the result + return x.node.hint + elif isinstance(x, FakeTensor): + return defake(x) + else: + assert isinstance( + x, torch.Tensor + ), "Unknown type when creating real inputs" + str(type(x)) + return x + + if tracing_context := torch._guards.TracingContext.try_get(): + if tracing_context.output_strides: + tracing_context.output_strides.clear() + + params_flat = [ + param + for param in tracing_context.params_flat # type: ignore[union-attr] + if param is not None + ] + real_inputs = [ + materialize(x) for x in itertools.chain(params_flat, V.real_inputs) + ] + else: + real_inputs = [materialize(x) for x in V.real_inputs] + + with torch.utils._python_dispatch._disable_current_modes(): + assert self.example_inputs is not None + compiled(real_inputs) + del real_inputs + + # second pass + # TODO: reuse self.scheduler from the first pass to speed up the second pass + self.cpp_wrapper = True + self.removed_buffers.clear() + self.inplaced_to_remove.clear() + return self.codegen() + else: + # cpu + return self.codegen() + + def codegen(self): + from .scheduler import Scheduler + + self.init_wrapper_code() + + self.scheduler = Scheduler(self.buffers) + V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes) + + self.scheduler.codegen() + return self.wrapper_code.generate(self.is_inference) + + def codegen_subgraph(self, parent_graph): + """ + This is a more compact version of the `codegen()` above + where we codegen this graph as a subgraph of some parent + graph. The parent graph is passed as an argument: the + intention is to inline codegening of the subgraph in + the parent graph's wrapper code (including the generated + kerenls). The wrapper code is not finalized (via `.generate()` + call), as this will be done in the parent graph's `codegen()`. + """ + from .scheduler import Scheduler + + self.wrapper_code = parent_graph.wrapper_code + self.device_ops = parent_graph.device_ops + self.cpp_wrapper = parent_graph.cpp_wrapper + + self.scheduler = Scheduler(self.buffers) + self.scheduler.codegen() + + def count_bytes(self): + from .scheduler import Scheduler + + scheduler = Scheduler(self.buffers) + + total_bytes = 0 + node_counts = [] + node_runtimes = [] + for node in scheduler.nodes: + num_bytes = node.get_read_write_buffers_sizes() + total_bytes += num_bytes + node_counts.append((node, num_bytes // 4)) + node_runtimes.append((node, node.get_estimated_runtime())) + return total_bytes, node_counts, node_runtimes + + @dynamo_timed(phase_name="code_gen") + def compile_to_module(self): + from .codecache import PyCodeCache + + code, linemap = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) + linemap = [(line_no, node.stack_trace) for line_no, node in linemap] + key, path = PyCodeCache.write(code) + mod = PyCodeCache.load_by_key_path( + key, path, linemap=linemap, attrs=self.constants + ) + self.cache_key = key + self.cache_path = path + self.cache_linemap = linemap + + # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029 + # TODO. Revisit this once the logging API is more mature + assert mod.__file__ is not None + + log_module_code(mod.__file__) + log.debug("Output code written to: %s", mod.__file__) + output_code_log.debug("Output code: \n%s", code) + trace_structured( + "inductor_output_code", + lambda: {"filename": mod.__file__}, + payload_fn=lambda: code, + ) + output_code_log.info("Output code written to: %s", mod.__file__) + if config.benchmark_kernel: + print(f"Compiled module path: {mod.__file__}", file=sys.stderr) + V.debug.output_code(mod.__file__) + V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") + return mod + + def compile_to_fn(self): + if self.aot_mode: + from .codecache import AotCodeCompiler + + assert self.cpp_wrapper, "AOT mode only supports C++ wrapper" + code, linemap = self.codegen_with_cpp_wrapper() + output_code_log.debug("Output code: \n%s", code) + + serialized_extern_kernel_nodes = None + if ( + config.is_fbcode() + and self.extern_kernel_nodes + and self.extern_node_serializer + ): + serialized_extern_kernel_nodes = self.extern_node_serializer( + self.extern_kernel_nodes + ) + output_code_log.debug( + "Serialized Extern Kernel Nodes: \n%s", + serialized_extern_kernel_nodes, + ) + + # Directly return the file path with the compiled code + return AotCodeCompiler.compile( + self, code, serialized_extern_kernel_nodes, cuda=self.cuda + ) + else: + return self.compile_to_module().call + + def get_output_names(self): + return [ + node.get_name() + for node in self.graph_outputs + if not isinstance(node, ir.NoneAsConstantBuffer) + and not isinstance(node, ir.ShapeAsConstantBuffer) + ] + + def is_unspec_arg(self, name: str): + # dynamo wraps unspec variable as 0d CPU tensor, + # need to convert to scalar during codegen (triton only) + return ( + name in self.graph_inputs.keys() + and self.graph_inputs[name].get_numel() == 1 + and self.graph_inputs[name].get_device().type == "cpu" + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ir.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..a2785e347cb1ef16d89184e56902f41af207c510 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ir.py @@ -0,0 +1,8064 @@ +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import re +import textwrap +import traceback +from contextlib import nullcontext +from enum import Enum +from functools import partial +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + Union, +) +from unittest.mock import patch + +import sympy +from sympy import Expr, Integer + +import torch._export.serde.schema as export_schema + +import torch._logging + +import torch.fx +import torch.utils._pytree as pytree +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import identity +from torch._export.serde.serialize import GraphModuleSerializer +from torch._higher_order_ops.auto_functionalize import can_auto_functionalize +from torch._prims_common import ( + compute_required_storage_length, + is_boolean_dtype, + is_float_dtype, + make_channels_last_strides_for, + make_contiguous_strides_for, + StrideType, +) +from torch._subclasses.fake_tensor import get_schema_info +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymTypes +from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing + +from . import config, dependencies +from .codegen.common import index_prevent_reordering +from .dependencies import ( + extract_free_unbacked_symbols, + extract_input_node_reduction_ranges, + extract_read_writes, + var_builder, +) +from .ops_handler import OpCounterCSE +from .utils import ( + argsort, + cache_on_self, + convert_shape_to_inductor, + convert_shape_to_symint, + developer_warning, + get_kernel_metadata, + is_dynamic, + pad_listlike, + sympy_dot, + sympy_index_symbol, + sympy_product, + sympy_subs, +) +from .virtualized import ops, V + +if TYPE_CHECKING: + from .graph import GraphLowering + +log = logging.getLogger(__name__) +indent = functools.partial(textwrap.indent, prefix=" ") +aten = torch.ops.aten + +""" [Note: Inductor IR] + +Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each +lowering is registered to a particular aten operator, and expects inputs that +correspond to the aten schema. However, in place of torch Tensor inputs, lowerings +expect Inductor TensorBox inputs. + +TensorBox IR represents torch tensors. Tensors are sometimes single objects owning +storage, and sometimes views of another Tensor's storage. Mutating tensor operations +(such as add_()) affect the underlying storage and any associated views. Other operations +(such as .t_()) update metadata about the current view but don't modify the underlying storage. + +To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer. + +TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor +output from an operation. But just as torch.Tensors take different forms, TensorBox IR can +reference View IR or directly reference StorageBox IRs. + +Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops) +may take an existing TensorBox and point it to a new underlying View IR. + +Tensors that directly own storage are represented as a chain of: +TensorBox -> StorageBox -> Buffer +where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout. + +If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer +(leaving the old buffer unmodified and functionalizing the operation). + +Tensors backed by views add one more indirection to the IR. +TensorBox -> View -> StorageBox -> Buffer +In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox. +""" + + +def validate_ir(node_or_nodes): + def _check_tensorbox(nodes): + # Could expand this to check deeper properties + # (e.g. TensorBox points to View or StorageBox) + if isinstance(nodes, (list, tuple)): + for node in nodes: + _check_tensorbox(node) + elif isinstance(nodes, dict): + for node in nodes.values(): + _check_tensorbox(node) + else: + assert isinstance( + nodes, + ( + torch._inductor.ir.ExpandView, + DynamicScalar, + AssertScalar, + TensorBox, + sympy.logic.boolalg.Boolean, + Expr, + ), + ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" + + # Be picky about the accepted data structure (don't use pytree here) + _check_tensorbox(node_or_nodes) + + +def ops_wrapper(name): + assert isinstance(name, str) + + def fn(*args, **kwargs): + return getattr(ops, name)(*args, **kwargs) + + return fn + + +def inverse_reorder(order): + inv_order = dict(zip(order, range(len(order)))) + + def reindex(index): + assert len(index) == len(inv_order) + return [index[inv_order[i]] for i in range(len(index))] + + return reindex + + +def same_reorder(order): + def reindex(index): + assert len(index) == len(order) + return [index[order[i]] for i in range(len(index))] + + return reindex + + +def fuse_reindexing(reindex1, reindex2): + def reindex(index): + return reindex1(reindex2(index)) + + return reindex + + +NHWC_STRIDE_ORDER = [3, 0, 2, 1] + + +def stride_order2fill_order(order): + """ + Convert stride order to fill order + For channel last format, + stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0] + """ + lookup = {pos: idx for idx, pos in enumerate(order)} + fill_order = [lookup[i] for i in range(len(order))] + return fill_order + + +def get_stride_order(seq: Sequence[int]) -> List[int]: + """ + Convert strides to stride order + """ + sorted_idx: List[int] = argsort(seq) + out = [0 for _ in range(len(seq))] + for i, elem in enumerate(sorted_idx): + out[elem] = i + return out + + +def ir_node_to_tensor(x, guard_shape=True): + if x is None: + return None + + shape_fn: Callable[[Expr], Union[int, Expr]] + if not guard_shape: + shape_fn = V.graph.sizevars.size_hint + else: + shape_fn = identity + size = [shape_fn(s) for s in x.get_size()] + stride: StrideType + if is_storage_and_layout(x): + stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] + else: + stride = make_contiguous_strides_for(size) # type: ignore[arg-type] + dtype = x.get_dtype() + device = x.get_device() + size = convert_shape_to_symint(size) + stride = convert_shape_to_symint(stride) + t = torch.empty_strided( + size=size, stride=stride, dtype=dtype, device=device + ).zero_() + return t + + +def may_convert_to_optional(value): + if isinstance(value, list) and not value: + # [None] makes sure the cpp wrapper codegen will generate something like + # {c10::nullopt} instead of {} + return [None] + return value + + +def get_device_type(x): + if getattr(x, "get_device", None): + return get_device_type(x.get_device()) + if isinstance(x, torch.device): + return x.type + return None + + +def is_triton(x): + return get_device_type(x) == "cuda" + + +def is_cpu(x): + return get_device_type(x) == "cpu" + + +class IRNode: + _current_origins: ClassVar[Set[Any]] = set() + + @staticmethod + @contextlib.contextmanager + def current_origins(origins: Set[torch.fx.Node]): + old = IRNode._current_origins + IRNode._current_origins = old | origins + try: + yield + finally: + IRNode._current_origins = old + + def __post_init__(self): + self.origins = set(self._current_origins) + self.traceback = traceback.format_stack() if config.debug_ir_traceback else None + + def get_traceback(self): + return self.traceback + + def common_repr(self): + origins = f"origins={getattr(self, 'origins', '')}" + if len(origins) > 64: + # this can get *very* long + origins = f"{origins[:61]}..." + return [origins] + + def str_helper(self, lines): + lines = lines + self.common_repr() + lines = indent(",\n".join(map(str, lines))) + return f"{type(self).__name__}(\n{lines}\n)" + + def is_user_of(self, name): + return name in self.get_read_names() + + @cache_on_self + def get_read_names(self): + return {dep.name for dep in self.get_reads()} + + def get_dtype(self): + return self.dtype + + def get_layout(self): + raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!") + + def get_size(self): + raise NotImplementedError(f"get_size() is not implemented by {type(self)}!") + + def get_numel(self): + return sympy_product(self.get_size()) + + def is_zero_elements(self): + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] + + def realize(self): + """ + If the IRNode refers to data which has not been materialized (e.g., + it is a Pointwise/Reduction that could potentially have more + compute fused into it), realize the IRNode into physical memory, + ending the possibility of fusing into it, but allowing, e.g., multiple + users to access the data without having to recompute. + + Check StorageBox.realize for a particularly notable implementation. + + TODO(ezyang): I think, in principle, every IRNode should have an + implementation of this, and most of the time no-op is OK, but you + really do have to audit each IRNode for this, so for now, raise + an error if it's not implemented. Note that some code in graph.py + will catch this thrown error and suppress it with a warning. + """ + raise NotImplementedError(f"realize NYI on {type(self)}") + + def codegen_reference(self, writer=None): + raise NotImplementedError(f"codegen_reference NYI on {type(self)}") + + # The abstract method declarations below serve to convince mypy that all IRNode instances have these functions + # defined, while having no effect at runtime. We cannot create stub implementations here because other parts of + # the code dynamically check for defined attributes. + get_device: Callable[[], torch.device] + dtype: torch.dtype + get_name: Callable[[], str] + get_reads: Callable[[], Any] + get_stride: Callable[[], Any] + get_storage_numel: Callable[[], Any] + has_exceeded_max_reads: Callable[[], bool] + make_loader: Callable[[], Callable[[Any], Any]] + make_indexer: Callable[[], Callable[[Any], Any]] + mark_reuse: Callable[[int], None] + realize_hint: Callable[[], None] + get_unbacked_symbol_uses: Callable[[], Set[sympy.Symbol]] + + +@dataclasses.dataclass +class Loops(IRNode): + device: torch.device + dtype: torch.dtype + inner_fn: Callable[..., Any] + ranges: List[Expr] + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return set().union( + *(free_unbacked_symbols(e) for e in self.ranges), + self.inner_fn_free_unbacked_symbols(), + ) + + def __str__(self, names=("ranges",)): + return self.str_helper( + [ + f"'{self.device.type}'", + str(self.dtype), + self.inner_fn_str(), + ] + + [f"{name}={getattr(self, name)}" for name in names] + + [f"origin_node={self.origin_node!r}"] + ) + + def __post_init__(self): + super().__post_init__() + self.origin_node = None + + __repr__ = __str__ + + def get_device(self): + return self.device + + def get_origin_node(self): + return self.origin_node + + def get_size(self): + return self.ranges + + def get_pointwise_size(self): + return self.ranges + + def is_extern(self): + return False + + @classmethod + def create(cls, *args, **kwargs): + origin_node = kwargs.pop("origin_node", None) + tb = kwargs.pop("traceback", None) + r = cls(*args, **kwargs) + r.origin_node = origin_node + r.traceback = ( + tb or traceback.format_stack() if config.debug_ir_traceback else None + ) + return TensorBox.create(r) + + @staticmethod + def _index(ranges, prefix="i"): + return [ + sympy.Integer(0) if s == 1 else sympy_index_symbol(f"{prefix}{n}") + for n, s in enumerate(ranges) + ] + + @cache_on_self + def inner_fn_opcount(self): + from .ir import FlexibleLayout + + opcounter = OpCounterCSE(V.MockHandler()) + + with V.set_ops_handler(opcounter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + result = self.inner_fn(*self.inner_fn_args()) + return opcounter.op_count + + def inner_fn_args(self): + return (self._index(self.ranges),) + + def inner_fn_str(self): + return V.KernelFormatterHandler.ir_to_string( + self.inner_fn, *self.inner_fn_args() + ) + + def has_large_inner_fn(self): + return self.inner_fn_opcount() > config.realize_opcount_threshold + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + return extract_free_unbacked_symbols(self.inner_fn, index) + + def get_reads(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.get_reduction_type(): + return extract_read_writes( + self.make_loader(), + self.get_size(), + self.get_reduction_size(), + ).reads + else: + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def get_reduction_size(self): + raise NotImplementedError( + f"get_reduction_size() is not implemented by {type(self)}!" + ) + + def get_reduction_type(self): + raise NotImplementedError( + f"get_reduction_type() is not implemented by {type(self)}!" + ) + + def constant_to_device(self, device): + raise NotImplementedError( + f"constant_to_device() is not implemented by {type(self)}!" + ) + + +def nop_loader_fn(idx, *, dtype): + if dtype.is_floating_point: + return ops.constant(float("nan"), dtype) + else: + return ops.constant(0, dtype) + + +class Pointwise(Loops): + def make_loader(self): + # Make zero-element loops into a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.dtype) + + return self.inner_fn + + def get_reduction_size(self): + return [] + + def get_reduction_type(self): + return None + + def store_output(self, output_name, indexer, vars): + loader = self.make_loader() + return ops.store(output_name, indexer(vars), loader(vars)) + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise(device, self.dtype, loader, self.ranges) + + +@dataclasses.dataclass +class Scatter(Pointwise): + output_indexer: Callable[[List[Expr]], Expr] + scatter_mode: Optional[str] = None + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Scatter( + device, + self.dtype, + loader, + self.ranges, + self.output_indexer, + self.scatter_mode, + ) + + def store_output(self, output_name, indexer, vars): + loader = self.make_loader() + return ops.store( + output_name, + indexer(self.output_indexer(vars)), + loader(vars), + mode=self.scatter_mode, + ) + + +class ReductionHint(Enum): + INNER = 0 + OUTER = 1 + OUTER_TINY = 2 + DEFAULT = 3 + + +class TileHint(Enum): + SQUARE = 0 + DEFAULT = 1 + + +REDUCTION_COMBINE_FN = { + "any": ops_wrapper("logical_or"), + "max": ops_wrapper("maximum"), + "min": ops_wrapper("minimum"), + "prod": ops_wrapper("mul"), + "sum": ops_wrapper("add"), + "xor_sum": ops_wrapper("bitwise_xor"), +} + + +def get_reduction_combine_fn(reduction_type, dtype): + if reduction_type in REDUCTION_COMBINE_FN: + combine_fn = REDUCTION_COMBINE_FN[reduction_type] + elif reduction_type in {"argmax", "argmin"}: + + def combine_fn(a, b): + a_value, a_index = a + b_value, b_index = b + + if reduction_type == "argmin": + mask = ops.lt(a_value, b_value) + else: + mask = ops.gt(a_value, b_value) + + equal = ops.eq(a_value, b_value) + if is_float_dtype(dtype): + a_isnan = ops.ne(a_value, a_value) + b_isnan = ops.ne(b_value, b_value) + mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan)) + equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan)) + + mask = ops.logical_or( + mask, ops.logical_and(equal, ops.lt(a_index, b_index)) + ) + return ( + ops.where(mask, a_value, b_value), + ops.where(mask, a_index, b_index), + ) + + elif reduction_type == "welford_combine": + + def combine_fn(a, b): + a_mean, a_m2, a_weight = a + b_mean, b_m2, b_weight = b + + delta = b_mean - a_mean + new_weight = a_weight + b_weight + w2_over_w = b_weight / new_weight + return ( + a_mean + delta * w2_over_w, + a_m2 + b_m2 + delta * delta * a_weight * w2_over_w, + new_weight, + ) + + else: + raise NotImplementedError(f"unknown reduction_type={reduction_type}") + + return combine_fn + + +@dataclasses.dataclass +class Reduction(Loops): + reduction_ranges: List[Expr] + reduction_type: str + # self.dtype represents the dst dtype + src_dtype: torch.dtype + reduction_hint: ReductionHint + + def __str__(self): + return Loops.__str__( # type: ignore[call-arg] + self, names=("ranges", "reduction_ranges", "reduction_type") + ) + + def __repr__(self): + return self.__str__() + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return super().get_unbacked_symbol_uses() | set().union( + *(free_unbacked_symbols(e) for e in self.reduction_ranges) + ) + + def get_reduction_size(self): + return self.reduction_ranges + + def get_reduction_type(self): + return self.reduction_type + + def store_reduction(self, output_name, indexer, vars, reduction_vars): + value = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + return ops.store_reduction(output_name, indexer(vars), value) + + def index_length(self): + return len(self.ranges) + len(self.reduction_ranges) + + def inner_fn_args(self): + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, "r") + return (index, rindex) + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, "r") + return extract_free_unbacked_symbols(self.inner_fn, index, rindex) + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Reduction( + device, + self.dtype, + loader, + self.ranges, + self.reduction_ranges, + self.reduction_type, + self.src_dtype, + ReductionHint.DEFAULT, + ) + + @staticmethod + def num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node: Optional[IRNode] = None, + ): + def _is_static(x): + return isinstance(x, (int, sympy.Integer)) + + reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel) + numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges)) + + should_split = ( + is_triton(device) + and reduction_type + not in { + "argmax", + "argmin", + } + and config.split_reductions + # We don't support unbacked symints + and _is_static(reduction_numel_hint) + and _is_static(numel_hint) + ) + if not should_split: + return ReductionHint.DEFAULT, 1 + + device_interface = get_interface_for_device(get_device_type(device)) + num_sm = device_interface.Worker.get_device_properties( + device + ).multi_processor_count + min_elements_per_thread = 32 + max_elements_per_thread = 512 + threads_per_sm = 2048 + min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm + max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm + + def inner_reduction_splits(reduction_numel_hint, numel_hint): + # do heuristics that's close to eager mode for split inner reduction + # we leak reduction autotune configs here, and will need to refactor to avoid this later + num_warps = 8 + num_threads = 32 * num_warps + if numel_hint >= 2 * num_sm: # don't split if there are enough outputs + return 1 + if reduction_numel_hint <= 8192: + return 1 + if reduction_numel_hint * numel_hint <= min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (2 * num_threads) + blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint + tmp_split_size = ( + reduction_numel_hint + num_threads * blocks_per_output - 1 + ) // (num_threads * blocks_per_output) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(closest - tmp_split_size) < 30: + # prefer even splits, but never smalle than min_elements_per_thread + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + return (reduction_numel_hint + split_size * num_threads - 1) // ( + split_size * num_threads + ) + + def outer_reduction_splits(reduction_numel_hint, numel_hint): + # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128 + # extend to even smaller number of outputs + num_warps = 8 + num_threads = num_warps * 32 + rvals_per_thread = 4 # comes from heuristics, refactor to not leak here + xvals_per_block = 128 + xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block + if reduction_numel_hint * numel_hint < min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (num_threads) + target_blocks = (target_blocks + xblocks - 1) // xblocks + tmp_split_size = ( + reduction_numel_hint + rvals_per_thread * target_blocks - 1 + ) // (rvals_per_thread * target_blocks) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(tmp_split_size - closest) < 20: + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + + return (reduction_numel_hint + rvals_per_thread * split_size - 1) // ( + rvals_per_thread * split_size + ) + + # easy cases + if numel_hint == 1: + split = inner_reduction_splits(reduction_numel_hint, numel_hint) + if split == 1: + # No need to split. + return ReductionHint.INNER, split + if ( + len(ranges) == 0 + and input_node is not None + and isinstance(input_node, TensorBox) + ): + # Only handles the case where keep_dim = False. + # Otherwise, we need to propagate reduction dim info to the stage where + # the intermediate loader of the first Reduction is generated. + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node + ) + if new_ranges is not None and new_reduction_ranges is not None: + extracted_numel_hint = V.graph.sizevars.symbolic_hint( + sympy_product(new_ranges + new_reduction_ranges) + ) + if reduction_numel_hint == extracted_numel_hint: + log.debug( + "Use previous IRNode's range and reduction_ranges instead of split. " + "current ranges: %s, current reduction ranges: %s, current split: %d, " + "new ranges: %s, new reduction ranges: %s", + ranges, + reduction_ranges, + split, + new_ranges, + new_reduction_ranges, + ) + # If the input_node or its dependent nodes are also Reduction nodes, + # use reduction_sizes of this node or its dependent nodes directly. + return ReductionHint.INNER, -1 + return ReductionHint.INNER, split + if ( + reduction_numel_hint <= min_elements_per_thread + or numel_hint >= num_sm * 2 * 32 + ): + return ReductionHint.DEFAULT, 1 + + r = Reduction( + device, + dst_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + src_dtype, + ReductionHint.DEFAULT, + ) + + def get_read_indices(r): + cb = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=r.get_device(), + dtype=r.get_dtype(), + size=r.get_size(), + ), + data=r, + ) + read_writes = cb.get_read_writes() + # try finding the full size producer + # TODO this will fail for something like ((1, N) * (N, 1)).sum() + # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare + range_vars = [ + r + for r in read_writes.range_vars + if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number) + ] + indices = [] + changed = False + for md in sorted(read_writes.reads, key=lambda x: x.name): + if all(r in md.index.free_symbols for r in range_vars): + indices.append(md.index) + if md.name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[md.name] + original_stride = buf.layout.stride + buf.decide_layout() + if buf.layout.stride != original_stride: + changed = True + return indices, changed + + indices, changed = get_read_indices(r) + if changed: + indices, _ = get_read_indices(r) + + if len(indices) == 0: + # TODO determine splits when all inputs are broadcast + return ReductionHint.DEFAULT, 1 + + (_, reduction_vars), ranges = dependencies.index_vars_squeeze( + r.get_size(), r.get_reduction_size() + ) + num_outer = 0 + num_inner = 0 + for i in indices: + i = V.graph.sizevars.simplify_with_ranges(i, ranges) + strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys()) + outer = all(s > 1 for s in strides) + if outer: + num_outer += 1 + else: + num_inner += 1 + if num_inner > num_outer: + return ReductionHint.INNER, inner_reduction_splits( + reduction_numel_hint, numel_hint + ) + else: + return ReductionHint.OUTER, outer_reduction_splits( + reduction_numel_hint, numel_hint + ) + + @staticmethod + def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type, src_dtype): + """Convert inner_fn from a reduction to an pointwise""" + reduction_ranges = [ + V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges + ] + + combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) + + def fn(index): + return functools.reduce( + combine_fn, + ( + value_fn(index, rindex) + for rindex in itertools.product( + *[range(x) for x in reduction_ranges] + ) + ), + ) + + if reduction_type in ("argmin", "argmax"): + flatten_index = FixedLayout( + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + reduction_ranges, + FlexibleLayout.contiguous_strides(reduction_ranges), + ).make_indexer() + + def value_fn(index, rindex): + rindex = [sympy.expand(i) for i in rindex] + return ( + inner_fn(index, rindex), + ops.index_expr(flatten_index(rindex), torch.int64), + ) + + return lambda index: fn(index)[1] + else: + value_fn = inner_fn + return fn + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + input_node: Optional[IRNode] = None, + ): + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + if reduction_numel == 0: + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val): + return ( + bool(val) + if dst_dtype == torch.bool + else float(val) + if dst_dtype.is_floating_point + else int(val) + ) + + rtypes_to_inits = { + "sum": py_cnst(0), + "xor_sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert ( + reduction_type in rtypes_to_inits.keys() + ), f"{reduction_type} not supported for zero-dimension tensors!" + + def const_fn(index): + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + ) + + if reduction_numel == 1: + # this reduction is actually a pointwise op + if reduction_type in ("argmin", "argmax"): + + def fn(index): + return ops.constant(0, dst_dtype) + + else: + + def fn(index): + reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + return inner_fn(index, reduction_index) + + return Pointwise.create(device, dst_dtype, fn, ranges) + + if ( + isinstance(reduction_numel, sympy.Integer) + and V.graph.sizevars.size_hint(reduction_numel) + < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ): + return Pointwise.create( + device, + dst_dtype, + cls._unroll_reduction_fn( + inner_fn, reduction_ranges, reduction_type, src_dtype + ), + ranges, + ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = cls.num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split == -1: + assert input_node is not None + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node # type: ignore[arg-type] + ) + assert new_ranges is not None + assert new_reduction_ranges is not None + return cls.create_multilayer_existing_ranges( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + elif split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + return TensorBox.create( + Reduction( + device, + dst_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + src_dtype, + reduction_hint, + ) + ) + + @staticmethod + def default_accumulator(reduction_type, dtype): + if reduction_type in {"max", "argmax"}: + if is_float_dtype(dtype): + return float("-inf") + elif is_boolean_dtype(dtype): + return 0 + else: + return torch.iinfo(dtype).min + if reduction_type in {"min", "argmin"}: + if is_float_dtype(dtype): + return float("inf") + elif is_boolean_dtype(dtype): + return 1 + else: + return torch.iinfo(dtype).max + + return { + "sum": 0, + "prod": 1, + "xor_sum": 0, + "any": 0, + "welford_reduce": (0, 0, 0), + "welford_combine": (0, 0, 0), + }[reduction_type] + + @staticmethod + def default_value(reduction_type, dtype): + if reduction_type == "welford_reduce": + return 0 + return Reduction.default_accumulator(reduction_type, dtype) + + @staticmethod + def _multilayer_second_step_hint( + split: int, numel_hint: int, reduction_hint: ReductionHint + ) -> ReductionHint: + if split == -1: + return reduction_hint + if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER: + return ReductionHint.OUTER_TINY + if ( + split <= 1024 + and numel_hint <= 256 + and reduction_hint == ReductionHint.OUTER + ): + return ReductionHint.OUTER_TINY + + return reduction_hint + + @classmethod + def _multilayer_wrap_loader( + cls, + loader, + reduction_ranges, + reduction_numel, + split, + block_size, + default, + ): + reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel]) + need_mask = not V.graph.sizevars.is_expr_static_and_true( + sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] + ) + + def wrapper_fn(index, reduction_index): + (reduction_index,) = reduction_index + *new_index, reduction_block = index + indices = block_size * reduction_block + reduction_index + + def body(): + return loader(new_index, reindex([indices])) + + if need_mask: + mask = ops.lt( + ops.index_expr(indices, torch.int32), + ops.index_expr(reduction_numel, torch.int32), + ) + return ops.masked(mask, body, default) + else: + return body() + + return wrapper_fn + + @classmethod + def _multilayer_wrap_loader_existing_ranges( + cls, + loader, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + default, + ): + assert len(original_ranges) == 0, f"{original_ranges}= is not equal to []" + reindex = View.dynamic_reshape_indexer( + original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges) + ) + + def wrapper_fn(index, reduction_index): + return loader([], reindex(tuple(index) + tuple(reduction_index))) + + return wrapper_fn + + @classmethod + def create_multilayer_helper( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + wrapper_fn: Callable[..., Any], + original_ranges: List[Expr], + original_reduction_ranges: List[Expr], + new_ranges: List[Expr], + new_reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # triton will automatically compute reductions in fp32 if reducing over fp16/bf16 + # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction + # in fp32 and not reduce precision by breaking up the kernel into multiple layers + intermediate_dtype = ( + dst_dtype + if dst_dtype not in (torch.float16, torch.bfloat16) + else torch.float + ) + intermediate = Reduction.create( + device, + intermediate_dtype, + src_dtype, + wrapper_fn, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + intermediate.realize() + intermediate_loader = intermediate.make_loader() + + def intermediate_fn(index, reduction_index): + return intermediate_loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + + assert original_ranges == new_ranges[: len(original_ranges)] + return TensorBox.create( + Reduction( + device, + dst_dtype, + intermediate_fn, + original_ranges, + new_ranges[len(original_ranges) :], + reduction_type, + src_dtype, + reduction_hint, + ) + ) + + @classmethod + def create_multilayer( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # TODO(jansel): realize the reduction so we can do dynamic indexing + reduction_numel = sympy_product(reduction_ranges) + block_size = FloorDiv(reduction_numel + (split - 1), split) + default = cls.default_value(reduction_type, dst_dtype) + wrapper_fn = cls._multilayer_wrap_loader( + inner_fn, reduction_ranges, reduction_numel, split, block_size, default + ) + + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + ranges, + reduction_ranges, + [*ranges, split], # type: ignore[list-item] + [block_size], + reduction_type, + split, + reduction_hint, + ) + + @classmethod + def create_multilayer_existing_ranges( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + original_ranges: List[Expr], + original_reduction_ranges: List[Expr], + new_ranges: List[Expr], + new_reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + default = cls.default_value(reduction_type, dst_dtype) + wrapper_fn = cls._multilayer_wrap_loader_existing_ranges( + inner_fn, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + default, + ) + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + reduction_type, + -1, + reduction_hint, + ) + + +def num_reduction_outputs(reduction_type): + return 3 if "welford" in reduction_type else 1 + + +class WelfordReduction(Reduction): + output_index: int + + def __init__( + self, + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + reduction_hint, + output_index, + ): + if len(inner_fns) == 1: + loader = inner_fns[0] + else: + + def loader(idx, reduction_idx): + return tuple(fn(idx, reduction_idx) for fn in inner_fns) + + super().__init__( + device, + dtype, + loader, + ranges, + reduction_ranges, + reduction_type, + dtype, + reduction_hint, + ) + self.output_index = output_index + + def store_reduction(self, output_name, indexer, vars, reduction_vars): + values = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + value = values[self.output_index] + return ops.store_reduction(output_name, indexer(vars), value) + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + ): + assert reduction_type in {"welford_reduce", "welford_combine"} + + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + def const(val): + def inner_fn(idx): + return ops.constant( + val, + dtype, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_numel == 0: + mean = const(0) + m2 = const(0) + weight = const(0) + return mean, m2, weight + + if reduction_numel == 1: + + def copy(loader): + def inner_fn(idx): + reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + return loader(idx, reduction_index) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_type == "welford_reduce": + return copy(inner_fns[0]), const(0), const(1) + else: + return tuple(copy(fn) for fn in inner_fns) + + # TODO: Unrolled reduction + # if ( + # isinstance(reduction_numel, sympy.Integer) + # and V.graph.sizevars.size_hint(reduction_numel) + # < config.unroll_reductions_threshold + # and sympy_product(ranges) != 1 + # ): + # return Pointwise.create( + # device, + # dst_dtype, + # cls._unroll_reduction_fn( + # inner_fn, reduction_ranges, reduction_type, src_dtype + # ), + # ranges, + # ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = Reduction.num_splits( + device, + dtype, + dtype, + inner_fns[0], + ranges, + reduction_ranges, + reduction_type=reduction_type, + reduction_numel=reduction_numel, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + results = [ + TensorBox.create( + WelfordReduction( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + reduction_hint, + output_idx, + ) + ) + for output_idx in range(3) + ] + for t in results: + t.realize() + return results + + @staticmethod + def default_value(reduction_type, dtype): + return (0, 0, 0) + + @classmethod + def create_multilayer( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + reduction_numel = sympy_product(reduction_ranges) + need_mask = not V.graph.sizevars.is_expr_static_and_true( + sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] + ) + + if need_mask and reduction_type != "welford_combine": + # If we need mask, then "welford_reduce" doesn't work because + # masked inputs shouldn't count towards the welford weight + + def constant(idx, reduction_idx, value): + return ops.constant(value, dtype) + + return cls.create_multilayer( + device=device, + dtype=dtype, + inner_fns=( + inner_fns[0], + partial(constant, value=0), + partial(constant, value=1), + ), + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type="welford_combine", + split=split, + reduction_hint=reduction_hint, + ) + + block_size = FloorDiv(reduction_numel + (split - 1), split) + intermediates = WelfordReduction.create( + device, + dtype, + tuple( + cls._multilayer_wrap_loader( + loader, + reduction_ranges, + reduction_numel, + split, + block_size, + default=0, + ) + for loader in inner_fns + ), + [*ranges, split], # type: ignore[list-item] + [block_size], + reduction_type, + reduction_hint, + ) + for i in intermediates: + i.realize() + + i_loaders = [i.make_loader() for i in intermediates] + + def intermediate_loader_fn(index, reduction_index, loader): + return loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + return WelfordReduction.create( + device, + dtype, + tuple( + partial(intermediate_loader_fn, loader=i.make_loader()) + for i in intermediates + ), + ranges, + [split], # type: ignore[list-item] + # welford_reduce turns one input into three outputs, which are combined with welford_combine + "welford_combine", + reduction_hint, + ) + + +@dataclasses.dataclass +class Scan(Loops): + scan_ranges: List[Expr] + size: List[Expr] + combine_fn: Callable[..., Any] + reindex: Callable[[List[Expr], List[Expr]], List[Expr]] + reduction_hint: ReductionHint + init: int + + # HACK we mimick reduction + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we + # need to explicitly represent the closure so we can pull out unbacked + # symbols here + return ( + super().get_unbacked_symbol_uses() + | set().union(*(free_unbacked_symbols(e) for e in self.scan_ranges)) + | set().union(*(free_unbacked_symbols(e) for e in self.size)) + ) + + def __post_init__(self): + assert len(self.ranges) + len(self.scan_ranges) == len(self.size) + super().__post_init__() + + def store_reduction(self, output_name, indexer, vars, scan_vars): + idx = self.reindex(vars, scan_vars) + value = self.inner_fn(idx) + result = ops.scan(self.dtype, self.combine_fn, value, self.init) + return ops.store(output_name, indexer(idx), result) + + def get_reduction_type(self): + # return self.scan_op + return "custom" + + def get_reduction_size(self): + return self.scan_ranges + + def get_size(self): + return self.size + + def get_pointwise_size(self): + return self.ranges + + def index_length(self): + return len(self.ranges) + len(self.scan_ranges) + + def inner_fn_args(self): + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, "r") + idx = self.reindex(index, rindex) + return (idx,) + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, "r") + idx = self.reindex(index, rindex) + return extract_free_unbacked_symbols(self.inner_fn, idx) + + @classmethod + def create( + cls, + device: torch.device, + dtype: torch.dtype, + inner_fn: Callable[[List[Expr]], Any], + size: List[Expr], + axis: int, + combine_fn: Callable[..., Any], + init: Any, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + ) -> Optional["TensorBox"]: + pointwise_ranges = [*size[:axis], *size[axis + 1 :]] + scan_ranges = [size[axis]] + + if device.type != "cuda": + # TODO: CPU support + return None + + sizevars = V.graph.sizevars + scan_numel = sizevars.simplify(sympy_product(scan_ranges)) + + # Scan with a single element is just a copy + if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type] + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=size, + ) + + reduction_hint, num_splits = cls.num_splits( + device=device, + dtype=dtype, + inner_fn=inner_fn, + axis=axis, + pointwise_ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + scan_numel=scan_numel, + ) + scan_type = Scan if num_splits <= 1 else SplitScan + + if num_splits > 1 and torch.version.hip is not None: + # Fallback for split-scan on ROCm + return None + + def reindex(index, scan_index): + assert len(scan_index) == len(scan_ranges) + assert len(index) == len(pointwise_ranges) + return [*index[:axis], *scan_index, *index[axis:]] + + result = TensorBox.create( + scan_type( + device=device, + dtype=dtype, + inner_fn=inner_fn, + size=size, + ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + reindex=reindex, + init=init, + reduction_hint=reduction_hint, + ) + ) + result.realize() + return result + + @classmethod + def num_splits( + cls, + device: torch.device, + dtype: torch.dtype, + inner_fn: Callable[[List[Expr]], Any], + axis: int, + pointwise_ranges: List[Expr], + scan_ranges: List[Expr], + combine_fn: Callable[..., Any], + scan_numel: Expr, + ): + # TODO: custom splitting heuristic for scan + def wrapper_fn(idx, reduction_idx): + return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]]) + + return Reduction.num_splits( + device=device, + dst_dtype=dtype, + src_dtype=dtype, + inner_fn=wrapper_fn, + ranges=pointwise_ranges, + reduction_ranges=scan_ranges, + reduction_type="sum", + reduction_numel=scan_numel, + ) + + +# This signifies a scan op that should go through TritonSplitScanKernel codgen on CUDA. +@dataclasses.dataclass +class SplitScan(Scan): + pass + + +def is_storage_and_layout(x): + try: + as_storage_and_layout(x, freeze=False) + return True + except NotImplementedError: + return False + + +def is_contiguous_storage_and_layout(x): + try: + buffer, layout = as_storage_and_layout(x, freeze=False) + return layout.is_contiguous() + except NotImplementedError: + return False + + +def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None): + """Try to simplify x into a StorageBox and a Layout""" + if isinstance(x, TensorBox): + return as_storage_and_layout( + x.data, + freeze=freeze, + want_contiguous=want_contiguous, + stride_order=stride_order, + ) + if isinstance(x, StorageBox) and isinstance(x.data, Buffer): + if freeze: + if want_contiguous: + x.data.freeze_layout() + assert x.data.layout.is_contiguous() + elif stride_order is not None: + x.data.freeze_layout_with_stride_order(stride_order) + else: + x.data.decide_layout() + return x, x.data.layout + if isinstance(x, ReinterpretView): + # making the base of x contiguous or stride_ordered will not necessarily make + # the ReinterpretView either, so don't pass along those arguments + buffer, _ = as_storage_and_layout( + x.data, + freeze=freeze, + ) + return buffer, x.layout + raise NotImplementedError + + +as_contiguous_storage_and_layout = functools.partial( + as_storage_and_layout, want_contiguous=True +) + + +def is_stride_order_storage_and_layout(x, stride_order): + try: + buffer, layout = as_storage_and_layout(x, freeze=False) + return layout.is_stride_ordered(stride_order) + except NotImplementedError: + return False + + +@dataclasses.dataclass +class BaseView(IRNode): + data: IRNode + + def get_unbacked_symbol_uses(self): + return self.data.get_unbacked_symbol_uses() + + def make_reindexer(self): + raise NotImplementedError(f"make_reindexer NYI on {self}") + + def make_indexer(self): + inner = self.data.make_indexer() + reindex = self.make_reindexer() + + def indexer(idx): + return inner(reindex(idx)) + + return indexer + + def make_loader(self): + inner = self.data.make_loader() + reindex = self.make_reindexer() + + def loader(idx): + return inner(reindex(idx)) + + return loader + + @property + def dtype(self): + return self.data.dtype + + def get_layout(self): + return self.data.get_layout() + + def get_device(self): + return self.data.get_device() + + def get_origin_node(self): + return None + + def get_name(self): + return self.data.get_name() + + def get_pointwise_size(self): + return self.get_size() + + def mark_reuse(self, users): + return self.data.mark_reuse(users) + + def has_exceeded_max_reads(self): + return self.data.has_exceeded_max_reads() + + def realize(self): + return self.data.realize() + + def realize_hint(self): + return self.data.realize_hint() + + def get_storage_numel(self): + return self.data.get_storage_numel() + + def is_extern(self): + return self.data.is_extern() # type: ignore[attr-defined] + + def get_reads(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def unwrap_view(self): + x: IRNode = self + while isinstance(x, BaseView): + x = x.data + return x + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise(device, self.get_dtype(), loader, self.get_size()) + + +@dataclasses.dataclass +class ExpandView(BaseView): + size: List[Expr] + + @staticmethod + def _normalize_size(x, new_size): + """Replace `-1` with correct sizes""" + new_size = list(map(sympy.expand, new_size)) + old_size = x.get_size() + old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) + assert len(new_size) == len(old_size) + for i in range(len(new_size)): + if new_size[i] == -1: + assert old_size[i] is not None + new_size[i] = old_size[i] + elif old_size[i] is None or old_size[i] == 1: + pass + else: + # Expect broadcast compatibility + new_size[i] = V.graph.sizevars.expect_equals( + new_size[i], + old_size[i], + msg=f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}", + ) + return new_size + + @classmethod + def create(cls, x, new_size): + new_size = cls._normalize_size(x, new_size) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + skip = len(new_size) - len(old_layout.size) + assert skip >= 0 + new_stride = [sympy.Integer(0)] * skip + for stride, size in zip(old_layout.stride, old_layout.size): + new_stride.append(stride if size != 1 else sympy.Integer(0)) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(new_size), + new_stride, + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + return ExpandView(x, new_size) + + def get_size(self): + return self.size + + def make_reindexer(self): + target = self.get_size() + actual = self.data.get_size() + skip = len(target) - len(actual) + + def reindex(index): + index = list(index[skip:]) + assert len(index) == len(actual) + for i in range(len(actual)): + if actual[i] == 1: + # zero out broadcast dimension + index[i] = sympy.Integer(0) + return index + + return reindex + + +@dataclasses.dataclass +class PermuteView(BaseView): + dims: List[Expr] + + @classmethod + def create(cls, x, dims): + dims = cls._map_neg_dims(dims) + assert set(dims) == set(range(len(dims))) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + [old_layout.size[i] for i in dims], + [old_layout.stride[i] for i in dims], + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + return PermuteView(x, dims) + + @classmethod + def _map_neg_dims(cls, dims): + return [dim if dim >= 0 else len(dims) + dim for dim in dims] + + def get_size(self): + assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims))) + size = self.data.get_size() + return [size[i] for i in self.dims] + + def make_reindexer(self): + inv = {j: i for i, j in enumerate(self.dims)} + inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index] + assert set(inv) == set(range(len(self.dims))) + + def reindex(index): + return [index[i] for i in inv] + + return reindex + + +class SqueezeView(BaseView): + @classmethod + def create(cls, x, *, dim=None): + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_size = [] + new_stride = [] + if dim is not None: + assert isinstance(dim, int), "expected integer dim argument" + assert 0 <= dim and dim < len(old_layout.size) + + for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): + if dim is None: + if size != 1: + new_size.append(size) + new_stride.append(stride) + else: + if i != dim: + new_size.append(size) + new_stride.append(stride) + else: + assert size == 1, "expected squeezed size to be 1" + + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + if dim is None: + # redirect to a generic view + return View.create(x, [s for s in x.get_size() if s != 1]) + else: + assert x.get_size()[dim] == 1 + return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) + + @staticmethod + def squeezer(size: Tuple[sympy.Expr, ...]): + new_size = [s for s in size if s != 1] + not_one = [i for i, s in enumerate(size) if s != 1] + length = len(size) + + def reindex(index: List[sympy.Expr]) -> Tuple[sympy.Expr, ...]: + assert len(index) == len(not_one), f"{index} {not_one}" + new_index = [sympy.Integer(0)] * length + for idx, s in zip(not_one, index): + new_index[idx] = s + return tuple(new_index) + + return new_size, reindex + + def __init__(self, data): + raise AssertionError("use SqueezeView.create()") + + +@dataclasses.dataclass +class GenericView(BaseView): + size: List[Expr] + reindex: Callable[..., Any] + + def make_reindexer(self): + return self.reindex + + def reindex_str(self): + index_old = [sympy_index_symbol(f"i{n}") for n in range(len(self.size))] + index_new = list(self.reindex(index_old)) + return f"lambda {', '.join(map(str, index_old))}: {index_new}" + + def __str__(self): + return self.str_helper( + [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"] + ) + + __repr__ = __str__ + + @classmethod + def create(cls, x, new_size, reindex): + return cls(x, list(new_size), reindex) + + def get_size(self): + return self.size + + +@dataclasses.dataclass +class View(GenericView): + @staticmethod + def handle_negative_index(idx, size): + idx = sympy.expand(idx) + size = sympy.expand(size) + evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr + if evaluate_expr(sympy.Lt(idx, 0)): + idx = idx + size + return idx + + @classmethod + def create(cls, x, new_size): + assert isinstance(new_size, (tuple, list)) + old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) + + # Skip pointless views + if V.graph.sizevars.statically_known_list_equals(old_size, new_size): + return x + + unbacked_symbols_in_sizes = False + if ( + len(free_unbacked_symbols(old_size)) > 0 + or len(free_unbacked_symbols(new_size)) > 0 + ): + unbacked_symbols_in_sizes = True + + if 0 in new_size: + + def fake_reindex(index): + return tuple([0] * len(old_size)) + + return cls(x, list(new_size), fake_reindex) + # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout + elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes: + if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)): + # realize x; otherwise, the dynamic_reshape_indexer below will fail + # due to the size_hint's inability to process unbacked SymInts + x = ExternKernel.realize_input(x) + + storage, old_layout = as_contiguous_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + FlexibleLayout.contiguous_strides(new_size), + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + reindex = cls.dynamic_reshape_indexer(old_size, new_size) + return cls(x, list(new_size), reindex) + + @staticmethod + def resolve_negative_size(old_size, new_size): + new_size = [V.graph.sizevars.simplify(x) for x in new_size] + old_size = [V.graph.sizevars.simplify(x) for x in old_size] + + new_size = list(new_size) + for i in range(len(new_size)): + if new_size[i] == -1: + new_size[i] = sympy.Integer(1) + new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) + break + + V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size)) + return old_size, new_size + + @classmethod + def dynamic_reshape_indexer(cls, old_size, new_size): + try: + reindex = cls._dynamic_reshape_indexer(old_size, new_size) + except (AssertionError, IndexError): + # optimistic algorithm failed, lets do a fallback + flat = [sympy_product(old_size)] + reindex1 = cls._dynamic_reshape_indexer(old_size, flat) + reindex2 = cls._dynamic_reshape_indexer(flat, new_size) + reindex = fuse_reindexing(reindex1, reindex2) + return reindex + + @staticmethod + def _dynamic_reshape_indexer(old_size, new_size): + """ + Perform a reshape entirely by modifying indexing math + """ + size_hint = V.graph.sizevars.size_hint + vars = [sympy_index_symbol(f"view{i}") for i in range(len(new_size))] + + stack_new = list(zip(vars, new_size)) + stack_old = list(old_size) + + view_expr = [] + while stack_new and stack_old: + size_old = stack_old.pop() + var, size_new = stack_new.pop() + if size_old == 1: + view_expr.append(sympy.Integer(0)) + stack_new.append((var, size_new)) # re-add + elif size_new == 1: + stack_old.append(size_old) # re-add + elif size_hint(size_new) == size_hint(size_old): + view_expr.append(var) + V.graph.sizevars.guard_equals(size_new, size_old) + elif size_hint(size_new) < size_hint(size_old): + while size_hint(size_new) < size_hint(size_old): + var2, size_new2 = stack_new.pop() + var = var2 * size_new + var + size_new = size_new * size_new2 + view_expr.append(var) + V.graph.sizevars.guard_equals(size_new, size_old) + elif size_hint(size_new) > size_hint(size_old): + divisor = sympy.Integer(1) + modulus = size_old + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + while size_hint(size_new) > size_hint(size_old): + modulus = stack_old.pop() + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + size_old = size_old * modulus + V.graph.sizevars.guard_equals(size_new, size_old) + else: + raise AssertionError() + + while stack_old: + size_old = stack_old.pop() + V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type] + view_expr.append(sympy.Integer(0)) + + while stack_new: + var, size_new = stack_new.pop() + V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type] + + view_expr.reverse() + assert len(view_expr) == len(old_size) + + def reindex(index): + assert len(index) == len(vars), (len(index), len(vars)) + replacements = dict(zip(vars, index)) + return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type] + + return reindex + + +@dataclasses.dataclass +class ReinterpretView(BaseView): + """Pretend our storage has a different layout""" + + layout: "Layout" + + def __post_init__(self): + super().__post_init__() + if isinstance(self.data, BaseView): + self.data = self.data.unwrap_view() + + def __str__(self): + return self.str_helper( + [ + self.data, + self.layout, + ] + ) + + __repr__ = __str__ + + def get_name(self): + return self.data.get_name() + + def get_device(self): + return self.layout.device + + def get_origin_node(self): + return None + + @property + def dtype(self): + return self.layout.dtype + + def get_size(self): + return list(self.layout.size) + + def get_stride(self): + return list(self.layout.stride) + + def make_loader(self): + def loader(index): + indexer = self.layout.make_indexer() + return ops.load(self.get_name(), indexer(index)) + + return loader + + def make_indexer(self): + return self.layout.make_indexer() + + def get_layout(self): + return self.layout + + def freeze_layout(self): + pass + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return ( + free_unbacked_symbols(self.layout.size) + | free_unbacked_symbols(self.layout.stride) + | free_unbacked_symbols(self.layout.offset) + ) + + def codegen_reference(self, writer=None): + # reinterpret_tensor is similar to as_strided except: + # - offset is added to the existing offset (rather than replacing it) + # - view tracking is disabled similar to unsafe_view + return V.graph.wrapper_code.codegen_reinterpret_view( + self.data, + self.layout.size, + self.layout.stride, + self.layout.offset, + writer, + ) + + +class SliceView(View): + @classmethod + def normalize_start_end(cls, x, dim, start, end): + """ + Normalize start and end such that both are in the range + [0, x.get_size()[dim]] and start <= end. + """ + sizevars = V.graph.sizevars + dim_size = x.get_size()[dim] + + if any(free_unbacked_symbols(x) for x in (start, end, dim_size)): + + def clamp(x, lower, upper): + return sympy.Min(sympy.Max(x, lower), upper) + + else: + + def clamp(x, lower, upper): + return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper) + + def clamp_wrap(val, lower, upper, default): + if val is None: + return default + val = cls.handle_negative_index(val, dim_size) + return clamp(val, lower, upper) + + start = clamp_wrap(start, 0, dim_size, 0) + end = clamp_wrap(end, start, dim_size, dim_size) + return start, end + + @classmethod + def create(cls, x, dim, start, end, step=1): + step = sympy.expand(step) + assert step > 0 + try: + if start == 0 and end >= 2**63 - 1 and step == 1: + return x + except TypeError: + pass + + sizevars = V.graph.sizevars + new_size = list(x.get_size()) + + start, end = cls.normalize_start_end(x, dim, start, end) + + new_size[dim] = FloorDiv(end - start + (step - 1), step) + + if is_storage_and_layout(x): + # Fast path + storage, old_layout = as_storage_and_layout(x) + new_stride = list(old_layout.stride) + new_stride[dim] = new_stride[dim] * step + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset + old_layout.stride[dim] * start, + ) + return ReinterpretView(storage, new_layout) + + def reindex(index): + assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" + index = list(index) + index[dim] = index[dim] * step + start + return index + + # redirect to a generic view + return SliceView(x, size=new_size, reindex=reindex) + + +class BaseConstant(IRNode): + dtype: torch.dtype + device: torch.device + + def get_size(self): + return () + + def get_device(self): + return self.device + + def get_origin_node(self): + return None + + def mark_reuse(self, users): + pass + + def has_exceeded_max_reads(self): + return False + + def get_reads(self): + return () + + def is_extern(self): + return False + + +@dataclasses.dataclass +class Constant(BaseConstant): + value: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self): + def loader(index): + return ops.constant(self.value, self.dtype) + + return loader + + def realize(self): + pass + + def constant_to_device(self, device): + return Constant(self.value, self.dtype, device) + + +@dataclasses.dataclass +class IndexingConstant(BaseConstant): + index: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self): + def loader(index): + return ops.index_expr(self.index, self.dtype) + + return loader + + def constant_to_device(self, device): + return IndexingConstant(self.index, self.dtype, device) + + +def is_contiguous_strides_for_shape(stride, shape): + return all( + size == 1 or left == right + for left, right, size in zip( + stride, FlexibleLayout.contiguous_strides(shape), shape + ) + ) + + +@dataclasses.dataclass +class Layout(IRNode): + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: List[Expr], + stride: Optional[Sequence[Union[Expr, int]]], + offset: Expr = Integer(0), + ): + assert stride is None or len(size) == len( + stride + ), f"size={size}, stride={stride}" + self.device = device + self.dtype = dtype + assert all(isinstance(s, (Expr, int)) for s in size) + self.size = size + self._stride = stride + self.offset = offset + + @property + def stride(self): + return self._stride + + def __str__(self): + offset = "" + if self.offset != 0: + offset = f", offset={self.offset}" + return ( + f"{type(self).__name__}('{self.device.type}', {self.dtype}, " + f"size={self.size}, stride={self.stride}{offset})" + ) + + __repr__ = __str__ + + def is_contiguous(self): + return is_contiguous_strides_for_shape(self.stride, self.size) + + def is_channels_last_contiguous(self): + ndim = len(self.size) + if ndim not in [4, 5]: + return False + for left, right, size in zip( + self.stride, make_channels_last_strides_for(self.size), self.size # type: ignore[arg-type] + ): + if size != 1 and left != right: + return False + return True + + def is_transposed(self): + for left, right, size in zip( + self.stride, + reversed(FlexibleLayout.contiguous_strides(self.size)), + self.size, + ): + if size != 1 and left != right: + return False + return True + + def is_stride_ordered(self, order): + assert len(self.stride) == len(order) + + # ignore dimensions of size 1, they dont affect layout + non_1_indices = [ + i + for i, dim in enumerate(self.size) + if V.graph.sizevars.size_hint(dim, fallback=2) != 1 + ] + + stride = [self.stride[i] for i in non_1_indices] + order = [order[i] for i in non_1_indices] + + def sorted_indices(arr): + sorted_arr = sorted(arr) + return [sorted_arr.index(element) for element in arr] + + # since we may have removed dimensions, need to re-sort & re-index order + order = sorted_indices(order) + + # reorder the stride given order + stride_ordered = [-1] * len(order) + for i in range(len(order)): + stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i]) + # check if it is in ascending order + for i in range(len(order) - 1): + if stride_ordered[i] > stride_ordered[i + 1]: + return False + return True + + def is_channels_last_stride_ordered(self): + # create channels_last order(NCHW, NCDHW, the C is the first order). + order = [0] + list(reversed(range(1, len(self.stride) - 1))) + order = [len(order)] + order + return self.is_stride_ordered(order) + + def as_fixed(self): + return FixedLayout( + self.device, + self.dtype, + self.size, + self.stride, + self.offset, + ) + + def make_indexer(self): + assert ( + FlexibleLayout.allow_indexing + ), f"convert {type(self).__name__} to FixedLayout first" + return self.as_fixed().make_indexer() + + def __eq__(self, other) -> bool: + return ( + self.device == other.device + and self.dtype == other.dtype + and self.size == other.size + and self.stride == other.stride + and self.offset == other.offset + ) + + def storage_size(self) -> sympy.Expr: + return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value] + + +class FixedLayout(Layout): + """A Tensor layout we cannot change""" + + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: Union[List[Expr], List[int]], + stride: Optional[Sequence[Union[Expr, int]]] = None, + offset: Union[Expr, int] = Integer(0), + ): + if stride is None: + stride = FlexibleLayout.contiguous_strides(size) + super().__init__( + device, + dtype, + size, # type: ignore[arg-type] + stride, + offset, # type: ignore[arg-type] + ) + + def make_indexer(self): + """A closure containing math to read a given element""" + + def indexer(index): + assert len(index) == len(self.stride) == len(self.size) + result = self.offset + for idx, stride, sz in zip(index, self.stride, self.size): + if sz != 1: + result = result + idx * stride + return result + + return indexer + + +class FlexibleLayout(Layout): + """A Tensor layout we are allowed to change""" + + allow_indexing = False + + @staticmethod + def contiguous_strides(sizes): + if len(sizes) == 0: + return [] + reversed_strides = [sympy.Integer(1)] + for size in reversed(sizes[1:]): + reversed_strides.append(size * reversed_strides[-1]) + return list(reversed(reversed_strides)) + + @staticmethod + def fill_ordered(sizes, order): + """ + Create a stride based on the order the dimensions should be filled in. + + In this format, channels last would be: + [1, 3, 2, 0] + """ + assert set(range(len(sizes))) == set(order) + next_stride = sympy.Integer(1) + strides = [None] * len(order) + + for i in order: + strides[i] = next_stride + next_stride = next_stride * sizes[i] + return strides + + @staticmethod + def stride_ordered(sizes, order): + """ + Create a stride based on the sorted order of a permuted range. + + In this format, channels last would be: + [3, 0, 2, 1] + """ + assert set(range(len(sizes))) == set(order) + fill_order = stride_order2fill_order(order) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + @staticmethod + def same_ordered(sizes, stride): + """ + Create a stride that has the same stride order as given stride + + For example, if given stride is [1000, 1, 100, 10], + the fill order should be [1, 3, 2, 0] + """ + assert len(sizes) == len(stride) + stride = [V.graph.sizevars.size_hint(x) for x in stride] + fill_order = sorted(range(len(stride)), key=stride.__getitem__) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + def as_stride_order(self, order): + return FixedLayout( + self.device, + self.dtype, + self.size, + self.stride_ordered(self.size, order), + self.offset, + ) + + def as_fill_order(self, order): + return FixedLayout( + self.device, + self.dtype, + self.size, + self.fill_ordered(self.size, order), + self.offset, + ) + + def as_same_order(self, stride): + return FixedLayout( + self.device, + self.dtype, + self.size, + self.same_ordered(self.size, stride), + self.offset, + ) + + def __init__(self, device, dtype, size, stride_order=None): + if stride_order: + strides = FlexibleLayout.fill_ordered(size, stride_order) + else: + strides = FlexibleLayout.contiguous_strides(size) + super().__init__(device, dtype, size, strides) + + +class AliasedLayout(Layout): + """Shares the same storage as another tensor""" + + def __init__(self, view: Union[BaseView, "TensorBox"]): + layout = view.get_layout() + super().__init__( + layout.device, + layout.dtype, + layout.size, + layout.stride, + ) + self.view = view + + def make_indexer(self): + return self.as_fixed().make_indexer() + + def maybe_guard_aligned(self): + offset = self.view.get_layout().offset + if offset == 0: + return True + from .compile_fx import ALIGNMENT + + return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type] + + +class NoneLayout(IRNode): + # This is janky, I figured out what fields to populate by just running + # the model I was interested in and adding properties/methods as needed. + # This doesn't inherit from Layout because Layout assumes you have stuff + # like sizes, but I don't really have anything here. + # + # If you have an ir.Node with NoneLayout, you probably need to setup + # dependencies manually in scheduler + + def __init__(self, device): + self.device = device + self.size = [0] + self.stride = [0] + + def storage_size(self): + return 0 + + def as_fixed(self): + return self + + +class MutationLayout(Layout): + def __init__(self, target: IRNode): + super().__init__( + target.get_device(), + target.get_dtype(), + target.get_size(), + None, + ) + self.target = target + name = self.get_buffer().get_name() + V.graph.mark_buffer_mutated(name) + + @Layout.stride.getter # type: ignore[attr-defined] + def stride(self): + return self.real_layout().stride + + def storage_size(self) -> sympy.Expr: + return self.real_layout().storage_size() + + def get_buffer(self) -> "Buffer": + def unwrap_views(target): + if isinstance(target, MutationLayout): + return unwrap_views(target.target) + if isinstance(target, BaseView): + return unwrap_views(target.unwrap_view()) + if isinstance(target, MutableBox): + return unwrap_views(target.data) + return target + + result = unwrap_views(self.target) + assert isinstance(result, Buffer), "MutationLayout must refer to a buffer" + return result + + def real_layout(self): + return self.get_buffer().layout + + @classmethod + def realize_into(cls, src, dst, unsafe_alias=False): + dst.realize() + # NOTE: We must realize users of `dst` before we realize `src`, since + # realization order determines scheduling order. Otherwise, src's + # mutation would be scheduled before the existing users of dst! + V.graph.mark_buffer_mutated(dst.get_name()) + + if isinstance(src, TensorBox): + src = src.data + + # We copy the contents of src into dst. In most cases this should + # be fused into a single kernel by the scheduler. + # NOTE: We cannot change src's layout to mutate dst directly as this + # would alias src to dst, which is not correct as further mutations to + # dst would effect users of src. However if there are no more users of + # dst, we can alias src to dst. + src.realize_hint() + + if not unsafe_alias: + src = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ).data + + src.realize() + assert isinstance(src.data.layout, FlexibleLayout) + src.data.layout = MutationLayout(dst) + return src.data + + def as_fixed(self): + return self + + def make_indexer(self): + return self.target.make_indexer() + + +@dataclasses.dataclass +class Buffer(IRNode): + # Name is sometimes None; e.g., ForceInPlace, where there isn't + # a meaningful name + name: Optional[str] + layout: Layout + + # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly, + # MultiOutput does NOT define this! + + def __post_init__(self): + super().__post_init__() + self.origin_node = None + + def make_indexer(self): + return self.layout.make_indexer() + + def get_name(self) -> str: + assert self.name + return self.name + + def get_device(self): + return self.layout.device + + def get_origin_node(self): + return self.origin_node + + @property + def dtype(self): + return getattr(self.layout, "dtype", None) + + def get_size(self): + return list(self.layout.size) + + def get_stride(self): + return list(self.layout.stride) + + def get_offset(self): + return self.layout.offset + + def get_layout(self): + return self.layout + + def get_storage_numel(self): + return self.get_numel() + + def is_extern(self): + return False + + def freeze_layout(self): + if not isinstance(self.layout, (MultiOutputLayout, AliasedLayout)): + self.layout = self.layout.as_fixed() + + def freeze_layout_with_stride_order(self, order): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_stride_order(order) + + def freeze_layout_with_fill_order(self, order): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_fill_order(order) + + def freeze_layout_with_same_order(self, stride): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_same_order(stride) + + def is_zero_elements(self): + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] + + def make_loader(self): + # Loading from a zero-element buffer is a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.get_dtype()) + + def loader(index): + indexer = self.layout.make_indexer() + return ops.load(self.name, indexer(index)) + + return loader + + def is_no_op(self): + return False + + def codegen_reference(self, writer=None): + return self.get_name() + + def decide_layout(self): + pass + + def get_alias_names(self): + if isinstance(self.layout, AliasedLayout): + return [self.layout.view.get_name()] + return () + + def get_mutation_names(self): + if isinstance(self.layout, MutationLayout): + return [self.layout.target.get_name()] + return () + + def get_read_writes(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + return extract_read_writes( + self.make_loader(), + self.get_size(), + ) + + def get_reads(self): + return self.get_read_writes().reads + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + """ + Returns the unbacked symbols which are defined by this IR node, + because this is a data-dependent IR node, or item() + """ + # So this is a little unusual. In principle, you could imagine + # defining a MultiOutputLayout buffer so that it DOES define + # unbacked symints. However, we can't easily tell what symints + # such a buffer defines, because MultiOutputLayout doesn't actually + # define any useful information about what it returns. + # + # An easier and better approach is to delay the symint allocation + # to the MultiOutput IR nodes, which are when we actually extract + # out the buffers and know what their sizes are. + # + # There are two subleties here: + # + # 1. Suppose you have a kernel that produces out1: (i0,), out2: (i0,) + # Both of these actually count as defs! The scheduler will just + # arbitrarily pick one of these as the canonical definer and + # ensure it stays live. It's not a big deal if we pick the + # wrong one because tuple accesses are cheap, and all this means + # is we accidentally keep a MultiOutput node live when it wasn't + # strictly necessary. + # + # 2. Suppose you have a MultiOutput buffer whose size is (i0,), but + # the MultiOutputLayout buffer it is projecting from isn't actually + # dynamic; it has i0 as one of the arguments. We cannot tell this + # directly from MultiOutput, we have to look at the input buffer's + # uses to work this out. No big deal. + if isinstance(self.layout, (NoneLayout, MultiOutputLayout)): + return set() + + # This kernel defines all unbacked symbols... that it didn't get in as + # arguments! + defs = ( + free_unbacked_symbols(self.get_size()) + | free_unbacked_symbols(self.get_stride()) + | free_unbacked_symbols(self.get_offset()) + ) + return defs - self.get_unbacked_symbol_uses() + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + """ + Returns the unbacked symbols which are required to be in scope in + order to successfully perform codegen for this buffer. For example, + a buffer that corresponds to an extern kernel call that takes i0 as + an argument would return {i0} here. This is used to generate necessary + dependencies that ensure we actually bind i0 in codegen before you + try to use it. + + Note that this is NOT transitive; in particular, if this buffer takes + in as input another buffer with dynamic shape (e.g., (i0,)), we will + not report it here, because you will already have a dependency + on that buffer, which will eventually have a dependency on i0 if + necessary. + """ + return set() + + def codegen_unbacked_symbol_defs(self, wrapper): + # NB: If it is possible for other ir node types to return unbacked + # symints, you need to make sure their codegen calls this method. + # Don't forget to update get_unbacked_symbol_defs too. + symbols_to_define = self.get_unbacked_symbol_defs() + for i, s in enumerate(self.get_size()): + if s in symbols_to_define: + wrapper.writeline( + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.size({i}){wrapper.ending}" + ) + symbols_to_define.remove(s) + for i, s in enumerate(self.get_stride()): + if s in symbols_to_define: + wrapper.writeline( + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.stride({i}){wrapper.ending}" + ) + symbols_to_define.remove(s) + if (s := self.get_offset()) in symbols_to_define: + wrapper.writeline( + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.storage_offset(){wrapper.ending}" + ) + symbols_to_define.remove(s) + assert ( + not symbols_to_define + ), f"unbacked symint {s} not written out, check comment above" + + def realize(self): + pass + + def get_workspace_size(self): + """ + Gets extra global memory size needed by this buffer. + Some algorithms (e.g. group gemm) may require extra global memory in the generated code. + """ + return 0 + + def should_allocate(self): + # Returns False by default. + return False + + +class InputBuffer(Buffer): + pass + + +class ConstantBuffer(InputBuffer): + override_device: Optional[torch.device] = None + + def make_loader(self): + def loader(index): + indexer = self.layout.make_indexer() + return ops.load( + V.graph.constant_name(self.get_name(), self.override_device), + indexer(index), + ) + + return loader + + def constant_to_device(self, device): + return ConstantBuffer( + V.graph.constant_name(self.get_name(), device), self.layout + ) + + +class NoneAsConstantBuffer(IRNode): + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return set() + + def codegen_reference(self, writer=None): + return V.graph.wrapper_code.none_str + + +class ShapeAsConstantBuffer(IRNode): + def __init__(self, shape): + super().__init__() + self.shape = shape + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return free_unbacked_symbols(self.shape) + + def codegen_reference(self, writer=None): + return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.shape)) + + +@dataclasses.dataclass +class ComputedBuffer(Buffer): + data: Loops + + def get_computed_buffer_name(self): + """ + Returns self.name if it exists, otherwise returns the name of the data node if that exists. + If neither exist, returns None. + """ + if self.name is not None: + return self.name + if hasattr(self.data, "name"): + return self.data.name + return None + + @cache_on_self + def num_reads(self): + return len(self.get_read_writes().reads) + + def get_read_writes(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.data.get_reduction_type(): + return extract_read_writes( + self.get_store_function(), + self.data.get_pointwise_size(), + self.data.get_reduction_size(), + ) + else: + return extract_read_writes( + self.get_store_function(), + self.data.get_size(), + ) + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + # Ordinarily, we'd like to just peek at the arguments list, + # but ComputedBuffers have no argument list. + # + # Morally, this logic needs to be synchronized with the + # KernelArgs.size calls, which are responsible for making symbols make + # there way as kernel arguments (and it is precisely passing in one of + # those symbols that establishes a dependency). However, we haven't + # started codegen yet so we can't directly reuse that logic. + # + # For now, I'm just yoloing with the size of the buffer. Not sure if + # it is enough. + # + # One thing you might wonder is if this is enough for a ComputedBuffer + # denoting a reduction over i0. Empirically, it is enough, but for an + # unusual reason: we only need accurate dependencies for item() call, + # but it's impossible to end up with a reduction over i0 from an + # item() call without a regular non-reduction buffer first. + return ( + free_unbacked_symbols(self.get_size()) + | free_unbacked_symbols(self.get_stride()) + | free_unbacked_symbols(self.get_offset()) + | self.data.get_unbacked_symbol_uses() + ) + + def make_loader(self): + # Inline constants and index_expressions + if ( + hasattr(self.data, "make_loader") + and self.name not in V.graph.mutated_buffers + and self.num_reads() == 0 + ): + # can be inlined + return self.data.make_loader() + return super().make_loader() + + def get_store_function(self): + indexer = self.layout.as_fixed().make_indexer() + if isinstance(self.data, (Reduction, Scan)): + return partial(self.data.store_reduction, self.name, indexer) + else: + assert isinstance(self.data, Pointwise) + return partial(self.data.store_output, self.name, indexer) + + def get_fill_order(self): + """ + If our layout is still flexible, try to determine the stride order based on stride orders of reads. + + TODO(jansel): A better algorithm here would look at downstream consumers of this + value and try to do global graph-level layout optimization. + This is also something just begging to be autotuned. + """ + if isinstance(self.layout, FlexibleLayout): + (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size() + ) + reads = self.get_read_writes().reads + reads_bufs = [ + V.graph.name_to_buffer[r.name] + if r.name in V.graph.name_to_buffer.keys() + else None + for r in reads + ] + # only consider reads to buffer of same size + # ignore StarDeps because they don't contribute stride information + assert all( + isinstance(r, (dependencies.StarDep, dependencies.MemoryDep)) + for r in reads + ) + reads = [ + sympy_subs( + r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0} + ) + for r in reads + if isinstance(r, dependencies.MemoryDep) + ] + + if reads: + if isinstance(self.data, Scan): + indices = self.data.reindex(index_vars, reduction_vars) + else: + indices = index_vars + stride_lengths = [ + V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type] + ] + from .scheduler import pick_loop_order + + return pick_loop_order(stride_lengths, self.get_size()) + + return None + + def decide_layout(self): + if isinstance(self.layout, FlexibleLayout): + order = self.get_fill_order() + if order: + self.freeze_layout_with_fill_order(order) + else: + self.freeze_layout() + + def get_default_sizes_body(self): + args, var_ranges = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q" + ) + with patch.object(ConstantBuffer, "override_device", self.get_device()): + body = LoopBody( + self.get_store_function(), + (args if self.get_reduction_type() else args[:1]), + var_ranges, + ) + index_vars = [] + reduce_vars: List[Any] = [] + index_size = [] + reduce_size = [] + for v, s in var_ranges.items(): + if v in args[0]: + assert not reduce_vars + index_vars.append(v) + index_size.append(s) + else: + assert v in args[1] + reduce_vars.append(v) + reduce_size.append(s) + return (index_size, reduce_size), body, (index_vars, reduce_vars) + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + ): + """ + This is a main place where we do loop transformations in a + backend-agnostic way. + + Here we: + 1) Remove any 1 dimensions + 2) Fuse contiguous dimensions together + 3) Reorder dimensions based on stride orders + + Optional argument extra_indexing_constraints can be used to append additional + indexing expressions to existing ones derived from buffer's body. This can be useful + to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) + on CPU by preventing indexing simplifications and obtaining index/reduce ranges for + the scheduler node compatible with other nodes. + """ + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = self.get_default_sizes_body() + + index_formulas = [*body.indexing_exprs.values()] + if extra_indexing_constraints is not None: + assert ( + isinstance(extra_indexing_constraints, tuple) + and len(extra_indexing_constraints) == 2 + ) + extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints + assert isinstance(extra_indexing_ranges, dict) + assert isinstance(extra_indexing_expr, list) + assert all(isinstance(f, Expr) for f in extra_indexing_expr) + + expected_var_ranges = body.var_ranges + assert expected_var_ranges == extra_indexing_ranges, ( + expected_var_ranges, + extra_indexing_ranges, + ) + # remove already existing expressions + extra_indexing_expr = [ + e for e in extra_indexing_expr if e not in index_formulas + ] + index_formulas += extra_indexing_expr + + reads_bufs = [ + V.graph.name_to_buffer[reads_name] + if reads_name in V.graph.name_to_buffer.keys() + else None + for reads_name in body.reads_name2expr.keys() + ] + memory_addrs = [ + *body.reads_name2expr.values(), + *body.writes_name2expr.values(), + ] + + # the reordering_reindex in reads' simplify_reorder_and_tile + reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs) + for i, reads_buf in enumerate(reads_bufs): + if isinstance(reads_buf, ComputedBuffer) and hasattr( + reads_buf, "iter_reordering_reindex" + ): + reordering_reindex[i] = reads_buf.iter_reordering_reindex # type: ignore[has-type] + + def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None): + sizes, reindex0, reindex1 = self._apply_loop_reordering( + x_vars, support_vars, sizes, memory_addrs, reordering_reindex + ) + # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] + x_vars = reindex0(x_vars) + sizes, reindex2, prune = V.graph.sizevars._simplify_loops( + x_vars, + sizes, + index_prevent_reordering(index_formulas, x_vars, sizes), + ) + x_vars = prune(x_vars) + # sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas) + # x_vars = prune(x_vars) + # sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs) + reindex = fuse_reindexing(reindex1, reindex2) + return sizes, reindex, reindex1 + + support_vars = index_vars + reduce_vars + iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder( + index_vars, support_vars, index_size, reordering_reindex + ) + reduce_ranges, reduce_reindex, _ = simplify_and_reorder( + reduce_vars, support_vars, reduce_size + ) + + # remember the reordering if not have loop collapse. + if len(iter_ranges) == len(index_vars): + self.iter_reordering_reindex = iter_reordering_reindex + # retrace the loop body with simplification and reordering applied + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + iter_ranges, reduce_ranges, prefix="z" + ) + body = LoopBody( + body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges + ) + return (iter_ranges, reduce_ranges), body + + @staticmethod + def _apply_loop_reordering( + index_vars, + support_vars, + sizes, + memory_addrs, + reordering_reindex=None, + priority_idx=None, + ): + """ + Shuffle the order of loops around to hopefully improve performance. + """ + from .scheduler import pick_loop_order + + if priority_idx is None: + priority_idx = [] + + try: + strides = [ + V.graph.sizevars.stride_hints(expr, index_vars, support_vars) + for expr in memory_addrs + ] + assert len(strides) == len(memory_addrs) and len(strides[0]) == len( + index_vars + ) + # consider both layout(strides) and reordering(reordering_reindex) + if reordering_reindex is not None: + for i in range(len(memory_addrs)): + try: + strides[i] = reordering_reindex[i](strides[i]) + # if len(order) != len(strides), do not reorder + except AssertionError: + pass + order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) + except Exception: + if config.debug: + log.warning( + "Did not simplify complex index:\n%s\n%s", + dict(zip(index_vars, sizes)), + memory_addrs, + ) + order = list(range(len(sizes))) + sizes = [sizes[i] for i in order] + return sizes, same_reorder(order), inverse_reorder(order) + + def get_reduction_size(self): + return self.data.get_reduction_size() + + def get_reduction_type(self): + return self.data.get_reduction_type() + + def is_no_op(self): + return self.data.is_zero_elements() + + def should_allocate(self): + return True + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + return self.data.constant_to_device(device) + + +class TemplateBuffer(Buffer): + """ + Represents a Triton (in the future other type) of template operator + that we can fuse an epilogue onto. + """ + + def __init__(self, layout, inputs, make_kernel_render): + super().__init__(name=None, layout=layout) + self.inputs = InputsKernel.unwrap_storage(inputs) + self.make_kernel_render = make_kernel_render + self.name = V.graph.register_buffer(self) + + def get_read_writes(self): + return self.normalized_read_writes() + + def normalized_read_writes(self): + name = self.get_name() + indexer = self.layout.make_indexer() + + def dummy(index, rindex): + assert len(rindex) == 0 + return ops.store(name, indexer(index), "fake") + + deps = dependencies.extract_read_writes( + dummy, self.get_size(), (), normalize=True + ) + deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs} + return deps + + def get_reduction_size(self): + return 1 + + def get_reduction_type(self): + return None + + def is_no_op(self): + return False + + def should_allocate(self): + return True + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + ): + return ( + ( + self.get_size(), + (), + ), + None, + ) + + +class TritonTemplateBuffer(TemplateBuffer): + pass + + +class CUDATemplateBuffer(TemplateBuffer): + def __init__( + self, + layout, + inputs, + make_kernel_render, + workspace_size: int, + template: "CUDATemplate", # type: ignore[name-defined] # noqa: F821 + ): + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self): + return self.workspace_size if self.workspace_size is not None else 0 + + +@dataclasses.dataclass +class InputsKernel(Buffer): + inputs: List[Buffer] + + def get_read_writes_input(self, x): + return dependencies.StarDep(x.get_name()) + + def get_read_writes(self): + star_dep = [] + for input in self.inputs: + if isinstance(input, list): + star_dep.extend([self.get_read_writes_input(x) for x in input]) + else: + star_dep.append(self.get_read_writes_input(input)) + + return dependencies.ReadWrites( + set(star_dep), + {dependencies.StarDep(self.get_name())}, + set(), + [], + None, + op_counts=collections.Counter(), + ) + + @classmethod + def unwrap_storage_for_input(cls, x): + if isinstance(x, TensorBox): + x = x.data + if isinstance(x, StorageBox): + x = x.data + if isinstance(x, BaseView) and not isinstance(x, ReinterpretView): + x = ExternKernel.realize_input(x) + if isinstance(x, TensorBox): + # when converting to ReinterpretView fails in the + # realize_input call above, the result will be wrapped + # into TensorBox / StorageBox pair as a result of the + # cls.copy_input call; so we should unwrap recursively + return cls.unwrap_storage_for_input(x) + assert isinstance(x, (Buffer, ReinterpretView)), x + return x + + @staticmethod + def unwrap_storage(inputs): + inputs_new = [] + for x in inputs: + if isinstance(x, list): + x = [InputsKernel.unwrap_storage_for_input(i) for i in x] + else: + x = InputsKernel.unwrap_storage_for_input(x) + inputs_new.append(x) + return inputs_new + + def is_extern(self): + return True + + +class NopKernel(InputsKernel): + def is_no_op(self): + return True + + +class ConcatKernel(NopKernel): + """ + There isn't actually a real kernel for concat, we just change the + storage for the upstream data. + """ + + @classmethod + def create(cls, inputs, dim): + device = inputs[0].get_device() + dtype = inputs[0].get_dtype() + new_size = list(inputs[0].get_size()) + offsets_start = [0] + offsets_end = [new_size[dim]] + assert 0 <= dim < len(new_size) + for i in range(1, len(inputs)): + input_size = inputs[i].get_size() + offsets_start.append(new_size[dim]) + assert len(input_size) == len(new_size) + assert inputs[i].get_dtype() == dtype + assert inputs[i].get_device() == device + for j in range(len(new_size)): + if j == dim: + new_size[j] = new_size[j] + input_size[j] + else: + new_size[j] = V.graph.sizevars.guard_equals( + new_size[j], input_size[j] + ) + offsets_end.append(new_size[dim]) + + output_stride = FlexibleLayout.contiguous_strides(new_size) + # If any of the inputs is in CL format, use CL format for the output + for i in range(len(inputs)): + x = inputs[i] + if is_storage_and_layout(x): + layout = x.get_layout() + if ( + isinstance(layout, FixedLayout) + and layout.is_channels_last_contiguous() + ): + # use CL stride for the output + output_stride = make_channels_last_strides_for(new_size) + break + + concat_kernel = ConcatKernel( + name=None, + layout=FixedLayout( + device=device, + dtype=dtype, + size=new_size, + stride=output_stride, + ), + inputs=[], + ) + kernel = StorageBox(concat_kernel) + buffer_names = [] + for i in range(len(inputs)): + input_buffer = cls.realize_into( + inputs[i], + SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]), + ) + concat_kernel.inputs.append(input_buffer) + + if isinstance(inputs[i].data, BaseView): + input_unwrapped = inputs[i].data.unwrap_view() + else: + input_unwrapped = inputs[i].data + + if ( + input_unwrapped.is_input_buffer() + and inputs[i].get_device().type == "cuda" + and not is_dynamic(input_buffer) + ): + buffer_names.append(input_buffer.get_name()) + + if len(buffer_names) > 1: + V.graph.register_list(buffer_names) + + concat_kernel.name = V.graph.register_buffer(concat_kernel) + concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs) + + return kernel + + @classmethod + def can_realize_into_without_copy(cls, src): + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.can_realize_into_without_copy(src.data) + + return isinstance(src.data.layout, FlexibleLayout) and not isinstance( + src.data, ExternKernelAlloc + ) + + @classmethod + def realize_into(cls, src, dst): + # Attempt to turn this into a ReinterpretView rather than assert. + # This has concessions around layout, as as_storage_and_layout + # can cause us to go from flexible to fixed layout. + if not isinstance(dst, ReinterpretView): + if is_storage_and_layout(dst): + storage, layout = as_storage_and_layout(dst) + dst = ReinterpretView(storage, layout) + assert isinstance(dst, ReinterpretView), dst + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.realize_into(src.data, dst) + if isinstance(src, StorageBox): + src.realize() + # ExternKernelAlloc has specific requirements for output layout, should create a copy + assert hasattr(src.data, "layout") + if cls.can_realize_into_without_copy(src): + src.data.layout = AliasedLayout(dst) + return src.data + # introduce a copy + pw = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ) + return cls.realize_into(pw, dst) + + def should_allocate(self): + return True + + +@dataclasses.dataclass +class ExternKernel(InputsKernel): + constant_args: Tuple[Any, ...] = () + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + output_view: Optional[ReinterpretView] = None + python_kernel_name: Optional[str] = None + cpp_kernel_name: Optional[str] = None + # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel + # We shouldn't need to do this since the information can be retrieved from op_overload._schema. + ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field( + default_factory=list + ) + op_overload: Optional[ + Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] + ] = None + arg_properties: Optional[List[Dict[str, Any]]] = None + kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None + + def __init__( + self, + name, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + name, + layout, + inputs, + ) + self.constant_args = constant_args + self.kwargs = kwargs if kwargs else {} + self.output_view = output_view + self.python_kernel_name = python_kernel_name + self.cpp_kernel_name = cpp_kernel_name + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + self.op_overload = op_overload + self.collect_arg_kwarg_properties() + + def collect_arg_kwarg_properties(self): + # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional + # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen + if ( + isinstance(self.op_overload, torch._ops.OpOverload) + and not self.ordered_kwargs_for_cpp_kernel + ): + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in self.op_overload._schema.arguments if x.kwarg_only + ] + self.arg_properties = ( + [ + { + "name": x.name, + "type": x.real_type, + "default_value": x.default_value, + } + for x in self.op_overload._schema.arguments + if not x.kwarg_only + ] + if isinstance(self.op_overload, torch._ops.OpOverload) + else [{} for i in range(len(self.inputs))] + ) + self.kwarg_properties = ( + { + x.name: {"type": x.real_type, "default_value": x.default_value} + for x in self.op_overload._schema.arguments + if x.kwarg_only + } + if isinstance(self.op_overload, torch._ops.OpOverload) + else {} + ) + + def decide_layout(self): + if isinstance(self.layout, FlexibleLayout): + self.apply_constraint() + self.freeze_layout() + + def codegen_comment(self, wrapper): + origin_str, detailed_origin_str = get_kernel_metadata(self, wrapper) + if origin_str: + wrapper.writeline(origin_str) + + def codegen(self, wrapper): + raise NotImplementedError() + + def get_kernel_name(self): + return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name + + @staticmethod + def copy_input(x): + pw = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=x.get_size(), + origin_node=x.get_origin_node(), + traceback=x.get_traceback(), + ) + pw.realize() + return pw + + @classmethod + def process_kernel(cls, kernel, *args, **kwargs): + binded_args = {"args": args, "kwargs": kwargs} + + args_flat, args_spec = pytree.tree_flatten(binded_args) + + is_arg_tensor = [] + tensor_args = [] + non_tensor_args: List[Any] = [] + for arg in args_flat: + is_arg_tensor.append(isinstance(arg, IRNode)) + if is_arg_tensor[-1]: + tensor_args.append(arg) + else: + if isinstance(arg, sympy.Expr): + arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) + non_tensor_args.append(arg) + + def unflatten_args(new_tensor_args, new_non_tensor_args): + result = [] + it_tensors = iter(new_tensor_args) + it_non_tensors = iter(new_non_tensor_args) + for is_tensor in is_arg_tensor: + if is_tensor: + result.append(next(it_tensors)) + else: + result.append(next(it_non_tensors)) + r = pytree.tree_unflatten(result, args_spec) + return r.get("args", []), r.get("kwargs", {}) + + tensor_args = [cls.realize_input(x) for x in tensor_args] + + # freeze layout otherwise our output stride calculation might + # become incorrect + for x in tensor_args: + if is_storage_and_layout(x): + as_storage_and_layout(x, freeze=True) + + # We don't have generic shape formulas, so just burn in the + # shapes and run an example input. + # TODO(jansel): replace this with dynamic shape formulas + example_args = [] + + # We need to retain the constant values of fake tensors that we originally + # propagated the graph with, because for some operators running without a + # constant would trigger an error / DataDependentException + for x in tensor_args: + if x.get_name() in V.graph.constants: + example_args.append(V.graph.constants[x.get_name()]) + else: + example_args.append(ir_node_to_tensor(x, guard_shape=True)) + + new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) + example_output = kernel(*new_args, **new_kwargs) + + example_out_li = ( + [example_output] + if not isinstance(example_output, (list, tuple)) + else example_output + ) + for t in example_out_li: + if isinstance(t, torch.Tensor) and t.is_sparse: + msg = "sparsity not handled. Please file issue for sparse inference weights." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + # TODO: Unconditionally do this, not just when example_output has + # unbacked symbols + if maybe_free_unbacked_symbols(example_output): + example_output = V.graph.current_node.meta["val"] + + return example_output, tensor_args, non_tensor_args, unflatten_args + + @classmethod + def convert_to_reinterpret_view(cls, x): + """ + In order to pass this to an extern kernel we need a + ReinterpretView not a View. This allows us to avoid some + unneeded copies. + """ + assert isinstance(x, BaseView) + if isinstance(x, ReinterpretView): + return x + + # NOTE: Don't use extract_read_writes here as it fails when + # make_loader() inlines the computation + x.unwrap_view().freeze_layout() + index_args, var_ranges = dependencies.index_vars_squeeze( + x.get_size(), prefix="r" + ) + range_vars = index_args[0] + index = x.make_indexer()(range_vars) + + index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) + strides = V.graph.sizevars.stride_vars(index, range_vars) + offset = V.graph.sizevars.offset_var(index, range_vars) + expected = sympy_dot(range_vars, strides) + offset + + if index != expected: + log.debug( + "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", + strides, + offset, + index, + ) + raise NotImplementedError() + + return ReinterpretView( + data=x.data, + layout=FixedLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=x.get_size(), + stride=strides, + offset=offset, + ), + ) + + @classmethod + def realize_input(cls, x): + if x is None: + return NoneAsConstantBuffer() + if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)): + return ShapeAsConstantBuffer(x) + if isinstance(x, Constant): + return V.graph.add_tensor_constant( + torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device()) + ) + if isinstance(x, ConstantBuffer): + return x + if isinstance(x, TensorBox): + return cls.realize_input(x.data) + if isinstance(x, ReinterpretView): + return ReinterpretView(cls.realize_input(x.data), x.get_layout()) + if isinstance(x, BaseView): + x.realize() + if is_storage_and_layout(x.unwrap_view()): + try: + return cls.convert_to_reinterpret_view(x) + except NotImplementedError: + pass + if isinstance(x, StorageBox): + # TODO(jansel): impose layout preference on realized buffer + x.realize() + return x + return cls.copy_input(x) + + @classmethod + def require_stride1(cls, x): + if is_storage_and_layout(x): + if len(x.get_stride()) == 0: + return x + for stride in x.get_stride(): + if stride == 1: + return x + return cls.copy_input(x) + + @classmethod + def require_stride_order(cls, x, order): + if x.get_numel() == 0: # Layout doesn't matter + return x + + # require x to have the layout as strided_ordered as order + if is_storage_and_layout(x): + while isinstance(x.get_layout(), AliasedLayout): + x = x.get_layout().view + if isinstance(x.get_layout(), FlexibleLayout): + # fix flexiblelayout to be FixedLayout with stride_order + as_storage_and_layout( + x, freeze=True, want_contiguous=False, stride_order=order + ) + return x + elif isinstance( + x.get_layout(), FixedLayout + ) and x.get_layout().is_stride_ordered(order): + return x + elif isinstance(x.get_layout(), MutationLayout): + if isinstance(x.get_layout().real_layout(), FlexibleLayout): + raise AssertionError( + "the MutationLayout's real layout shouldn't be FlexibleLayout" + ) + elif isinstance( + x.get_layout().real_layout(), FixedLayout + ) and x.get_layout().real_layout().is_stride_ordered(order): + return x + + # TODO - Storage to InputBuffer + if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): + return x + if ( + isinstance(x, TensorBox) + and isinstance(x.data, BaseView) + and not isinstance(x.data, ReinterpretView) + and is_storage_and_layout(x.unwrap_view()) + and not isinstance(x.unwrap_view().data, ExternKernelAlloc) + ): + try: + x.data = cls.convert_to_reinterpret_view(x.data) + return cls.require_stride_order(x, order) + except NotImplementedError: + pass + x = cls.copy_input(x) + as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order) + assert is_stride_order_storage_and_layout(x, order) + return x + + @classmethod + def require_channels_last(cls, x): + return cls.require_stride_order(x, NHWC_STRIDE_ORDER) + + @classmethod + def require_contiguous(cls, x): + return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) + + def apply_constraint(self): + pass + + def codegen_const_args(self): + return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) + + def codegen_args(self): + args = [] + for i, x in enumerate(self.inputs): + if isinstance(x, list): + names = [i.codegen_reference() for i in x] + codegen_reference = f'[{", ".join(names)}]' + args.append(codegen_reference) + else: + if V.graph.cpp_wrapper: + assert self.arg_properties and i < len( + self.arg_properties + ), "Invalid arg_properties accessing" + type_ = self.arg_properties[i].get("type") + args.append( + V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] + type_, x, self.is_legacy_abi_kernel() + ) + ) + else: + args.append(x.codegen_reference()) + args.extend(self.codegen_const_args()) + return args + + def get_kwargs_value(self, arg_name): + if arg_name in self.kwargs: + return self.kwargs.get(arg_name) + if self.kwarg_properties and self.kwarg_properties.get(arg_name): + return self.kwarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] + else: + raise AssertionError(f"{arg_name} not in self.kwarg_properties") + + def is_legacy_abi_kernel(self): + return False + + def codegen_kwargs(self): + if V.graph.cpp_wrapper: + kwargs = [] + for arg_name in self.ordered_kwargs_for_cpp_kernel: + v = self.get_kwargs_value(arg_name) + if isinstance(v, sympy.Expr): + kwargs.append(v) + else: + type_ = ( + self.kwarg_properties.get(arg_name).get("type") # type: ignore[union-attr] + if self.kwarg_properties and arg_name in self.kwarg_properties + else None + ) + kwargs.append( + V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] + type_, v, self.is_legacy_abi_kernel() + ) + ) + else: + kwargs = [ + f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc] + for k, v in self.kwargs.items() + ] + return kwargs + + def codegen_size_asserts(self, wrapper): + if config.size_asserts and not V.graph.cpp_wrapper: + size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) + stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) + wrapper.writeline( + f"assert_size_stride({self.get_name()}, {size}, {stride})" + ) + + def get_group_stride(self): + """ + get output sizes and strides, for template_codegen + """ + _size = self.get_size() + _stride = self.get_stride() + # iter_ranges = _size of output tensor, reduce_range = [] because no reduction + return [_size, []], _stride + + def canonicalize(self): + """ + Manually get canonicalization of the output index + """ + # manually generate index formula for conv + sizevars = V.graph.sizevars + sizes = self.get_size() + strides = self.get_stride() + strides = [sizevars.size_hint(x) for x in strides] + index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))] + # reorder index vars according to stride + index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) + lookup = {pos: idx for idx, pos in enumerate(index_order)} + order = [lookup[i] for i in range(len(lookup))] + index_vars = [index_vars[i] for i in order] + indexer = self.make_indexer() + index = indexer(index_vars) + + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, [index] + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + _, add_var = var_builder("c") + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + + index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type] + return index, tuple(new_sizes) + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + # NB: It's not necessary to check regular inputs as we automatically + # have dependencies on them + r = set() + for arg in self.constant_args: + r |= maybe_free_unbacked_symbols(arg) + for arg in self.kwargs.values(): + r |= maybe_free_unbacked_symbols(arg) + return r + + def __str__(self): + kernel_name = getattr(self, "python_kernel_name", None) + lines = [ + f"python_kernel_name={kernel_name!r}", + ] + lines += [ + f"{field.name}={getattr(self, field.name)}" + for field in dataclasses.fields(self) + ] + lines.append(f"origin_node={self.origin_node!r}") + return self.str_helper(lines) + + __repr__ = __str__ + + +@dataclasses.dataclass +class ExternKernelOut(ExternKernel): + def codegen(self, wrapper): + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + wrapper.generate_extern_kernel_out( + self.output_view, + self.codegen_reference(), + args, + self.get_kernel_name(), + ) + + def __init__( + self, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + None, + layout, + self.unwrap_storage(inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + self.name = V.graph.register_buffer(self) + + def should_allocate(self): + return True + + +class RandomSeeds(ExternKernelOut): + def __init__(self, count: int, device: torch.device): + limits = torch.iinfo(torch.int64) + super().__init__( + layout=FixedLayout( + device=device, + dtype=torch.int64, + size=[count], + ), + inputs=[], + constant_args=[limits.min, limits.max, [count]], + python_kernel_name="aten.randint.low_out", + cpp_kernel_name="at::randint_out", + ) + + +class ExternKernelAlloc(ExternKernel): + def codegen(self, wrapper): + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + V.graph.wrapper_code.generate_extern_kernel_alloc(self, args) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def __init__( + self, + layout, + inputs, + constant_args=(), + kwargs=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + None, + layout, + self.unwrap_storage(inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + self.name = V.graph.register_buffer(self) + + def should_allocate(self): + return False + + def apply_constraint(self): + raise NotImplementedError + + +class UserDefinedTritonKernel(ExternKernel): + def get_kernel_and_configs(self): + from triton.runtime.autotuner import Autotuner + + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + kernel = kernel_side_table.get_kernel(self.kernel_idx) + configs = [] + if isinstance(kernel, Autotuner): + configs = kernel.configs + kernel = kernel.fn + return kernel, configs + + def codegen(self, wrapper): + kernel, configs = self.get_kernel_and_configs() + + # Definition of kernel + new_name, triton_meta = wrapper.define_user_defined_triton_kernel( + kernel, configs, self.kwargs + ) + + args = self.codegen_kwargs() + if V.graph.cpp_wrapper: + # in C++ wrapper, we don't pass constexpr args, as they don't + # get added as parameters to the PTX code compiled from the + # user-defined Triton kernel (only non-constexpr args do) + args = [arg for i, arg in enumerate(args) if i not in kernel.constexprs] + + # Call to kernel + self.codegen_comment(wrapper) + wrapper.generate_user_defined_triton_kernel( + new_name, + self.grid, + configs, + args, + triton_meta, + ) + + def should_allocate(self): + return False + + def has_side_effects(self): + # UserDefinedTritonKernel does not return anything, but rather + # modifies input in place, do not let it get DCEd + return True + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def get_mutation_names(self): + return [] + + def __init__(self, *, kernel_idx, grid, kernel_args): + inputs = [] + kwargs = dict() + constant_args = [] + for k, v in kernel_args.items(): + if isinstance(v, TensorBox): + t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) + inputs.append(t) + kwargs[k] = t + else: + constant_args.append(v) + kwargs[k] = v + + assert len(inputs) != 0 + device = inputs[0].get_device() + + super().__init__( + None, + NoneLayout(device), # type: ignore[arg-type] + inputs, + tuple(constant_args), + kwargs, + ) + self.name = V.graph.register_buffer(self) + self.kernel_idx = kernel_idx + self.grid = grid + + kernel, _ = self.get_kernel_and_configs() + # If we are autotuning, not all arguments will be passed + self.ordered_kwargs_for_cpp_kernel = [ + arg for arg in kernel.arg_names if arg in kernel_args + ] + + mark_node_as_mutating( + self, *[a for a in kernel_args.values() if isinstance(a, TensorBox)] + ) + + def get_alias_names(self): + return [i.get_name() for i in self.inputs] + + +def mark_node_as_mutating(cur_buffer, *mutated_ops): + """ + Allows ops in mutated_ops to be marked as being mutated as well as + indicates to the scheduler that these ops depend on cur_buffer. + """ + for op in mutated_ops: + assert isinstance(op, IRNode), op + V.graph.mark_buffer_mutated(op.get_name()) + assert hasattr(op, "layout") + MutationOutput(op.layout, op, cur_buffer) + + +class MutationOutput(ExternKernel): + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def __init__(self, layout, input, parent): + super().__init__(None, layout, [input, parent], ()) + self.name = V.graph.register_buffer(self) + + def should_allocate(self): + return False + + def is_no_op(self): + return True + + def has_side_effects(self): + return True + + def get_alias_names(self): + return [self.inputs[0].get_name()] + + +class InplaceBernoulliFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + (x,) = (t.codegen_reference() for t in self.inputs) + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def __init__(self, x, *constant_args): + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage([x]), + constant_args, + ) + self.name = V.graph.register_buffer(self) + self.python_kernel_name = "aten.bernoulli_" + self.cpp_kernel_name = ( + "aoti_torch_bernoulli_" + if config.abi_compatible + else "at::native::bernoulli_" + ) + mark_node_as_mutating(self, x) + + +# Used to deal with torch.complex types +class InplaceCopyFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + (dst, src, non_blocking) = self.codegen_args() + wrapper.writeline( + f"{self.get_kernel_name()}({dst}, {src}, {non_blocking}){wrapper.ending}" + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def __init__( + self, + layout, + inputs, + constant_args, + ): + super().__init__( + None, + layout, + inputs, + constant_args, + python_kernel_name="aten.copy_", + cpp_kernel_name=( + "aoti_torch_copy_" if config.abi_compatible else "at::_ops::copy_::call" + ), + ) + self.name = V.graph.register_buffer(self) + + @classmethod + def create(cls, dst, src, non_blocking: bool = False): + inputs = [cls.realize_input(t) for t in [dst, src]] + constant_args = (non_blocking,) + result = InplaceCopyFallback( + NoneLayout(dst.get_device()), # type: ignore[arg-type] + inputs, + constant_args, + ) + mark_node_as_mutating(result, dst) + return result + + +class MutatingFirstArgExternKernel(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + argrefs = [ + *(t.codegen_reference() for t in self.inputs), + *map(repr, self.constant_args), + ] + wrapper.writeline( + f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}" + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def has_side_effects(self): + return True + + +class ResizeStorageBytes(MutatingFirstArgExternKernel): + def __init__(self, variable, new_size): + assert isinstance(new_size, int), "TODO: dynamic shapes" + super().__init__( + None, + NoneLayout(variable.get_device()), # type: ignore[arg-type] + self.unwrap_storage([variable]), + constant_args=(new_size,), + ) + V.graph.mark_buffer_mutated(variable.get_name()) + self.name = V.graph.register_buffer(self) + self.python_kernel_name = "inductor_ops.resize_storage_bytes_" + self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_" + V.graph.never_reuse_buffers.add(variable.data.get_name()) + mark_node_as_mutating(self, variable) + + +class ScatterFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly. + This class handles both aten.scatter_ and aten.scatter_reduce_. + It also handle the case `src` being a scalar properly. + """ + + def codegen(self, wrapper): + reduce = self.kwargs["reduce"] + if V.graph.cpp_wrapper: + # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum + get_operator_enum = {"add": "sum", "multiply": "prod"} + if reduce in get_operator_enum: + reduce = get_operator_enum[reduce] + + if self.src_is_tensor: + (x, index, src) = (t.codegen_reference() for t in self.inputs) + else: + (x, index) = (t.codegen_reference() for t in self.inputs) + src = self.constant_args[1] + wrapper.generate_scatter_fallback( + x, + [x, self.constant_args[0], index, src], + self.get_kernel_name(), + self.python_kernel_name, + self.src_is_tensor, + reduce, + self.codegen_kwargs(), + ) + + def should_allocate(self): + return False + + def get_cpp_kernel(self): + reduce = self.kwargs["reduce"] + if self.python_kernel_name == "aten.scatter_": + if self.src_is_tensor: + kernel = ( + "at::scatter_out" if reduce is None else "at::scatter_reduce_out" + ) + else: + assert ( + reduce is None + ), "Expect reduce to be None for aten.scatter_ with scalar src" + kernel = "at::scatter_out" + else: + assert ( + reduce is not None + ), "Expect reduce to be not None for aten.scatter_reduce_" + kernel = "at::scatter_reduce_out" + return kernel + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def __init__( + self, + op_overload, + python_kernel_name, + x, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, + ): + assert python_kernel_name in {"aten.scatter_", "aten.scatter_reduce_"} + self.src_is_tensor = isinstance(src, TensorBox) + + constant_args: Tuple[Any, ...] + if self.src_is_tensor: + tensors = [self.realize_input(t) for t in [x, index, src]] + constant_args = (dim,) + else: + tensors = [self.realize_input(t) for t in [x, index]] + constant_args = (dim, src) + + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage(tensors), + constant_args, + {"reduce": reduce, "include_self": include_self}, + python_kernel_name=python_kernel_name, + ordered_kwargs_for_cpp_kernel=["reduce", "include_self"], + op_overload=op_overload, + ) + self.cpp_kernel_name = self.get_cpp_kernel() + self.name = V.graph.register_buffer(self) + mark_node_as_mutating(self, x) + + +class IndexPutFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation and indices properly + """ + + def codegen(self, wrapper): + (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs) + indices = [] + iter_valid_indices = iter(valid_indices) + for i, _ in enumerate(self.indices): + if self.indices[i] is not None: + indices.append(next(iter_valid_indices)) + else: + indices.append(V.graph.wrapper_code.none_str) + + wrapper.generate_index_put_fallback( + self.get_kernel_name(), x, indices, values, *self.codegen_const_args() + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def __init__(self, op_overload, x, indices, values, accumulate): + self.indices = indices + valid_indices = [i for i in indices if i is not None] + tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] + cpp_kernel_name = ( + "aoti_torch_index_put_out" if config.abi_compatible else "at::index_put_out" + ) + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage(tensors), + (accumulate,), + python_kernel_name="aten.index_put_", + cpp_kernel_name=cpp_kernel_name, + op_overload=op_overload, + ) + self.name = V.graph.register_buffer(self) + mark_node_as_mutating(self, x) + + +class DeviceCopy(ExternKernelOut): + @classmethod + def create(cls, x, device): + if ( + not x.is_extern() + and all( + (r.name in V.graph.constants and isinstance(r, dependencies.MemoryDep)) + for r in x.get_reads() + ) + and not config.aot_inductor.use_runtime_constant_folding + ): + return x.constant_to_device(device) + + V.graph.add_device_info(device) + V.graph.add_device_info(x.get_device()) + + developer_warning("DeviceCopy in input program") + return DeviceCopy( + FlexibleLayout( + device=device, + dtype=x.get_dtype(), + size=x.get_size(), + ), + [cls.realize_input(x)], + ) + + def codegen(self, wrapper): + args = self.codegen_args() + assert len(args) == 1 + if self.output_view: + wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference()) + else: + wrapper.codegen_device_copy(args[0], self.codegen_reference()) + + +class DynamicScalar(ExternKernel): + """ + The result of a call to aten._local_scalar_dense. + """ + + def get_reads(self): + return () + + def should_allocate(self): + return False + + # TODO: handle bools carefully + def __init__(self, sym, data): + data.realize() + super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data])) # type: ignore[arg-type] + if isinstance(sym, sympy.Symbol): + self.sym = sym + self.is_bool = False + else: + # Special case for boolean. For Reasons(TM), we don't represent + # boolean variables directly in sympy; instead, we generate an + # indicator integer variable which we then convert to a boolean by + # testing i0 == 1. We have to identify the underlying indicator + # variable, and then bind i0 to the appropriate integer value + # based on the runtime boolean. + assert isinstance(sym, sympy.Eq), sym + assert isinstance(sym.args[0], sympy.Symbol), sym + assert sym.args[1] == 1, sym + self.sym = sym.args[0] + self.is_bool = True + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return {self.sym} + + def codegen(self, wrapper): + wrapper.codegen_dynamic_scalar(self) + + +class AssertScalar(ExternKernel): + """ + The result of a call to aten._assert_scalar + """ + + def get_reads(self): + return () + + def should_allocate(self): + return False + + def __init__(self, scalar, msg): + super().__init__( + # Buffer(name, layotu) + None, + NoneLayout(torch.device("cpu")), # type: ignore[arg-type] + # InputsKernel(inputs) + [], + ) # type: ignore[arg-type] + self.scalar = scalar + self.msg = msg + + def has_side_effects(self): + return True + + def get_unbacked_symbol_uses(self): + return free_unbacked_symbols(self.scalar) + + def codegen(self, wrapper): + if V.graph.cpp_wrapper: + pass + else: + wrapper.writeline( + f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar)}:" + ) + wrapper.writeline(f" raise RuntimeError({repr(self.msg)})") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + wrapper.writeline(f"{self.get_name()} = None") + + +@dataclasses.dataclass +class ExternKernelNode: + name: str + node: export_schema.Node + + +has_c_shim = { + aten._embedding_bag.default, + aten._fft_c2c.default, + aten._scaled_dot_product_efficient_attention.default, + aten._scaled_dot_product_flash_attention.default, + aten._scaled_mm.default, + aten.addmm.out, + aten.bmm.out, + aten.copy_.default, + aten.mm.out, + aten.repeat_interleave.Tensor, + aten.nonzero.default, + aten.view.dtype, + aten.view_as_real.default, +} + + +def get_aten_cpp_kernel_name(kernel): + # Calling with the default kernel name can lead to ambiguous behavior like the following example. + # repeat_interleave(const at::Tensor & repeats, c10::optional output_size=c10::nullopt) + # repeat_interleave(const at::Tensor & self, int64_t repeats, + # c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) + assert ( + isinstance(kernel, torch._ops.OpOverload) and kernel.namespace == "aten" + ), "Invalid aten kernel" + opname = ( + kernel.__name__.split(".")[0] + if kernel._overloadname == "default" + else kernel.__name__.replace(".", "_") + ) + return f"at::_ops::{opname}::call" + + +class FallbackKernel(ExternKernelAlloc): + args_default_value: List[Dict[str, Any]] + + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + ): + super().__init__( + layout, + tuple(tensor_args), + tuple(nontensor_args), + op_overload=kernel, + ) + # We need output buffers for generating kernel arguments in the + # abi-compatible mode, where we retrieve outputs by pass each individual + # output through the abi-compatible interface. + self.outputs: Sequence[Any] = [] + self.use_runtime_dispatch = False + self.abi_compatible_kernel = None + + assert isinstance( + kernel, + ( + torch._ops.OpOverload, + torch._ops.HigherOrderOperator, + ), + ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" + self.op_overload = kernel + + self.unflatten_args = unflatten_args + self.kwargs = {} if kwargs is None else kwargs + V.graph.warn_fallback(self.python_kernel_name) + + # args that are aliased + self.alias_names: List[str] = [] + # args that are mutated AND returned from the op + self.mutation_names: List[str] = [] + + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + # We assume here that HOPs with FallbackKernel are functional. + # This may not always be true! HOPs must individually opt-in to + # FallbackKernel, so please check this if you opt-in. + return + + if "_c10d_functional" in self.op_overload.name(): + # _c10d_functional kernels are lowered into _CollectiveKernel which + # derives from FallbackKernel for the cpp codegen. The kernels + # don't pass the can_auto_functionalize check, but their mutation + # is handled properly by _CollectiveKernel. + return + + schema = self.op_overload._schema + + # NOTE: [FallbackKernel supported operators] + # We only support three types of operators: + # - functional ops + # - view ops + # - inplace aten ops + # - mutating ops that are auto-functionalizable. That is, + # the operator may mutate any number of inputs, but its outputs + # may not alias any of the inputs. + # + # The unsupported cases usually do not show up here (because + # AOTAutograd functionalized them away); the only way for an in-place + # op to show up here is if a lowering or pass introduced it. + if torch._library.utils.mutates_and_returns_first_arg(self.op_overload): + self.mutation_names.append(tensor_args[0].get_name()) + return + + if schema.is_mutable and not can_auto_functionalize(kernel): + raise NotImplementedError( + f"NYI: Can't generate FallbackKernel for {kernel}" + ) + + schema_args = schema.arguments + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + + def handle_aliasing_and_mutation(info, arg): + # Assertions to make sure we didn't mismatch args + if isinstance(info.type, torch.ListType): + assert isinstance(arg, (list, tuple)) + is_optional_tensor = isinstance( + info.type, torch.OptionalType + ) and isinstance(info.type.getElementType(), torch.TensorType) + if is_optional_tensor or isinstance(info.type, torch.TensorType): + # PyTorch also accepts None and scalar types for args marked as "Tensor". + # We're not going to check all of them here. + assert not isinstance(arg, (tuple, list)) + + if arg is None: + return + if info.alias_info is None: + return + # can_auto_functionalize already filters out mutable List[Tensor]. + # We can support this in the future, but this is very uncommon. + assert isinstance(info.type, torch.TensorType) or is_optional_tensor + self.alias_names.append(arg.get_name()) + if info.alias_info.is_write: + mark_node_as_mutating(self, arg) + + for info, arg in torch._library.utils.zip_schema(schema, args, kwargs): + handle_aliasing_and_mutation(info, arg) + + def set_cpp_kernel(self, kernel): + from .codegen.wrapper import get_cpp_op_schema + + assert ( + not kernel._schema.is_mutable + ), f"mutable {kernel.__name__} is not supported with cpp_wrapper" + + # These checks are here because ops that return aliasing tensors will + # return type Tensor& instead of Tensor, but codegen will always write + # type Tensor on the LHS. + def is_not_write(arg): + return arg.alias_info is None or not arg.alias_info.is_write + + assert all( + is_not_write(x) for x in kernel._schema.arguments + ), f"{kernel.__name__} with alias_info arguments is not supported with cpp_wrapper" + assert all( + is_not_write(x) for x in kernel._schema.returns + ), f"{kernel.__name__} with alias_info returns is not supported with cpp_wrapper" + + self.cpp_kernel_name = kernel._schema.name + self.cpp_kernel_overload_name = kernel._schema.overload_name + self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] + + self.cpp_op_schema = get_cpp_op_schema(kernel) + self.init_args_default_value(kernel._schema) + + def is_legacy_abi_kernel(self): + return ( + config.c_shim_version == "1" + and "_scaled_dot_product_flash_attention" in str(self.python_kernel_name) + ) + + def init_args_default_value(self, schema): + self.args_default_value = [ + { + "name": x.name, + "type": x.real_type, + "value": x.default_value, + } + for x in schema.arguments + if not x.kwarg_only + ] + + def get_pos_arg_value(self, pos, kwargs): + # positional args may be provided in kwargs + pos_arg_name = self.args_default_value[pos]["name"] + if pos_arg_name in kwargs: + log.debug( + "Found argument %s with value %s from kwargs", + pos_arg_name, + kwargs[pos_arg_name], + ) + return kwargs[pos_arg_name] + + assert hasattr( + self, "args_default_value" + ), "self.args_default_value has to be provided" + assert pos < len( + self.args_default_value + ), f"expected the index {pos} to be smaller than len(self.args_default_value): {len(self.args_default_value)}" + arg_default_value = self.args_default_value[pos]["value"] + log.debug( + "Use default value %s for argument %s", arg_default_value, pos_arg_name + ) + return arg_default_value + + def codegen_args(self): + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self): + return self.ref + + tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] + args, kwargs = self.unflatten_args(tensor_args, self.constant_args) + # Now we setup abi_compatible_kernel after self.python_kernel_name + # and kwargs are adjusted appropriately. + # For sdpa, we need the v2 version since v1 didn't consider optional arg + # FIXME: no need to do this after we switch to the torchgen-ed C shim + self.abi_compatible_kernel = ( + f"{self.cpp_kernel_name}_v2" + if self.cpp_kernel_name in {"at::_scaled_dot_product_flash_attention"} + and config.c_shim_version == "1" + else self.cpp_kernel_name + ) + + if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): + args = [ + V.graph.wrapper_code.val_to_cpp_arg_str( + param.real_type, x, self.is_legacy_abi_kernel() + ) + for param, x in zip(self.op_overload._schema.arguments, args) + ] + else: + args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] + + # Previously, we want to maintain forward-compatibility by skipping + # default args in the serialized artifacts in fbcode. However, + # some of our shim interfaces require default values being set. + # Discussed with Sherlock offline and we decided to allow serializing + # default args into the C++ wrapper code for now. We will refine this + # part if we see real FC requirement. More details related to FC + # can be found at: + # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing + if V.graph.cpp_wrapper and hasattr(self, "args_default_value"): + self.fill_non_provided_args(args, kwargs, convert_val_to_str=True) + + # let self.codegen_kwargs handle kwargs + self.kwargs.update(kwargs) + return args + + @staticmethod + def find_device(tensor_args, example_output): + if tensor_args: + return tensor_args[0].get_device() + if isinstance(example_output, torch.Tensor): + return example_output.device + if isinstance(example_output, (list, tuple)): + devices = {FallbackKernel.find_device(None, x) for x in example_output} + # Remove None + devices = [device for device in devices if device] + if len(devices) == 1: + return devices[0] + for device in devices: + if device.type == "cuda": + return device + return devices[0] + return None + + def has_side_effects(self): + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + return False + return get_schema_info(self.op_overload).is_mutable() + + def get_alias_names(self): + return self.alias_names + + def get_mutation_names(self): + assert len(self.mutation_names) <= 1 + return self.mutation_names + + def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): + assert isinstance(args, (list, tuple)) + if isinstance(args, tuple): + args = list(args) + assert hasattr(self, "args_default_value") + n_args = len(args) + n_pos_args = len(self.args_default_value) + # For cpp wrapper, if some positional args are not provided, we need to check + # if they're in the kwargs or use their default value + if n_args < n_pos_args: + log.debug( + "%s has %d unprovided positional arguments. " + "Will check if they are in the keyword arguments or will use default values.", + self.op_overload, + n_pos_args - n_args, + ) + pos_args = [ + self.get_pos_arg_value(i, kwargs) for i in range(n_args, n_pos_args) + ] + if convert_val_to_str: + pos_args = [V.graph.wrapper_code.val_to_arg_str(x) for x in pos_args] + args.extend(pos_args) + return args + + # ProxyExecutor Design Note + # We export the ExternFallbackNodes (for custom ops) into a serialized file + # and run it with a host side proxy executor to address the ABI problem + # This is currently only implemented for fbcode. Eventually, we will also make this work for OSS. + # Detailed design doc can be found at + # https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing + def export_extern_kernel_node(self): + assert isinstance(self, FallbackKernel) + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + args = self.fill_non_provided_args(args, kwargs) + ordered_kwargs = [ + kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel + ] + + serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] + named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # type: ignore[arg-type] + + # serialize_outputs + def handle_single_output(return_type, output): + if isinstance(return_type, torch.TensorType): + # For single Tensor + out = output + if isinstance(output, (list, tuple)): + assert len(output) == 1 + out = output[0] + return export_schema.Argument.create( + as_tensor=export_schema.TensorArgument(name=out.get_name()) + ) + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ): + # For single TensorList + return export_schema.Argument.create( + as_tensors=[ + export_schema.TensorArgument(name=out.get_name()) + for out in output + ] + ) + else: + raise RuntimeError(f"Unsupported return type {type(return_type)}") + + target = self.op_overload + returns = target._schema.returns # type: ignore[union-attr] + if len(returns) == 1: + return_type = returns[0].real_type + output_arguments = [handle_single_output(return_type, self.outputs)] + else: + # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" + assert isinstance(self.outputs, tuple) + assert len(returns) == len(self.outputs) + output_arguments = [ + handle_single_output(return_schema.real_type, output) + for return_schema, output in zip(returns, self.outputs) + ] + + node = ExternKernelNode( + name=self.get_name(), + node=export_schema.Node( + target=self.op_overload.name(), # type: ignore[union-attr] + inputs=named_arguments, + outputs=output_arguments, + metadata={}, + ), + ) + + V.graph.extern_kernel_nodes.append(node) + + return [*args, *ordered_kwargs] + + def codegen(self, wrapper): + kernel = self.op_overload + if kernel.namespace == "aten": # type: ignore[union-attr] + # Aten Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) + if V.graph.cpp_wrapper: + if ( + config.is_fbcode() + and kernel not in has_c_shim + # C shim v2 is torchgen-ed, which should cover all aten ops. + # If you do hit a missed op, please update gen_aoti_c_shim.py. + and config.c_shim_version == "1" + ): + log.warning( + "%s is missing a c-shim implementation, using proxy executor as fallback", + kernel, + ) + self.use_runtime_dispatch = True + self.set_cpp_kernel(kernel) + else: + self.cpp_kernel_name = get_aten_cpp_kernel_name(kernel) + schema = kernel._schema + self.init_args_default_value(schema) + else: + self.python_kernel_name = str(kernel) + + elif isinstance(kernel, torch._ops.HigherOrderOperator): + self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}" + else: + # For non-aten OpOverload, i.e. custom ops + if V.graph.cpp_wrapper: + self.use_runtime_dispatch = True + self.set_cpp_kernel(kernel) + else: + self.python_kernel_name = f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" # type: ignore[union-attr] + + if self.use_runtime_dispatch: + self.codegen_comment(wrapper) + + exported_args = None + args = None + if config.is_fbcode() and V.graph.cpp_wrapper: + exported_args = self.export_extern_kernel_node() + else: + args = [*self.codegen_args(), *self.codegen_kwargs()] + + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + exported_args, + self.outputs, + ) + else: + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + V.graph.wrapper_code.generate_fallback_kernel(self, args) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @staticmethod + def tensor_to_layout(output: torch.Tensor): + return FixedLayout( + output.device, + output.dtype, + convert_shape_to_inductor(output.size()), + convert_shape_to_inductor(output.stride()), + ) + + @classmethod + def create(cls, kernel, *args, **kwargs): + fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) + context = ( + V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() + ) + with context: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + ) = cls.process_kernel(kernel, *args, **kwargs) + + device = cls.find_device(tensor_args, example_output) + assert device, "Not sure where to find device info" + + packed = cls( + MultiOutputLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + + def generate_output(output, indices): + if isinstance(output, (list, tuple)): + return type(output)( + generate_output(output[i], indices + [(type(output), i)]) + for i in range(len(output)) + ) + elif isinstance(output, dict): + return { + key: generate_output(val, indices + [(type(output), key)]) + for key, val in output.items() + } + elif isinstance(output, torch.Tensor): + return MultiOutput( + cls.tensor_to_layout(output), + packed, + indices, + ) + elif isinstance(output, int): + return output + elif isinstance(output, torch.SymInt): + return output.node.expr + else: + assert ( + output is None + ), f"FallbackKernel output type {type(output)} is not supported" + return None + + outputs = generate_output(example_output, []) + if isinstance(outputs, (list, tuple, dict)): + packed.outputs = outputs # type: ignore[assignment] + else: + packed.outputs = [outputs] + return outputs + + def apply_constraint(self): + return super().apply_constraint() + + +@dataclasses.dataclass +class ComplexView(FallbackKernel): + """View a complex number as two dtyped numbers or vice versa""" + + def should_allocate(self): + return False + + def get_alias_names(self): + # Signal to codegen that our output buffer isn't safe to reuse + return [self.inputs[0].get_name()] + + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + ): + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + ) + + +@dataclasses.dataclass +class MultiOutputLayout(IRNode): + device: torch.device + + +class MultiOutput(ExternKernel): + # Given an input MultiOutputLayout buffer, indexes out an actual buffer + # from that result. This doesn't actually produce multiple outputs, + # that's MultiOutputLayout! + def codegen_list_tuple_access(self, basename, indices): + if len(indices) > 0: + itype, i = indices[0] + if itype == list: + return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:]) + elif itype == tuple: + # cpp wrapper code needs to use std::get<> to access a tuple + tuple_access = V.graph.wrapper_code.codegen_tuple_access( + basename, self.get_name(), str(i) + ) + return self.codegen_list_tuple_access(tuple_access, indices[1:]) + elif itype == dict: + return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:]) + else: + raise AssertionError("non supported index type") + else: + return basename + + def codegen(self, wrapper): + wrapper.codegen_multi_output( + self.get_name(), + self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices), + ) + self.codegen_unbacked_symbol_defs(wrapper) + + def __init__(self, layout, input, indices: List[Tuple[Any, ...]]): + super().__init__(None, layout, [input], ()) + self.name = V.graph.register_buffer(self) + self.indices = indices + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return self.inputs[0].get_unbacked_symbol_uses() + + def should_allocate(self): + return False + + def get_alias_names(self): + return [ + inp.get_name() + for inp in self.inputs + if isinstance(inp, FallbackKernel) and len(inp.get_alias_names()) > 0 + ] + + +def _prepare_convolution_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding: List[int], + stride: List[int], + dilation: List[int], + groups: int, + transposed: bool = False, + output_padding: Optional[List[int]] = None, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for convolution post-op fusion's create function, including deciding the output + layout (channels first or channels last), realizing inputs and make them etc. The + function only supports the CPU device since conv post-op fusion kernel is only + supported on CPU right now. + """ + + # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size + def _conv_input_size( + output_size, weight_size, padding, output_padding, stride, dilation, groups + ): + assert len(output_size) == len(weight_size), "Expect input dim == weight dim" + dim = len(output_size) + assert dim > 2, "Expect input dim > 2" + + BATCH_DIM = 0 + WEIGHT_INPUT_CHANNELS_DIM = 1 + input_size = [] + input_size.append(output_size[BATCH_DIM]) + input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) + for d in range(2, dim): + kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 + input_size_d = ( + (output_size[d] - 1) * stride[d - 2] + - (padding[d - 2] * 2) + + kernel + + output_padding[d - 2] + ) + input_size.append(input_size_d) + return list(map(int, input_size)) + + # The size of prepacked_weight is the prepacked weight size of deconv: + # Groups > 1: [g*o, i/g, ...] + # Groups == 1: [o, i, ...] + # Returns original weight size in [i, o, ...] + def _original_deconv_weight_size( + prepacked_weight, + groups, + ): + prepacked_weight_size = prepacked_weight.size() + dim = len(prepacked_weight_size) + assert dim > 2, "Expect weight dim > 2" + if groups > 1: + weight_size = [] + weight_size.append(prepacked_weight_size[1] * groups) + weight_size.append(prepacked_weight_size[0] / groups) + for d in range(2, dim): + weight_size.append(prepacked_weight_size[d]) + else: + weight_size = prepacked_weight.transpose(0, 1).size() + return weight_size + + x.realize() + weight.realize() + if bias is not None: + bias.realize() + with V.graph.fake_mode: + # TODO cleaned up the fake_tensor trace as Linear implementation + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + dims = len(x_fake.size()) - 2 + assert 0 < len(padding) <= dims + assert 0 < len(dilation) <= dims + assert 0 < len(stride) <= dims + padding = pad_listlike(padding, dims) + dilation = pad_listlike(dilation, dims) + stride = pad_listlike(stride, dims) + if output_padding is None: + output_padding = pad_listlike([0], dims) + else: + assert 0 < len(output_padding) <= dims + output_padding = pad_listlike(output_padding, dims) + assert isinstance(groups, int) + if transposed: + # When transposed, the size of the prepacked oneDNN weight is different + # from the PyTorch weight. We're not able to run aten conv with such + # size. We infer the output size from the input params here: + weight_size = _original_deconv_weight_size(weight_fake, groups) + input_size = x_fake.size() + output_size = _conv_input_size( + input_size, + weight_size, + padding, + output_padding, + stride, + dilation, + groups, + ) + else: + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + output_size = output.size() + + req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) + req_stride_order = [len(req_stride_order)] + req_stride_order + output_stride = make_channels_last_strides_for(output_size) + + x = cls.require_stride_order(x, req_stride_order) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + convert_shape_to_inductor(output_size), + convert_shape_to_inductor(output_stride), + ) + constant_args = [padding, stride, dilation, groups] + if transposed: + constant_args.insert(1, output_padding) + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +def _prepare_linear_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", +): + """ + This function is a helper function to prepare inputs, layout and constant args + for linear post-op fusion's create function. The function only supports the CPU device + since linear post-op fusion kernel is only supported on CPU right now. + """ + x.realize() + weight.realize() + if bias is not None: + bias.realize() + + *m, _ = x.get_size() + # The weight has been transposed during the qlinear weight prepack process. + # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ + # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 + _, oc = weight.get_size() + output_size = list(m) + [oc] + req_stride_order = list(reversed(range(len(x.get_size())))) + + x = cls.require_stride_order(x, req_stride_order) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + output_stride = make_contiguous_strides_for(output_size) + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ) + constant_args: List[Any] = [] + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +class ConvolutionUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise", + cpp_kernel_name="mkldnn::_convolution_pointwise", + ) + self.cpp_kernel_key = "convolution_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + cpp_constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary", + cpp_kernel_name="mkldnn::_convolution_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "convolution_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm)""" + self.cpp_constant_args = cpp_constant_args + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + return ConvolutionBinary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinaryInplace(ExternKernelAlloc): + def __init__( + self, + kernel_layout, + inputs, + constant_args=(), + ): + # Due to constrain of op.call, other (Tensor&) should be at input[0] + reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] + + super().__init__( + kernel_layout, + reordered_inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary", + cpp_kernel_name="mkldnn::_convolution_pointwise_", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "convolution_pointwise_binary_" + # TODO: op.call: input[0] should be at::Tensor& + self.cpp_op_schema = """ + at::Tensor&( + at::Tensor& other_t, + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + _, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + packed = ConvolutionBinaryInplace( + kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] + inputs=inputs, + constant_args=constant_args, + ) + mark_node_as_mutating(packed, inputs[1]) + # This op mutates in place which means that the result is not the + # target but rather the input that is being mutated + # init reorders the inputs, so inputs[1] becomes packed.inputs[0] + return packed.inputs[0] + + +class MKLPackedLinear(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkl._mkl_linear", + cpp_kernel_name="mkl::_mkl_linear", + ) + self.cpp_kernel_key = "mkl_linear" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& self, + const at::Tensor& mkl_weight_t, + const at::Tensor& origin_weight_t, + const c10::optional& bias_opt, + const int64_t prepack_batch_size)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, packed_w, orig_w, batch_size): + x = cls.require_stride1(cls.realize_input(x)) + orig_w = cls.require_stride1(cls.realize_input(orig_w)) + *m, _ = x.get_size() + oc, _ = orig_w.get_size() + output_size = list(m) + [oc] + output_stride = make_contiguous_strides_for(output_size) + inputs = [x, packed_w, orig_w] + constant_args = [None, batch_size] + + return MKLPackedLinear( + layout=FixedLayout( + x.get_device(), x.get_dtype(), output_size, output_stride + ), + inputs=inputs, + constant_args=constant_args, + ) + + +class LinearUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._linear_pointwise", + cpp_kernel_name="mkldnn::_linear_pointwise", + ) + self.cpp_kernel_key = "linear_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, w, b, attr, scalars, algorithm): + x = cls.require_contiguous(cls.realize_input(x)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + inputs = [x, w] + constant_args = [attr, scalars if scalars else [-1], algorithm] + if b is not None: + b = cls.require_contiguous(cls.realize_input(b)) + inputs.append(b) + else: + constant_args.insert(0, None) + + return LinearUnary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class LinearBinary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._linear_pointwise.binary" + + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary", + cpp_kernel_name="mkldnn::_linear_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "linear_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + c10::string_view attr) + """ + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + + @classmethod + def create(cls, x, y, w, b, attr): + x = cls.require_contiguous(cls.realize_input(x)) + y = cls.require_contiguous(cls.realize_input(y)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + + inputs = [x, y, w] + constant_args = [attr] + if b is not None: + b = cls.require_contiguous(cls.realize_input(b)) + inputs.append(b) + else: + constant_args.insert(0, b) + + return LinearBinary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class ConvolutionTransposeUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise", + cpp_kernel_name="mkldnn::_convolution_transpose_pointwise", + ) + self.cpp_kernel_key = "convolution_transpose_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + output_padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups_: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + transposed = True + ( + inputs, + constant_args, + kernel_layout, + _, + ) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups_, + transposed, + output_padding_, + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionTransposeUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class MkldnnRnnLayer(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="aten.mkldnn_rnn_layer", + cpp_kernel_name="at::mkldnn_rnn_layer", + ) + + @classmethod + def create( + cls, + x: "TensorBox", + w0: "TensorBox", + w1: "TensorBox", + w2: "TensorBox", + w3: "TensorBox", + hx: "TensorBox", + cx: "TensorBox", + reverse: bool, + batch_sizes: List[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + x = cls.require_stride1(cls.realize_input(x)) + # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. + # Make sure x is contiguous in batch_first case. + x.freeze_layout() + w0 = cls.require_stride1(cls.realize_input(w0)) + w1 = cls.require_stride1(cls.realize_input(w1)) + w2 = cls.require_stride1(cls.realize_input(w2)) + w3 = cls.require_stride1(cls.realize_input(w3)) + hx = cls.require_stride1(cls.realize_input(hx)) + hx.freeze_layout() + cx = cls.require_stride1(cls.realize_input(cx)) + cx.freeze_layout() + + input_size = x.get_size() + assert len(input_size) == 3, "Expect lstm input to be 3D" + # batch_first is handled in the lstm OP. When entering + # rnn_layer here, we'll always have batch_first = False + seq_length, mini_batch, input_size = input_size + output_shape = [seq_length, mini_batch, hidden_size] + + hy_shape = hx.get_size() + cy_shape = cx.get_size() + + res: List[IRNode] = [] + + inputs = [x, w0, w1, w2, w3, hx, cx] + constant_args = [ + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ] + + packed = MkldnnRnnLayer( + MultiOutputLayout(x.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + def get_strides_of_lstm_output(output_shape, batch_first): + assert len(output_shape) == 3, "Expect output_shape to be 3D" + return make_contiguous_strides_for(output_shape) + + output_sizes = [output_shape, hy_shape, cy_shape] + output_strides = [ + get_strides_of_lstm_output(output_shape, batch_first), + make_contiguous_strides_for(hy_shape), + make_contiguous_strides_for(cy_shape), + ] + output_ir = [ + MultiOutput( + FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ), + packed, + [(tuple, i)], + ) + for i, (output_size, output_stride) in enumerate( + zip(output_sizes, output_strides) + ) + ] + + return output_ir + + +class QConvPointWisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 5 + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.onednn.qconv2d_pointwise", + cpp_kernel_name="onednn::qconv2d_pointwise", + ) + self.cpp_kernel_key = "qconv2d_pointwise" + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double inv_output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp = args[-2], args[-1] + ( + stride, + padding, + dilation, + groups, + x_scale, + x_zp, + o_inv_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + x_scale: float, + x_zp: int, + weight: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zp: "TensorBox", + bias: "TensorBox", + stride_: List[int], + padding_: List[int], + dilation_: List[int], + groups: int, + o_inv_scale: float, + output_zero_point: int, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ): + transposed = False + output_padding = None + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups, + transposed, + output_padding, + ) + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zp.realize() + inputs = inputs + [w_scale, w_zp] + constant_args = constant_args + [ + x_scale, + x_zp, + o_inv_scale, + output_zero_point, + output_dtype, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. + kernel_layout.dtype = output_dtype + + return QConvPointWisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + """ + Needs input/weight/output qparams + if bias is not None + - inputs = [x, w, b, accum, w_scale, w_zp] + - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, accum, w_scale, w_zp] + - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale, + accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 6 + self.idx_for_inplace_sum = 3 if self.has_bias else 2 + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary", + cpp_kernel_name="onednn::qconv2d_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "qconv2d_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor accum, + double accum_scale, + int64_t accum_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double inv_output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + accum, w_scale, w_zp = args[-3], args[-2], args[-1] + ( + stride, + padding, + dilation, + groups, + x_scale, + x_zp, + accum_scale, + accum_zp, + o_inv_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-16:] + conv_args = ( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + conv_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self): + return [self.inputs[self.idx_for_inplace_sum].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, + x: "TensorBox", + x_scale, + x_zp, + accum: "TensorBox", + accum_scale, + accum_zp, + weight: "TensorBox", # packed_weight + w_scale, + w_zp, + bias: "TensorBox", + stride_: List[int], + padding_: List[int], + dilation_: List[int], + groups: int, + o_inv_scale: "TensorBox", + output_zero_point: "TensorBox", + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + transposed = False + output_padding = None + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups, + transposed, + output_padding, + ) + + accum = cls.require_stride_order(accum, req_stride_order) + inputs.append(accum) + + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zp.realize() + inputs = inputs + [w_scale, w_zp] + constant_args = constant_args + [ + x_scale, + x_zp, + accum_scale, + accum_zp, + o_inv_scale, + output_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + assert ( + binary_attr == "sum" + ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." + + packed = QConvPointWiseBinaryPT2E( + layout=NoneLayout(accum.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + mark_node_as_mutating(packed, accum) + + # Return accum since it has been inplace changed. + return packed.inputs[packed.idx_for_inplace_sum] + + +class QLinearPointwisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name=( + "torch.ops.onednn.qlinear_pointwise.tensor" + if x_scale_zp_are_tensors + else "torch.ops.onednn.qlinear_pointwise.default" + ), + cpp_kernel_name="onednn::qlinear_pointwise", + ) + self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else "" + self.cpp_kernel_key = "qlinear_pointwise" + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + double inv_output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + std::string post_op_name, + torch::List> post_op_args, + std::string post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp = args[-2], args[-1] + if self.x_scale_zp_are_tensors: + assert len(args) >= 4 + x_scale, x_zp = args[-4], args[-3] + ( + o_inv_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-6:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_inv_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-8:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_inv_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + x_scale: float, + x_zp: int, + weight: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zp: "TensorBox", + bias: "TensorBox", + o_inv_scale: float, + output_zero_point: int, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( + cls, + x, + weight, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): + x_scale.realize() + x_zp.realize() + inputs = inputs + [x_scale, x_zp] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zp, int) + constant_args = constant_args + [x_scale, x_zp] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zp.realize() + inputs = inputs + [w_scale, w_zp] + constant_args = constant_args + [ + o_inv_scale, + output_zero_point, + output_dtype, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + +@dataclasses.dataclass +class MutableBox(IRNode): + """ + TensorBox / StorageBox allow in-place mutation of Tensors + """ + + data: IRNode + + def __getattr__(self, name): + fn = getattr(self.data, name) + if callable(fn): + return fn + raise AttributeError(f"{type(self.data).__name__}.{name} not callable") + + def realize(self): + return self.data.realize() + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return self.data.get_unbacked_symbol_uses() + + def codegen_reference(self, writer=None): + return self.data.codegen_reference(writer) + + @property + def layout(self): + return self.data.layout # type: ignore[attr-defined] + + def get_layout(self): + return self.layout + + def get_size(self): + return self.data.get_size() + + @property + def dtype(self): + return self.data.dtype + + def __str__(self): + if isinstance(self.data, MutableBox): + line0 = f"{type(self).__name__}({type(self.data).__name__}(" + endl = "))" + inner = self.data.data + else: + line0 = f"{type(self).__name__}(" + inner = self.data + endl = ")" + + lines = [ + line0, + indent(str(inner)), + endl, + ] + return "\n".join(lines) + + __repr__ = __str__ + + +class TensorBox(MutableBox): + @staticmethod + def create(data): + return TensorBox(StorageBox(data)) + + +class StorageBox(MutableBox): + def is_input_buffer(self): + if isinstance(self.data, (InputBuffer, ReinterpretView)): + return self.data.get_name() in V.graph.graph_inputs + return False + + def realize(self): + if isinstance( + self.data, + ( + ComputedBuffer, + InputsKernel, + InputBuffer, + ReinterpretView, + TemplateBuffer, + ), + ): + return self.data.get_name() + assert isinstance(self.data, (Pointwise, Reduction, Scan)), type(self.data) + origin_node = self.data.get_origin_node() + traceback = self.data.get_traceback() + self.data = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=self.data.get_device(), + dtype=self.data.get_dtype(), + size=self.data.get_size(), + ), + data=self.data, + ) + self.data.name = V.graph.register_buffer(self.data) + self.data.origins = self.origins + self.data.origin_node = origin_node + self.data.traceback = traceback + return self.data.name + + def realize_hint(self): + """ + Called on buffers we expect to be forced to realize later. + """ + if ( + isinstance(self.data, (Pointwise, Reduction)) + and self.num_reads() > 1 + and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one() + ): + self.realize() + + def has_exceeded_max_reads(self): + return isinstance(self.data, Pointwise) and ( + self.num_reads() > config.realize_acc_reads_threshold + or self.has_large_inner_fn() + ) + + def mark_reuse(self, users): + """ + A heuristic to decide if we should realize a tensor + that is used multiple times. + """ + + def should_realize_on_cpu(loops: Union[Pointwise, Reduction]): + """ + The heuristic for realizing reused result of heavy ops on cpu + """ + heavy_ops = ["exp"] # a list of heavy ops + fn_str = loops.inner_fn_str() + return any((op + "(") in fn_str for op in heavy_ops) + + if ( + users > 1 + and isinstance(self.data, (Pointwise, Reduction)) + and ( + self.num_reads() > config.realize_reads_threshold + or self.has_large_inner_fn() + or (is_cpu(self.data) and should_realize_on_cpu(self.data)) + ) + ): + self.realize() + + @cache_on_self + def num_reads(self): + data = self.data + if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)): + return 1 + if isinstance(data, ComputedBuffer): + read_writes = data.get_read_writes() + else: + assert isinstance(data, (Pointwise, Reduction)), type(data) + read_writes = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=data.get_device(), + dtype=data.get_dtype(), + size=data.get_size(), + ), + data=data, + ).get_read_writes() + return len(read_writes.reads) + + @cache_on_self + def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self): + # Skip the check for non Pointwise instances + return ( + (sum(read.index != 0 for read in self.data.get_reads()) > 1) + if isinstance(self.data, Pointwise) + and all( + not isinstance(read, dependencies.StarDep) + for read in self.data.get_reads() + ) + else True + ) + + +@dataclasses.dataclass +class Subgraph(IRNode): + name: str + graph_module: torch.fx.GraphModule + graph: Optional["GraphLowering"] = None + + +@dataclasses.dataclass +class Conditional(ExternKernel): + predicate: Optional[DynamicScalar] = None + operands: Optional[List[TensorBox]] = None + true_subgraph: Optional[Subgraph] = None + false_subgraph: Optional[Subgraph] = None + outputs: Optional[List[MultiOutput]] = None + + def __init__( + self, + predicate: DynamicScalar, + operands: List[TensorBox], + true_subgraph: Subgraph, + false_subgraph: Subgraph, + layout: MultiOutputLayout, + ): + self.predicate = predicate + self.operands = operands + self.true_subgraph = true_subgraph + self.false_subgraph = false_subgraph + + super().__init__( + name=None, + layout=layout, # type: ignore[arg-type] + inputs=[predicate, *operands], # type: ignore[list-item] + ) + + self.name = V.graph.register_buffer(self) + + @classmethod + def create( + cls, + predicate: TensorBox, + true_fn: Subgraph, + false_fn: Subgraph, + operands: List[TensorBox], + ): + predicate = cls.realize_input(predicate) + operands = [cls.realize_input(x) for x in operands] + + fx_operands = V.graph.current_node.args[-1] + fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] + + for subgraph in (true_fn, false_fn): + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fake_operands, + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_operands) + + true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] + false_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] + + def _aliased_buffers(outputs): + buffers = [ + output.unwrap_view() if isinstance(output, ReinterpretView) else output + for output in outputs + ] + # assuming the same buffer is represented by the same IRNode object + return len({id(buffer) for buffer in buffers}) < len(outputs) + + for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): + if _aliased_buffers(true_outputs): + raise AssertionError( + "Output aliasing is currently not supported in compiled torch.cond. " + f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}" + ) + + # make sure true and false outputs are structurally equivalent + assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs) + for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)): + assert to.get_size() == fo.get_size(), (i, to, fo) + assert to.get_stride() == fo.get_stride(), (i, to, fo) + assert to.get_device() == fo.get_device(), (i, to, fo) + assert to.get_dtype() == fo.get_dtype(), (i, to, fo) + assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo) + + conditional = Conditional( + predicate=predicate, + operands=operands, + true_subgraph=true_fn, + false_subgraph=false_fn, + # use predicate device for consistent codegen-ing + layout=MultiOutputLayout(predicate.get_device()), + ) + + outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + ), + conditional, + [(list, i)], + ) + # as the true and false outputs are equivalent, + # we can use either of them here as a "template" + for i, output in enumerate(true_outputs) + ] + + conditional.outputs = outputs + return outputs + + def codegen(self, wrapper): + wrapper.codegen_conditional(self) + + +class InterpreterShim(torch.fx.Interpreter): + @staticmethod + @functools.lru_cache(None) + def _dummy_gm(): + return torch.fx.symbolic_trace(identity) + + def __init__(self, graph, submodules): + # call super() with a placeholder to avoid constructing a + # GraphModule which is very expensive (it does codegen). + super().__init__(self._dummy_gm(), garbage_collect_values=False) + self.module = self # type: ignore[assignment] + self.graph = graph + self.submodules = submodules + self.extra_traceback = False + self.fetch_attr = submodules.__getitem__ + self.current_node = None + + def run_node(self, n: torch.fx.Node) -> Any: + self.current_node = n + return super().run_node(n) + + def run(self, *args, **kwargs): + with V.set_interpreter_handler(self): + return super().run(*args, **kwargs) + + +class LoopBody: + """ + Captures the body of a Loops subclass into an FX graph. Persists any + indexing simplifications and makes it easier to analyze loop bodies. + """ + + def __init__(self, fn, args, var_ranges): + super().__init__() + self.var_ranges = var_ranges + self.indexing_exprs = {} + self.indexing_exprs_name = {} + self.reads = [] + self.writes = [] + self.reads_name2expr = {} + self.writes_name2expr = {} + self.other = [] + self.submodules = {"get_index": self.get_index} + self.subblocks = {} + self.indirect_vars = [] + self.root_block = LoopBodyBlock(self, fn, args) + self.indexing = None + + @cache_on_self + def get_nodes(self): + all_graphs = itertools.chain( + (self.root_block.graph,), + (block.graph for block in self.subblocks.values()), + ) + return [node for graph in all_graphs for node in graph.nodes] + + @cache_on_self + def bounds(self): + # Doing a local import to avoid dumping all the code here + from .bounds import BoundVars + + return BoundVars(self) + + def debug_str(self): + lines = [f"var_ranges = {dict(self.var_ranges)}"] + lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) + lines.extend( + [ + block.debug_str(name) + for name, block in itertools.chain( + [("body", self.root_block)], self.subblocks.items() + ) + ] + ) + return "\n".join(lines) + + def add_index_expr(self, expr: sympy.Expr, category, buf_name): + getattr(self, category).append(expr) + if buf_name is not None: + getattr(self, f"{category}_name2expr")[buf_name] = expr + if expr not in self.indexing_exprs_name: + name = f"index{len(self.indexing_exprs)}" + self.indexing_exprs_name[expr] = name + self.indexing_exprs[name] = expr + return self.indexing_exprs_name[expr] + + def add_submodule(self, block, prefix): + """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" + if prefix[-1].isnumeric() and prefix not in self.submodules: + name = prefix + else: + name = f"{prefix}{len(self.submodules)}" + self.submodules[name] = block + return name + + def add_indirect(self, size): + name = f"indirect{len(self.indirect_vars)}" + var = sympy_index_symbol(name) + self.indirect_vars.append(var) + return var + + def replace_indirect(self, old, new): + """Swap in a variable used in indirect indexing""" + if str(old) == str(new): + return + assert self.indexing is not None + self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} + + def get_index(self, name): + assert self.indexing is not None + return self.indexing[name] + + def __call__(self, *indices): + index = list(itertools.chain.from_iterable(indices)) + assert len(index) == len(self.var_ranges), (index, self.var_ranges) + assert all(v not in self.var_ranges for v in index) + replacements = dict(zip(self.var_ranges.keys(), index)) + self.indexing = { + name: sympy_subs(expr, replacements) + for name, expr in self.indexing_exprs.items() + } + result = self.root_block() + self.indexing = None + return result + + +class LoopBodyBlock: + """ + Captures the body of a Loops subclass into an FX graph. + In normal cases there will be a 1:1 mapping between LoopBody and + LoopBodyBlock, hower in the case of ops.masked() the masked out + operations will manifest as an extra LoopBodyBlock. + """ + + def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): + self.body = body + + def add_index(expr, category, buf_name=None): + return tracer.create_proxy( + "call_module", + "get_index", + (self.body.add_index_expr(expr, category, buf_name),), + {}, + ) + + class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] + self.name = "CaptureIndexing" + + def load(self, name: str, index: sympy.Expr): + index = add_index(index, "reads", name) + return self._inner.load(name, index) + + def store(self, name, index, value, mode=None): + index = add_index(index, "writes", name) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = add_index(index, "writes", name) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + if "welford" in reduction_type: + return tuple(result[i] for i in range(3)) + return result + + def index_expr(self, index, dtype): + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = add_index(index, "other") + return self._inner.index_expr(index, dtype) + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + offsets_size = add_index(offsets_size, "other") + return self._inner.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + @staticmethod + def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + + subblock: LoopBodyBlock + + def shim(mask, other): + return V.ops.masked(mask, subblock, other) + + name = self.body.add_submodule(shim, "masked_subblock") + subblock = LoopBodyBlock(self.body, masked_body, []) + self.body.subblocks[name] = subblock + return tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + @staticmethod + def scan( + dtype_proxy, combine_fn: Callable[..., Any], value_proxy, init_proxy + ): + def shim(dtype, value, init): + return V.ops.scan(dtype, combine_fn, value, init) + + name = self.body.add_submodule(shim, "scan") + return tracer.create_proxy( + "call_module", name, (dtype_proxy, value_proxy, init_proxy), {} + ) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + @staticmethod + def indirect_indexing(index_proxy, size, check=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + + def set_indirect(new_var): + self.body.replace_indirect( + var, V.ops.indirect_indexing(new_var, size, check) + ) + + tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + @staticmethod + def output(result): + tracer.create_proxy("output", "output", (result,), {}) + + tracer = torch.fx.Tracer() + tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) + proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) + + from .index_propagation import IndexPropagation + from .sizevars import SimplifyIndexing + + handler: Any = SimplifyIndexing( + CaptureIndexing(proxy_ops), self.body.var_ranges + ) + if config.constant_and_index_propagation: + handler = IndexPropagation(handler) + + with V.set_ops_handler(handler): + # This indirection is just a cute way to get IndexPropagation to + # unwrap the return value. + ops.output(fn(*args)) + self.graph = tracer.graph + + def __call__(self): + graph = self.graph + submodules = self.body.submodules + + return InterpreterShim(graph, submodules).run(V.get_ops_handler()) + + def debug_str(self, name="block"): + code = torch.fx.GraphModule(self.body.submodules, self.graph).code + return re.sub( + # strip `; del var0` suffixes to make output prettier + r";[^\n]*", + "", + code.strip().replace("def forward(", f"def {name}("), + ) + + +class Wait(ExternKernelAlloc): + """ + Wait should not be used by itself. It should always be constructed in tandem + with a collective op that produces a work to wait on. + """ + + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__(layout, inputs, constant_args) + + def should_allocate(self): + return False + + def codegen(self, wrapper): + from .codegen.wrapper import ReuseLine + + wrapper.add_import_once( + "from torch.distributed._functional_collectives_impl import _wait_tensor" + ) + (input_collective,) = (t.codegen_reference() for t in self.inputs) + wrapper.writeline(f"{input_collective} = _wait_tensor({input_collective})") + + # wait op still needs to produce a 'buffer' that represents the tensor output. + # this is a symbolic gesture, and it gets handled by WrapperCodegen. + # codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective') + # to a new name (`self.get_name()`) and `del`s the old name. + wrapper.writeline(ReuseLine(wrapper, self.inputs[0], self, delete_old=False)) + + @classmethod + def create(cls, collective_op: "TensorBox"): + # TODO(whc) i'm not sure what's going on here, this probably means I missed something upstream + collective_op.decide_layout() + return Wait( + layout=AliasedLayout(collective_op), + inputs=[collective_op], + ) + + def get_alias_names(self): + # Signal to codegen that our output buffer isn't safe to reuse + return [self.inputs[0].codegen_reference()] + + def get_mutation_names(self): + # The generated `_wait_tensor` op mutates the input tensor + return [self.inputs[0].codegen_reference()] + + +class CollectiveKernel(ExternKernel): + """ + Each collective should follow the pattern: + - extend InPlaceCollectiveKernel or OutOfPlaceCollectiveKernel. + - the kernel delegates into c10d processgroup, which returns a 'work' obj + - the work obj is registered via _register_tensor_work so it can be waited on later + """ + + def __init__(self, layout, inputs, constant_args): + super().__init__(None, layout, inputs, constant_args) + self.name = V.graph.register_buffer(self) + + def should_emit_register_tensor_work(self): + return True + + def should_emit_find_or_create_pg(self): + return True + + def codegen_collective(self, wrapper, output_name, input_names): + # factor so the boilerplate can be handled in CollectiveKernel.codegen + raise NotImplementedError("Must implement") + + def codegen_output(self, wrapper, output_name, input_names): + # factor so the boilerplate can be handled in CollectiveKernel.codegen + raise NotImplementedError("Must implement") + + @classmethod + def wrap_inputs_as_inplace(cls, inputs): + def wrap_input(var): + op = InPlaceHint( + FlexibleLayout(var.get_device(), var.get_dtype(), var.get_size()), var + ) + return TensorBox.create(op) + + return list(map(wrap_input, inputs)) + + def codegen(self, wrapper): + wrapper.add_import_once("import torch.distributed as dist") + wrapper.add_import_once("import torch.distributed.distributed_c10d as c10d") + wrapper.add_import_once( + "import torch.distributed._functional_collectives_impl as fun_col_impl" + ) + # extract references to our args in string form for codegen output + input_names = [t.codegen_reference() for t in self.inputs] + output_name = self.get_name() + tag, ranks, group_size = self.constant_args + + if self.should_emit_find_or_create_pg(): + # TODO: avoid more than one ref of the same pg (even though they are cached inside the api) + wrapper.writeline( + f"{output_name}_pg = c10d._find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})" + ) + + self.codegen_output(wrapper, output_name, input_names) + self.codegen_collective(wrapper, output_name, input_names) + if self.should_emit_register_tensor_work(): + wrapper.writeline( + f"fun_col_impl._register_tensor_work({output_name}, {output_name}_work)" + ) + + +class InPlaceCollectiveKernel(CollectiveKernel): + """ + InPlaceCollectiveKernel are those with in-out arguments such as all_reduce. + Extend this kernel if your collective needs to modify its inputs in-place. + """ + + def __init__(self, layout, inputs, constant_args): + super().__init__(layout, inputs, constant_args) + + def should_allocate(self): + return False + + def has_side_effects(self): + return True + + def codegen_output(self, wrapper, output_name, input_names): + if len(input_names) > 1: + wrapper.writeline(f"{output_name} = [{','.join(input_names)}] ") + else: + wrapper.writeline(f"{output_name} = {input_names[0]}") + + +class OutOfPlaceCollectiveKernel(CollectiveKernel): + """ + OutOfPlaceCollectiveKernel are those that allocate their + outputs and leave their inputs inplace, such as all_gather. + """ + + def __init__(self, layout, inputs, outputs, constant_args): + super().__init__(layout, inputs + outputs, constant_args) + self.outputs = outputs + self.original_inputs = inputs + # NOTE: As seen in issue #108780, output buffers of out-of-place collectives + # could be incorrectly reused. As a safety measure, here we just ban the reuse of them. + # TODO: A better fix is to figure out how to propagate the aliases properly, + # so that the buffer is only reused after all its users have consumed it. + for x in self.outputs: + V.graph.never_reuse_buffers.add(x.name) + + def should_allocate(self): + return False + + def has_side_effects(self): + return True + + def codegen_output(self, wrapper, output_name, input_names): + input_names = [t.codegen_reference() for t in self.original_inputs] + wrapper.writeline(f"{output_name}_inputs = [{','.join(input_names)}]") + wrapper.writeline(f"{output_name} = [{','.join(x.name for x in self.outputs)}]") + + @classmethod + def create_output_buffers(cls, inputs, size_cb=None): + outputs = [] + for input in inputs: + new_size = input.get_size() + if size_cb is not None: + size_cb(new_size) + # new_size[0] *= group_size + + buff = OutputBuffer( + layout=FlexibleLayout( + device=input.get_device(), + dtype=input.get_dtype(), + size=new_size, + ), + ) + outputs.append(buff) + return outputs + + @classmethod + def create_output_nodes(cls, coll, output_buffers): + return [ + MultiOutputNoSizeAssert( + out_t.layout, + coll, + f"[{i}]", + ) + for i, out_t in enumerate(output_buffers) + ] + + +class InPlaceHint(ExternKernel): + """ + Helper OP to encode an in/out argument that tries to make it inplace whenever possible. + Wrap the input of your inplace op to enable this behavior. + + The design is based on two key decisions: + - this node is responsible for allocating the in/out buffer used by the collective. + This is controlled by the ``should_allocate`` method that returns True here and + False for the collective node + - The scheduler special-case this node and enable it to reuse its input. + """ + + def codegen(self, wrapper): + input_name = self.inputs[0].codegen_reference() + output_name = self.get_name() + if not wrapper.did_reuse(self, self.inputs[0]): + wrapper.writeline(f"{output_name}.copy_({input_name}) #no reuse") + + def __init__(self, layout, input): + input = self.realize_input(input) + super().__init__(None, layout, self.unwrap_storage([input]), ()) + self.name = V.graph.register_buffer(self) + + def should_allocate(self): + return True + + +class OutputBuffer(ExternKernel): + """ + Represent the output buffer used by ops that require multiple of them + """ + + def __init__(self, layout): + super().__init__(name=None, layout=layout, inputs=[]) + self.name = V.graph.register_buffer(self) + + def should_allocate(self): + return True + + def codegen(self, wrapper): + wrapper.writeline(f"# collective out buffer {self.name}") + + +class MultiOutputNoSizeAssert(MultiOutput): + """ + Extract partial output from a multi-output OP. + Works like MultiOutput but doesn't assert size. This must be a property guaranteed by the op emitting this. + """ + + def __init__(self, layout, input, index): + super().__init__(layout, input, []) + self.index = index + + def codegen(self, wrapper): + wrapper.writeline( + f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}" + ) + + +class Broadcast(InPlaceCollectiveKernel): + def __init__(self, layout, inputs, constant_args, src): + super().__init__(layout, inputs, constant_args) + self.src = src + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, x: "TensorBox", src: int, tag: str, ranks: List[int], group_size: int + ): + inplace_inputs = cls.wrap_inputs_as_inplace([x]) + packed = Broadcast( + layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type] + inputs=inplace_inputs, + constant_args=[tag, ranks, group_size], + src=src, + ) + mark_node_as_mutating(packed, inplace_inputs[0]) + return inplace_inputs[0] + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = dist.broadcast(" + f"{output_name}, async_op=True, group={output_name}_pg, src={self.src})" + ) + + +class AllReduceCoalesced(InPlaceCollectiveKernel): + def __init__(self, layout, inputs, constant_args, reduce_op): + super().__init__(layout, inputs, constant_args) + self.reduce_op = reduce_op + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, + inputs: List["TensorBox"], + reduce_op: str, + tag: str, + ranks: List[int], + group_size: int, + ): + inplace_inputs = cls.wrap_inputs_as_inplace(inputs) + packed = AllReduceCoalesced( + layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type] + inputs=inplace_inputs, + constant_args=[tag, ranks, group_size], + reduce_op=reduce_op, + ) + mark_node_as_mutating(packed, inplace_inputs[0]) + return inplace_inputs + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = dist.all_reduce_coalesced(" + f"{output_name}, " + f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), " + f"group={output_name}_pg, " + "async_op=True)" + ) + + +class AllReduce(InPlaceCollectiveKernel): + def __init__(self, layout, inputs, constant_args, reduce_op): + super().__init__(layout, inputs, constant_args) + self.reduce_op = reduce_op + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int + ): + inplace_inputs = cls.wrap_inputs_as_inplace([x]) + + packed = AllReduce( + layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type] + inputs=inplace_inputs, + constant_args=[tag, ranks, group_size], + reduce_op=reduce_op, + ) + mark_node_as_mutating(packed, inplace_inputs[0]) + return inplace_inputs[0] + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = dist.all_reduce(" + f"{output_name}, async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))" + ) + + +class AllGatherIntoTensor(OutOfPlaceCollectiveKernel): + def __init__(self, layout, inputs, outputs, constant_args): + super().__init__(layout, inputs, outputs, constant_args) + + @classmethod + def create(cls, x: "TensorBox", tag: str, ranks: List[int], group_size: int): + inputs = [cls.realize_input(x)] + + def compute_size(new_size): + new_size[0] *= group_size + + outputs = cls.create_output_buffers(inputs, compute_size) + + layout = MultiOutputLayout(inputs[0].get_device()) + + packed = AllGatherIntoTensor( + layout=layout, + inputs=inputs, + outputs=outputs, + constant_args=[tag, ranks, group_size], + ) + return cls.create_output_nodes(packed, outputs)[0] + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = dist.all_gather_into_tensor(" + f"{output_name}[0], {output_name}_inputs[0], async_op=True, group={output_name}_pg)" + ) + + +class ReduceScatterTensor(OutOfPlaceCollectiveKernel): + def __init__(self, layout, inputs, outputs, constant_args, reduce_op): + super().__init__(layout, inputs, outputs, constant_args) + self.reduce_op = reduce_op + + @classmethod + def create( + cls, + x: "TensorBox", + reduce_op: str, + tag: str, + ranks: List[int], + group_size: int, + ): + inputs = [cls.realize_input(x)] + + def compute_size(new_size): + new_size[0] //= group_size + + outputs = cls.create_output_buffers(inputs, compute_size) + + layout = MultiOutputLayout(inputs[0].get_device()) + + packed = ReduceScatterTensor( + layout=layout, + inputs=inputs, + outputs=outputs, + constant_args=[tag, ranks, group_size], + reduce_op=reduce_op, + ) + return cls.create_output_nodes(packed, outputs)[0] + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = dist.reduce_scatter_tensor(" + f"{output_name}[0], {output_name}_inputs[0], " + f"async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))" + ) + + +class AllGatherIntoTensorCoalesced(OutOfPlaceCollectiveKernel): + def __init__(self, layout, inputs, outputs, constant_args): + super().__init__(layout, inputs, outputs, constant_args) + + @classmethod + def create( + cls, + inputs: List["TensorBox"], + tag: str, + ranks: List[int], + group_size: int, + ): + inputs = [cls.realize_input(x) for x in inputs] + + def compute_size(new_size): + new_size[0] *= group_size + + outputs = cls.create_output_buffers(inputs, compute_size) + + layout = MultiOutputLayout(inputs[0].get_device()) + + packed = AllGatherIntoTensorCoalesced( + layout=layout, + inputs=inputs, + outputs=outputs, + constant_args=[tag, ranks, group_size], + ) + + return outputs + # return cls.create_output_nodes(packed, outputs) + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = fun_col_impl._all_gather_into_tensor_coalesced_fallback(" + f"output_tensors={output_name}, " + f"input_tensors={output_name}_inputs, " + f"group={output_name}_pg, " + "async_op=True)" + ) + + +class ReduceScatterTensorCoalesced(OutOfPlaceCollectiveKernel): + def __init__(self, layout, inputs, outputs, constant_args, reduce_op): + super().__init__(layout, inputs, outputs, constant_args) + self.reduce_op = reduce_op + + @classmethod + def create( + cls, + inputs: List["TensorBox"], + reduce_op: str, + tag: str, + ranks: List[int], + group_size: int, + ): + inputs = [cls.realize_input(x) for x in inputs] + + def compute_size(new_size): + new_size[0] //= group_size + + outputs = cls.create_output_buffers(inputs, compute_size) + + layout = MultiOutputLayout(inputs[0].get_device()) + + _ = ReduceScatterTensorCoalesced( + layout=layout, + inputs=inputs, + outputs=outputs, + constant_args=[tag, ranks, group_size], + reduce_op=reduce_op, + ) + + return outputs + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = fun_col_impl._reduce_scatter_tensor_coalesced_fallback(" + f"output_tensors={output_name}, " + f"input_tensors={output_name}_inputs, " + f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), " + f"group={output_name}_pg, " + "async_op=True)" + ) + + +# TODO(yifu): replace the CollectiveKernel IR hierarchy with _CollectiveKernel. +class _CollectiveKernel(FallbackKernel): + def should_allocate(self): + return False + + def has_side_effects(self): + return True + + # This is identical to FallbackKernel.set_cpp_kernel(), minus the + # part that checks against input aliasing and mutation. + def set_cpp_kernel(self, kernel): + from .codegen.wrapper import get_cpp_op_schema + + self.cpp_kernel_name = kernel._schema.name + self.cpp_kernel_overload_name = kernel._schema.overload_name + self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] + + self.cpp_op_schema = get_cpp_op_schema(kernel) + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in kernel._schema.arguments if x.kwarg_only + ] + + # NOTE: [In-Place Collective Safety] + # Between the initiation and completion of an in-place collective, the + # input buffers are subject to both volatile reads and volatile writes. + # They must not be read, written to or reused by another kernel. To ensure + # the constraints, we model collective -> wait_tensor as as two-step + # mutation of the input buffers. + @classmethod + def create_inplace( + cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs + ) -> None: + cpp_kernel_name = kernel._name + python_kernel_name = cpp_kernel_name.replace("::", ".") + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + for tensor_arg in tensor_args: + tensor_arg.realize() + + packed = cls( + NoneLayout(tensor_args[0].get_device()), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.cpp_kernel_name = cpp_kernel_name + packed.python_kernel_name = python_kernel_name + + def mark_mutation(x): + if isinstance(x.data, BaseView): + x = x.data.unwrap_view() + MutationOutput(x.layout, x, packed) + + pytree.tree_map(lambda inp: mark_mutation(inp), inputs) + + # NOTE: [Out-of-Place Collective Safety] + # Between the initiation and completion of an out-of-place collective: + # + # Input buffers: + # - Are subject to volatile reads + # - Can be read by another kernel + # - Must not be written to or reused by another kernel + # + # Output buffers: + # - Are subject to volatile writes + # - Must not be read, written to or reused by another kernel + # + # To ensure the safety of input buffers without sacrificing read + # availability, we add input buffers as read deps of wait_tensor kernels. + # + # To ensure the safety of output buffers, we model wait_tensor as a + # mutation to the output buffer. Note we also assumes the user program being + # correct and the output buffer is not consumed by kernels other than + # wait_tensor. + # + # TODO(yifu): add a pre-grad pass to validate the correctness of collective + # usage in the user program. + @classmethod + def create_out_of_place( + cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs + ): + cpp_kernel_name = kernel._name + python_kernel_name = cpp_kernel_name.replace("::", ".") + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + for tensor_arg in tensor_args: + tensor_arg.realize() + + if isinstance(example_output, list): + device = cls.find_device(tensor_args, example_output) + packed = cls( + MultiOutputLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.cpp_kernel_name = cpp_kernel_name + packed.python_kernel_name = python_kernel_name + packed.outputs = [ + MultiOutput( + cls.tensor_to_layout(tensor), + packed, + [(list, i)], + ) + for i, tensor in enumerate(example_output) + ] + return packed.outputs + else: + packed = cls( + cls.tensor_to_layout(example_output), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.cpp_kernel_name = cpp_kernel_name + packed.python_kernel_name = python_kernel_name + packed.outputs = [packed] + return packed + + +class _WaitKernel(_CollectiveKernel): + def get_volatile_reads(self): + inp = self.inputs[0] + if isinstance(inp, _CollectiveKernel): + # Out-of-place single-output + return [inp.inputs[0]] + elif isinstance(inp, MultiOutput): + # This can be two things: + # 1. Out-of-place multi-output coll + # 2. In-place coll with inputs coming from another MultiOutput + coll = inp.inputs[0] + # Case 1 + if isinstance(coll, _CollectiveKernel): + _, idx = inp.indices[0] + return [coll.inputs[idx]] + # Case 2 + return [] + else: + # In-place requires no additional deps handling for volatile + # reads since the inputs are mutated. + return [] + + @classmethod + def create_wait(cls, kernel, inp: TensorBox) -> None: + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + ) = cls.process_kernel(kernel, inp) + packed = cls( + NoneLayout(inp.get_device()), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + if isinstance(inp.data, BaseView): + inp = inp.data.unwrap_view() + MutationOutput(inp.layout, inp, packed) + + def get_read_writes(self): + read_writes = super().get_read_writes() + # See [Out-of-Place Collective Safety]. + volatile_reads = self.get_volatile_reads() + for vr in volatile_reads: + read_writes.reads.add(dependencies.StarDep(vr.get_name())) + return read_writes + + +# NB: recursive structure here reflects val_to_arg_str, avoid +# calling free_unbacked_symbols on "exotic" types that don't get pexpr +# treatment +def maybe_free_unbacked_symbols(s): + if isinstance(s, (SymTypes, sympy.Expr)): + # This branch should be impossible in return position + return free_unbacked_symbols(s) + elif isinstance(s, (tuple, list)): + r = set() + for t in s: + r |= maybe_free_unbacked_symbols(t) + return r + elif isinstance(s, torch.Tensor): + # This branch is impossible in constant-args position + return free_unbacked_symbols(s) + else: + return set() + + +class AllToAllSingle(OutOfPlaceCollectiveKernel): + def __init__( + self, + layout, + inputs, + outputs, + constant_args, + output_split_sizes, + input_split_sizes, + ): + super().__init__(layout, inputs, outputs, constant_args) + self.output_split_sizes = output_split_sizes + self.input_split_sizes = input_split_sizes + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + r = set() + if self.output_split_sizes is not None: + r |= free_unbacked_symbols(self.output_split_sizes) + if self.input_split_sizes is not None: + r |= free_unbacked_symbols(self.input_split_sizes) + return r + + @classmethod + def create( + cls, + x: "TensorBox", + output_split_sizes: Optional[List[Expr]], + input_split_sizes: Optional[List[Expr]], + tag: str, + ranks: List[int], + group_size: int, + ): + inputs = [cls.realize_input(x)] + + def compute_size(new_size): + if output_split_sizes is not None: + new_size[0] = sum(output_split_sizes) + + outputs = cls.create_output_buffers(inputs, compute_size) + + layout = MultiOutputLayout(inputs[0].get_device()) + + packed = AllToAllSingle( + layout=layout, + inputs=inputs, + outputs=outputs, + constant_args=[tag, ranks, group_size], + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + ) + return cls.create_output_nodes(packed, outputs)[0] + + def codegen_collective(self, wrapper, output_name, input_names): + tag, ranks, group_size = self.constant_args + + # TODO: might be necessary to do some pretty printing on + # split sizes + wrapper.writeline( + f"{output_name}_work = dist.all_to_all_single(" + f"{output_name}[0], {output_name}_inputs[0], " + f"output_split_sizes={self.output_split_sizes}, " + f"input_split_sizes={self.input_split_sizes}, " + f"group={output_name}_pg, async_op=True)" + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..c48251fc352d66ba8ef51b6ee12830e216c71a24 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py @@ -0,0 +1,1524 @@ +from __future__ import annotations + +import dataclasses +import functools +import inspect +import itertools +import logging +import operator +import os +import re +from collections import defaultdict +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Iterable, + List, + NoReturn, + Optional, + Set, + Union, +) + +from typing_extensions import TypeGuard + +import torch +import torch._guards +import torch.fx +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import counters +from torch._prims_common import is_integer_dtype +from torch.fx import Node +from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode +from torch.fx.experimental.symbolic_shapes import guard_size_oblivious +from torch.fx.immutable_collections import immutable_dict, immutable_list + +from .._functorch import config as functorch_config +from .._functorch.aot_autograd import aot_function, make_boxed_func +from .._functorch.partitioners import default_partition +from .._subclasses import FakeTensorMode +from ..fx import Transformer +from . import config +from .decomposition import select_decomp_table +from .lowering import fallback_node_due_to_unsupported_type + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +Constant = Any +NodeOrConstant = Union[Constant, torch.fx.Node] + + +class Multiple: + pass + + +# Sentinel indicating multiple quantities can be matched +MULTIPLE = Multiple() + + +class Match: + """ + Represents a successfully matched pattern. + """ + + def __init__(self, pattern: PatternExpr, args=None, kwargs=None): + super().__init__() + self.pattern = pattern + # The input nodes that must be passed in to the result + self.args = args or [] + self.kwargs = kwargs or {} + # The nodes matched in this expression + self.nodes: List[torch.fx.Node] = [] + # Mapping CallFunction to the node.target + self.targets: Dict[_TargetExpr, torch.fx.node.Target] = {} + self.ctx: Optional[MatchContext] = None + self.replacement_graph: Optional[torch.fx.Graph] = None + + @property + def graph(self) -> torch.fx.Graph: + assert self.ctx + return self.ctx.graph + + def extend(self, other: Match): + if self.kwargs: + for key in set(self.kwargs.keys()) & set(other.kwargs.keys()): + if self.kwargs[key] != other.kwargs[key]: + raise FailedMatch("kwarg mismatch: {}", key) + self.args.extend(other.args) + self.nodes.extend(other.nodes) + self.kwargs.update(other.kwargs) + self.targets.update(other.targets) + + def bundle(self) -> Match: + # Wrap args in an extra list + self.args = [tuple(self.args)] if self.args else [] + return self + + def __repr__(self): + return f"Match(..., {self.args}, {self.kwargs})" + + def erase_nodes(self, graph: torch.fx.Graph): + for n in reversed(self.nodes): + if not n._erased: + graph.erase_node(n) + + def output_nodes(self) -> List[Optional[torch.fx.Node]]: + assert self.ctx + return [ + (self.ctx.pattern_to_node[p] if p is not None else None) + for p in self.ctx.outputs + ] + + def output_node(self) -> torch.fx.Node: + return next(p for p in self.output_nodes() if p) + + def replace_with_graph(self, replacement_graph, args): + assert self.ctx + ReplacementPatternEntry.replace_with_graph( + self, self.ctx.graph, replacement_graph, args + ) + + def replace_by_example(self, replacement_fn, args, trace_fn=None, run_dce=True): + assert self.ctx + if trace_fn is None: + trace_fn = functools.partial(fwd_only, run_dce=run_dce) + replacement = trace_fn( + replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + ) + ReplacementPatternEntry.replace_with_graph( + self, + self.ctx.graph, + replacement, + args, + ) + + +class FailedMatch(RuntimeError): + def __init__(self, format_string, *args, **kwargs): + self.format_string = format_string + # We want to construct error messages lazily instead of eagerly, as + # constructing them eagerly can significantly worsen compile times. + if len(format_string) > 200: + raise RuntimeError( + f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}" + ) + self.args = args + self.kwargs = kwargs + + def __str__(self): + return self.format_string.format(*self.args, **self.kwargs) + + def __bool__(self): + return False + + +def is_match(m: Union[Match, FailedMatch]) -> TypeGuard[Match]: + """ + TypeGuards cannot act on `self`. Thus this function exists to let mypy + recognize FailedMatch.__bool__ as a TypeGuard. + """ + return bool(m) + + +class MatchContext: + """ + State needed while running PatternExpr._match(). + """ + + def __init__( + self, + outputs: List[Optional[PatternExpr]], + pattern_to_node: Optional[Dict[PatternExpr, Node]] = None, + *, + graph: torch.fx.Graph, + ): + self.outputs = outputs + self.pattern_to_node = {} if pattern_to_node is None else pattern_to_node + self.graph = graph + self.exclusive_node_set: List[NodeOrConstant] = [] + + def match(self, pattern, node): + """wrapper to check reused nodes in patterns""" + if pattern in self.pattern_to_node: + if self.pattern_to_node[pattern] == node: + return Match(pattern) # already checked this node + else: + return FailedMatch("repeated pattern differs") + m = pattern._match(node, self) + assert pattern not in self.pattern_to_node + self.pattern_to_node[pattern] = node if m else None + m.ctx = self + return m + + def filter_multi_user_patterns(self): + return { + pattern: node + for pattern, node in self.pattern_to_node.items() + if pattern.has_multiple_users() and node is not None + } + + +class PatternExpr: + """ + Base class for types of patterns + """ + + def _match( + self, node: torch.fx.Node, ctx: MatchContext + ) -> Union[Match, FailedMatch]: + raise NotImplementedError() + + def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: + try: + return MatchContext([self], graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + def has_multiple_users(self) -> bool: + return False + + def __repr__(self): + return self.__class__.__name__ + "()" + + def find_anchor_nodes(self, ctx: MatchContext, searched): + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + + +class Arg(PatternExpr): + """ + Capture an arg which will become an input to the handler. Args are + passed in depth first order. + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext): + return Match(self, args=[node]) # matches anything + + +class Ignored(PatternExpr): + """ + Match an arg, but don't pass it to handler + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext): + return Match(self) # matches anything + + def __repr__(self): + return "*" + + def pretty_print(self, pp: PatternPrettyPrinter): + return "Ignored()" + + +class KeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + def __init__(self, name: str): + super().__init__() + self.name = name + + def __repr__(self): + return f"KeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext): + return Match(self, kwargs={self.name: node}) # matches anything + + +class ExclusiveKeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + def __init__(self, name): + super().__init__() + self.name = name + + def __repr__(self): + return f"ExclusiveKeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext): + if node in ctx.exclusive_node_set: + return FailedMatch("exclusive arg appears twice") + + ctx.exclusive_node_set.append(node) + return Match(self, kwargs={self.name: node}) # matches anything + + +class _TargetExpr(PatternExpr): + """ + Base class for filtering match by node.target + """ + + op: Optional[str] = None + + def __init__(self, fns, users=1): + if not self.op: + raise NotImplementedError("Shouldn't directly use _BaseNodeMatch") + super().__init__() + fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns) + for fn in list(fns): + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + + self.fns: List[Union[Callable[..., Any], str]] = fns + self.fns_set: Set[Union[Callable[..., Any], str]] = set(fns) + self.users: Union[int, Multiple] = users + + def fns_repr(self) -> str: + first_repr = self.fns[0] + if not isinstance(first_repr, str): + first_repr = first_repr.__name__ + + if len(self.fns) > 1: + return f"[{first_repr}, ...]" + elif self.fns[0] is getattr(torch, first_repr, None): + return f"torch.{first_repr}" + elif isinstance(self.fns[0], torch._ops.OpOverload): + return str(self.fns[0]) + else: + return first_repr + + def __repr__(self): + return f"{self.__class__.__name__}({self.fns_repr()})" + + def has_multiple_users(self) -> bool: + return isinstance(self.users, Multiple) or self.users > 1 + + def find_anchor_nodes(self, ctx: MatchContext, searched): + raise NotImplementedError() + + def _match_fns(self, node: torch.fx.Node): + return ( + isinstance(node, torch.fx.Node) + and node.op == self.op + and extract_target(node) in self.fns_set + ) + + def _match_users(self, node: torch.fx.Node, ctx: MatchContext): + return ( + self in ctx.outputs + or self.users is MULTIPLE + or len(node.users) == self.users + ) + + +class _TargetArgsExpr(_TargetExpr): + """ + Base class for filtering match by node.{target,args,kwargs} + """ + + def __init__(self, fns, *args, _users=1, **kwargs): + super().__init__(fns, _users) + self.args = tuple(args) + self.kwargs = dict(kwargs) + if any( + isinstance(x, (dict, list, tuple)) + for x in itertools.chain(args, kwargs.values()) + ): + self.flatten = self.pytree_flatten + else: + self.flatten = self.simple_flatten + self.flat_args_kwargs = self.flatten(self.args, self.kwargs) + + @staticmethod + def simple_flatten(args, kwargs: Dict[Any, Any]): + return (*args, *kwargs.values()), (len(args), *kwargs.keys()) + + @staticmethod + def pytree_flatten(args, kwargs: Dict[Any, Any]): + def norm_spec(s: pytree.TreeSpec): + if s.type is None: + return s + mapping = {immutable_list: list, tuple: list, immutable_dict: dict} + return pytree.TreeSpec( + mapping.get(s.type, s.type), + s.context, + list(map(norm_spec, s.children_specs)), + ) + + flat, spec = pytree.tree_flatten([args, kwargs]) + spec = norm_spec(spec) + return flat, spec + + def __repr__(self): + args = [ + self.fns_repr(), + *map(repr, self.args), + *[f"{k}={v}" for k, v in self.kwargs.items()], + ] + return f"{self.__class__.__name__}({', '.join(args)})" + + def pretty_print(self, pp: PatternPrettyPrinter): + args = [ + self.fns_repr(), + *(pp.pretty_print(x) for x in self.args), + *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()], + ] + if isinstance(self.users, Multiple): + args.append("_users=MULTIPLE") + elif self.users > 1: + args.append(f"_users={self.users}") + + joiner_str = ", " + return f"{self.__class__.__name__}({joiner_str.join(args)})" + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + if not self._match_fns(node) or len(node.args) != len(self.args): + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users {}", self) + + _args = node.args + _kwargs = node.kwargs + if len(_kwargs) < len(self.kwargs): + from torch.fx.operator_schemas import normalize_function + + normalized_args_and_kwargs = normalize_function( + node.target, node.args, node.kwargs + ) + + if normalized_args_and_kwargs is None: + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + else: + _args, _kwargs = normalized_args_and_kwargs + if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs): + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + else: + return FailedMatch( + "function_mismatch: node={}, pattern={}", node, self + ) + else: + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + + node_items, node_spec = self.flatten(_args, _kwargs) + self_items, self_spec = self.flat_args_kwargs + if node_spec != self_spec: + return FailedMatch("args_structure {} {}", node_spec, self_spec) + assert len(node_items) == len(self_items) + + m = Match(self) + for i, pattern, child_node in zip(itertools.count(), self_items, node_items): + if isinstance(pattern, PatternExpr): + child_match = ctx.match(pattern, child_node) + if not child_match: + return child_match + m.extend(child_match) + elif isinstance(child_node, torch.fx.Node) or child_node != pattern: + return FailedMatch( + "constant_args: {} {!r}!={pattern!r}", node, child_node + ) + m.nodes.append(node) + m.targets[self] = node.target + return m + + def find_anchor_nodes(self, ctx: MatchContext, searched): + """ + This is used when we are matching a pattern with multiple outputs. + There is a partial match (stored in ctx) and we want to walk + this pattern to find a connection to an already-matched node. + + Yields candidate nodes that `self._match` might like. + """ + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + return + + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + +class CallFunction(_TargetArgsExpr): + """ + Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)` + """ + + op = "call_function" + + +class CallMethod(_TargetArgsExpr): + """ + Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)` + """ + + op = "call_method" + + +class CallModule(_TargetArgsExpr): + """ + Matches a call_module node in the FX graphs: `module(*args, **kwargs)` + """ + + op = "call_module" + + +class _TargetExprVarArgs(_TargetExpr): + """ + Matches a call_function node with any arguments which are passed into the pattern + """ + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + if not self._match_fns(node): + return FailedMatch("function_mismatch") + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users") + + m = Match(self) + m.nodes.append(node) + m.targets[self] = node.target + m.args.extend(node.args) + m.kwargs.update(node.kwargs) + return m + + +class CallFunctionVarArgs(_TargetExprVarArgs): + op = "call_function" + + +class CallMethodVarArgs(_TargetExprVarArgs): + op = "call_method" + + +class CallModuleVarArgs(_TargetExprVarArgs): + op = "call_module" + + +class ListOf(PatternExpr): + """ + Matches a repeated pattern + """ + + def __init__(self, pattern: PatternExpr, partial=False): + super().__init__() + assert isinstance(pattern, PatternExpr) + self.pattern = pattern + self.partial = partial + + def __repr__(self): + return f"{self.__class__.__name__}({self.pattern})" + + def _match(self, node: List[torch.fx.Node], ctx: MatchContext): # type: ignore[override] + if not isinstance(node, (list, tuple)) or len(node) == 0: + return FailedMatch("non_list") + m = Match(self) + # Propagating patterns with multiple users will ensure we don't revisit + # the same nodes + pattern_to_node = ctx.filter_multi_user_patterns() + matched = False + for i, child_node in enumerate(node): + child_ctx = MatchContext( + ctx.outputs, pattern_to_node, graph=child_node.graph + ) + child_match = child_ctx.match(self.pattern, child_node) + pattern_to_node = child_ctx.filter_multi_user_patterns() + if not child_match: + if not self.partial: + return FailedMatch("list[{}]: {}", i, child_match) + continue + matched = True + m.extend(child_match.bundle()) + if not matched: + return FailedMatch("list: no_match") + return m.bundle() + + +class MultiOutputPattern(PatternExpr): + def __init__(self, outputs): + super().__init__() + assert all(isinstance(x, (PatternExpr, type(None))) for x in outputs), outputs + self.outputs: List[Optional[PatternExpr]] = outputs + + @property + def fns(self): + assert self.outputs[0] and hasattr(self.outputs[0], "fns") + return self.outputs[0].fns + + def __repr__(self): + return f"{self.__class__.__name__}({self.outputs})" + + def pretty_print(self, pp: PatternPrettyPrinter): + args = [pp.pretty_print(x) for x in self.outputs] + joiner_str = f",\n{' '}" + str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}" + str_out = f"{str_out}\n])" + return str_out + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = ctx.match(self.outputs[0], node) + if not m: + return m + + for pattern in self.outputs[1:]: + if pattern is None: + continue + child_match = self._match_from_anchors(pattern, ctx) + if not child_match: + return child_match + m.extend(child_match) + + return m + + def _match_from_anchors(self, pattern, ctx): + prior = dict(ctx.pattern_to_node) + m = FailedMatch("no anchor found") + for node in pattern.find_anchor_nodes(ctx, set()): + m = ctx.match(pattern, node) + if m: + return m + # revert any partial matches + ctx.pattern_to_node = dict(prior) + return m + + def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: + try: + return MatchContext(self.outputs, graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + +class RepeatedExpr(PatternExpr): + """ + Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind` + """ + + def __init__(self, inner_pattern: PatternExpr): + super().__init__() + assert hasattr(inner_pattern, "fns") + self.inner_pattern = inner_pattern + + @property + def fns(self): + return self.inner_pattern.fns + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = ctx.match(self.inner_pattern, node) + if not m: + return m + ctx.pattern_to_node.pop( + self.inner_pattern, + ) + # Check all anchor nodes match the pattern + for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()): + anchor_m = MatchContext([self], graph=node.graph).match( + self.inner_pattern, anchor_node + ) + if not anchor_m: + return anchor_m + m.extend(anchor_m) + return m + + +class PatternPrettyPrinter: + """ + Serializes Patterns to executable python. + XXX: currently only used and tested for fuse attention patterns. May not cover + all patterns. + """ + + def __init__(self): + self.namespace = torch.fx.graph._Namespace() + self.memoized_objs_names: Dict[PatternExpr, str] = {} + self.memoized_objs_pp: Dict[PatternExpr, str] = {} + + @staticmethod + def run(obj: PatternExpr, output_name="output"): + """ + Serializes obj to python code with obj written out to `output_name` + """ + + pp = PatternPrettyPrinter() + assert hasattr(obj, "pretty_print") + out_str = obj.pretty_print(pp=pp) + + output = [] + for key in pp.memoized_objs_names: + output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}") + + output.append(f"{output_name} = {out_str}") + + return "\n".join(output) + + def pretty_print(self, obj): + if isinstance(obj, _TargetArgsExpr): + if memoized_name := self.memoized_objs_names.get(obj): + return memoized_name + else: + return self.memoize(obj) + if hasattr(obj, "pretty_print"): + return obj.pretty_print(self) + + return repr(obj) + + def memoize(self, obj): + obj_str = obj.pretty_print(self) + obj_name = obj.fns_repr() + for prefix in ("aten.", "torch.", "prims."): + obj_name = obj_name.replace(prefix, "") + + tmp_name = self.namespace.create_name(obj_name, None) + self.memoized_objs_names[obj] = tmp_name + self.memoized_objs_pp[obj] = obj_str + return tmp_name + + +@dataclasses.dataclass +class PatternEntry: + pattern: PatternExpr + extra_check: Callable[[Match], bool] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + raise NotImplementedError() + + def register(self, pass_dicts, target=None, prepend=False): + if target is None: + assert hasattr(self.pattern, "fns") + for fn in self.pattern.fns: + self.register(pass_dicts, fn, prepend=prepend) + elif isinstance(pass_dicts, (dict, PatternMatcherPass)): + if prepend: + pass_dicts[target].insert(0, self) + else: + pass_dicts[target].append(self) + else: + for x in pass_dicts: + self.register(x, target, prepend=prepend) + + +@dataclasses.dataclass +class LoweringPatternEntry(PatternEntry): + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) + with graph.inserting_before(node): + replacement = graph.call_function(handler, tuple(match.args), match.kwargs) + replacement.meta.update(node.meta) + node.replace_all_uses_with(replacement) + assert match.nodes[-1] is node + match.erase_nodes(graph) + + +@dataclasses.dataclass +class GraphPatternEntry(PatternEntry): + """ + A pattern that runs a function on the FX graph + """ + + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + with graph.inserting_before(node): + self.handler(match, *match.args, **match.kwargs) + + +@dataclasses.dataclass +class ReplacementPatternEntry(PatternEntry): + normalize_args: Callable[..., List[Any]] + + @staticmethod + def replace_with_graph( + match: Match, + graph: torch.fx.Graph, + replacement_graph: torch.fx.Graph, + args: List[Any], + ): + output_nodes = match.output_nodes() + first_node = output_nodes[0] + + class Replacer(torch.fx.Interpreter): + call_method = None # type: ignore[assignment] + call_module = None # type: ignore[assignment] + get_attr = None # type: ignore[assignment] + + def run_node(self, node) -> Any: + if node.op in ("placeholder", "output"): + return super().run_node(node) + if node.op == "call_function": + target = node.target + args, kwargs = self.fetch_args_kwargs_from_env(node) + result = graph.call_function(target, args, kwargs) + if "val" in node.meta and "val" not in result.meta: + result.meta["val"] = node.meta["val"] + if isinstance(node.meta["val"], torch.Tensor): + assert "tensor_meta" in node.meta + result.meta["tensor_meta"] = node.meta["tensor_meta"] + return result + raise NotImplementedError(f"unhandled {node}") + + output_nodes = match.output_nodes() + + if len(output_nodes) == 1: + last_node = output_nodes[0] + else: + assert output_nodes[0] + nodes = list(output_nodes[0].graph.nodes) + indices = [ + (nodes.index(n), n) + for n in output_nodes + if isinstance(n, torch.fx.Node) + ] + last_node = min(indices, key=lambda tup: tup[0])[1] + + def percolate_tags(node, recompute_tag, input_stops): + queue = [node] + visited = set() + + while queue: + arg = queue.pop() + if ( + arg not in visited + and arg not in input_stops + and hasattr(arg, "meta") + ): + visited.add(arg) + arg.meta["recompute"] = recompute_tag + queue.extend(arg.all_input_nodes) + + with graph.inserting_before(last_node): + replacement = Replacer(replacement_graph).run(*args) + if isinstance(replacement, torch.fx.Node): + replacement = [replacement] + + def maybe_getitem(node): + if node.op != "call_function": + return None + if node.target != operator.getitem: + return None + assert len(node.args) == 2 + return node.args[1] + + def replace(old, new): + if old is None: + assert new is None + return + assert isinstance(old, torch.fx.Node) + if new is None: + old.replace_all_uses_with(None) + graph.erase_node(old) + return + if isinstance(new, torch.fx.Node): + if "val" not in new.meta: + new.meta.update(old.meta) + + # Preserve the recompute tags in the replacement graph. We + # look at the recompute tags of the original output node to + # propagate the tag from the output all the way to the input + # args (named as args in the replace_with_graph). + # Note that this is best effort. Since patterns are from + # many to many, there is no easy way to correctly map the + # recomputable tags. It is possible in some scenarios that we + # incorrectly tag some nodes as recomputables. + if "recompute" in old.meta: + percolate_tags(new, old.meta["recompute"], args) + + old.replace_all_uses_with(new) + graph.erase_node(old) + return + + # `new` is not a node: it's a list of nodes. + # + # This happens when we want to replace a node that has a single + # packed return with multiple unpacked returns. We need to do + # some graph surgery here. + # + # Example: + # def original_graph(x): + # a = op(x) + # b = a[0] + # c = a[1] + # ... + # + # Assume that we want to replace op(x) with the graph + # def new_op(x): + # w = x + 1 + # z = x + 2 + # return (w, z) + # + # We need to replace `op` with the contents of `new_op`, + # and then rewrite a[0] to be w and a[1] to be z, as so: + # def new_graph(x): + # w = x + 1 + # z = x + 2 + # b = w + # c = z + # ... + old_uses = list(old.users.keys()) + for user in old_uses: + idx = maybe_getitem(user) + if idx is None: + raise AssertionError("can't handle") + replace(user, new[idx]) + graph.erase_node(old) + + if len(output_nodes) == len(replacement): + for old, new in zip(output_nodes, replacement): + replace(old, new) + else: + assert len(output_nodes) == 1 + replace(output_nodes[0], replacement) + + match.erase_nodes(graph) + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + self.replace_with_graph( + match, + graph, + match.replacement_graph, # type: ignore[arg-type] + self.normalize_args(*match.args, **match.kwargs), + ) + + +def _return_true(match): + return True + + +def log_trace_failure(search_fn, e): + log.info( + "Replacement pattern %s failed to apply due to shape mismatch: %s", + search_fn.__name__, + e, + ) + + +def register_replacement( + search_fn, + replace_fn, + example_inputs: Iterable[Any], + trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], + pass_dicts, + extra_check=_return_true, + scalar_workaround=(), + exclusive_arg_names=(), + search_fn_pattern=None, +): + """ + Create a replacement rule based on example functions that get traced + to create patterns. This supports both training and inference when + run on a joint forward+backward graph. + + Args: + search_fn: traced to give original pattern + replace_fn: traced to give replacement graph + example_inputs: example inputs for initial trace + trace_fn: fwd_only or joint_fwd_bwd + pass_dict: dict of passes to register to + extra_check: additional check to run on match(using real shapes) + """ + argnames_static = [*inspect.signature(search_fn).parameters.keys()] + + def check_fn(match: Match): + """ + Often shapes get burned into the pattern, so our initial match ran with + `ignore_types=(int, ...)`. + + Recheck the match with the correct shapes. + """ + argnames = list(argnames_static) + for name in argnames: + if name not in match.kwargs: + raise RuntimeError( + f"Not all inputs to pattern found in match.kwargs. Perhaps one " + f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}" + ) + + args = list( + torch.fx.map_arg( + [match.kwargs[name] for name in argnames], lambda n: n.meta["val"] + ) + ) + sym_args: List[torch.SymInt] = [] + with torch._dynamo.utils.detect_fake_mode(args): + for i, grad in enumerate(requires_grad): + if isinstance(args[i], torch.Tensor): + if grad and is_integer_dtype(args[i].dtype): + return False + + args[i] = torch.empty_strided( + args[i].size(), + args[i].stride(), + dtype=args[i].dtype, + device=args[i].device, + requires_grad=grad, + ) + for v in itertools.chain(args[i].shape, args[i].stride()): + if isinstance(v, torch.SymInt) and all( + guard_size_oblivious(v != a) for a in sym_args + ): + sym_args.append(v) + + if sym_args: + # AOT Autograd and make fx will dedupe symbolic shape size + # accesses of sym ints that appear as inputs + # We don't want the sym_size uses to interfere with pattern matching + # so we provide them as inputs. + # Later, when we actually do the replacement, the symbolic shape + # sizes will get re-traced and added to the graph. + + def search_fn_new(*args_new): + return search_fn(*args_new[len(args_new) - len(args) :]) + + try: + specific_graph = trace_fn(search_fn_new, sym_args + args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + # correct argnames in the graph + sym_arg_names = [] + for i, placeholder in zip( + range(len(sym_args) + len(args)), + specific_graph.graph.nodes, + ): + if i < len(sym_args): + sym_arg_names.append(placeholder.target) + continue + + with specific_graph.graph.inserting_after(placeholder): + new_node = specific_graph.graph.placeholder( + argnames[i - len(sym_args)] + ) + new_node.target = new_node.name + placeholder.replace_all_uses_with(new_node) + specific_graph.graph.erase_node(placeholder) + + argnames = sym_arg_names + argnames + else: + try: + specific_graph = trace_fn(search_fn, args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + specific_pattern = fx_to_pattern( + specific_graph, + argnames=argnames, + exclusive_arg_names=exclusive_arg_names, + scalar_workaround=scalar_workaround, + ) + specific_pattern_match = specific_pattern.match(match.output_nodes()[0]) # type: ignore[arg-type] + if specific_pattern_match and extra_check(specific_pattern_match): + # trace the pattern using the shapes from the user program + match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment] + return True + return False + + def normalize_args(**kwargs): + args = [] + for name in argnames_static: + args.append(kwargs.pop(name)) + for i in range(1, len(kwargs) + 1): + if f"tangents_{i}" not in kwargs: + break + args.append(kwargs.pop(f"tangents_{i}")) + assert not kwargs, f"leftover kwargs: {kwargs!r}" + return args + + if trace_fn is joint_fwd_bwd: + # If inference mode is enabled during compilation, assume that we don't + # want to match on any training graph patterns + if torch.is_inference_mode_enabled(): + return False + + # TODO: Revisit the functionalize_rng_ops for lowmem dropout + with functorch_config.patch(functionalize_rng_ops=False): + requires_grad: List[bool] = [ + isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs + ] + if search_fn_pattern is None: + pattern = gen_pattern( + search_fn, + example_inputs, + trace_fn, + scalar_workaround, + exclusive_arg_names, + ) + else: + pattern = search_fn_pattern + + pattern_repr = PatternPrettyPrinter.run(pattern) + assert pattern_repr not in _seen_patterns + _seen_patterns.add(pattern_repr) + pattern = ReplacementPatternEntry( + pattern=pattern, + extra_check=check_fn, + normalize_args=normalize_args, + ) + pattern.register(pass_dicts) + return pattern.pattern + + +@functorch_config.patch(functionalize_rng_ops=False) +def gen_pattern( + search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=() +) -> PatternExpr: + argnames = [*inspect.signature(search_fn).parameters.keys()] + + if scalar_workaround == (): + scalar_workaround = {} + flat_inputs = [] + input_idx = 0 # Positional arguments index + + for argname in argnames: + if argname in scalar_workaround: + flat_inputs.append(scalar_workaround[argname]) + else: + flat_inputs.append(example_inputs[input_idx]) + input_idx += 1 + + search_gm = trace_fn(search_fn, flat_inputs) + return fx_to_pattern( + search_gm, + ignore_types=(int, float, list, torch.device, torch.dtype), + argnames=argnames, + scalar_workaround=scalar_workaround, + exclusive_arg_names=exclusive_arg_names, + ) + + +def register_lowering_pattern( + pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False +): + """ + Register an aten to inductor IR replacement pattern. The decorated + function is saved and then called a lowering time allowing direct + pattern to inductor IR conversion. + """ + + def decorator(handler): + assert callable(handler) + LoweringPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + handler._inductor_lowering_function = True + return handler + + return decorator + + +def register_graph_pattern( + pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False +): + """ + Register a pattern that runs a function on the FX graph, allowing + custom transformation code. + """ + + def decorator(handler): + assert callable(handler) + GraphPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + return handler + + return decorator + + +def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool: + # first node in the graph + return node is next(iter(graph.nodes)) + + +# match: copy_, relu_, _set_grad_enabled, manual_seed, enter_functional_autocast, etc +_mutation_op_re = re.compile(r"_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_)") + + +def is_mutation_op(node: torch.fx.Node) -> bool: + if node.op == "call_function": + if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr] + return True + elif node.op == "call_method": + if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type] + return True + return node.kwargs.get("out") is not None + + +def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int: + n = node + while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n): + n = n.prev + mutation_region_id = n.meta.get("mutation_region_id", 0) + while n is not node: + n = n.next + if is_mutation_op(n): + mutation_region_id += 1 + n.meta["mutation_region_id"] = mutation_region_id + return mutation_region_id + + +def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool: + return "mutation_region_id" not in next(iter(graph.nodes)).meta + + +def compute_mutation_region_ids(graph: torch.fx.GraphModule): + mutation_region_id = 0 + for nd in graph.nodes: + if is_mutation_op(nd): + mutation_region_id += 1 + nd.meta["mutation_region_id"] = mutation_region_id + + +class PatternMatcherPass: + def __init__( + self, prevent_match_across_mutations=False, pass_name: Optional[str] = None + ): + super().__init__() + self.patterns: DefaultDict[ + torch.fx.node.Target, List[PatternEntry] + ] = defaultdict(list) + self.prevent_match_across_mutations = prevent_match_across_mutations + self.pass_name = pass_name + + def __getitem__(self, item: torch.fx.node.Target) -> List[PatternEntry]: + return self.patterns[item] + + def apply(self, graph: torch.fx.GraphModule) -> int: + if not self.patterns: + return 0 + if isinstance(graph, torch.fx.GraphModule): + graph = graph.graph + if self.prevent_match_across_mutations: + if should_compute_mutation_region_ids(graph): + compute_mutation_region_ids(graph) + get_mutation_region_id_partial = functools.partial( + get_mutation_region_id, graph + ) + count = 0 + for node in reversed(graph.nodes): + target = extract_target(node) + if ( + node.op in ["call_function", "call_method", "call_module"] + and target in self.patterns + ): + # conservatively not applying pattern for cpu input, + # since some of the patterns induce codegen and split nodes. + # Note: we will only skip cpu compute if disable_cpp_codegen=True + if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): + continue + + for entry in self.patterns[target]: + if node._erased: + break + m = entry.pattern.match(node) + # pattern match crosses mutation barrier - discard + if ( + self.prevent_match_across_mutations + and is_match(m) + and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] + ): + continue + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning("%s%s %s %s", node, node.args, m, entry.pattern) + if is_match(m) and entry.extra_check(m): + count += 1 + entry.apply(m, graph, node) # type: ignore[arg-type] + counters["inductor"]["pattern_matcher_count"] += 1 + counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) + return count + + def clear(self): + self.patterns.clear() + + +def _not_implemented(*args, **kwargs) -> NoReturn: + raise NotImplementedError() + + +def fx_to_pattern( + gm, + ignore_types=(), + argnames=(), + scalar_workaround=(), + exclusive_arg_names=(), +) -> PatternExpr: + """ + Convert an FX graph into a PatternExpr. This is useful for simple + patterns that can only match single functions and fixed-length lists. + """ + # scalar_workaround is a hack to capture dropout_p + # see https://github.com/pytorch/pytorch/issues/97894 + scalar_workaround = scalar_workaround or {} + inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()} + assert len(inv_scalar_workaround) == len(scalar_workaround) + + def process_arg(x): + if isinstance(x, (float, int)) and x in inv_scalar_workaround: + return KeywordArg(inv_scalar_workaround[x]) + if type(x) in ignore_types: + return Ignored() + if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x: + return Ignored() + return x + + argnum = itertools.count() + + class Converter(torch.fx.Interpreter): + call_method = _not_implemented + call_module = _not_implemented + get_attr = _not_implemented + + def placeholder(self, target, args, kwargs): + n = next(argnum) + if n < len(argnames): + name = argnames[n] + elif argnames: + assert target.startswith("tangent") + name = target + else: + target = re.sub(r"_\d+$", "", target) # de-mangle arg name + name = target + if name in exclusive_arg_names: + return ExclusiveKeywordArg(name) + else: + return KeywordArg(name) + + def call_function(self, target, args, kwargs): + args, kwargs = pytree.tree_map(process_arg, (args, kwargs)) + if list in ignore_types: + # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...] + args = [process_arg(a) for a in args] + kwargs = {k: process_arg(a) for k, a in kwargs.items()} + return CallFunction(target, *args, **kwargs) + + def run_node(self, n): + rv = super().run_node(n) + if n.op == "output" and isinstance(rv, tuple): + assert len(rv) == len(n.args[0]) + for r, arg in zip(rv, n.args[0]): + r.users = len(arg.users) + else: + rv.users = len(n.users) + return rv + + pattern = Converter(gm).run() + if not isinstance(pattern, PatternExpr): + return MultiOutputPattern(pytree.tree_leaves(pattern)) + return pattern + + +@torch.no_grad() +def fwd_only(fn, args, *, run_dce=True) -> torch.fx.GraphModule: + """Build a normalized inference graph, for use with fx_to_pattern""" + # TODO - look into using aot autograd, asserting no mutating ops here + with enable_python_dispatcher(): + mode = ( + "real" if not torch._inductor.utils.any_is_symbolic(*args) else "symbolic" + ) + gm = make_fx(fn, select_decomp_table(), tracing_mode=mode)(*args) + if run_dce: + gm.graph.eliminate_dead_code() + gm.recompile() + return gm + + +@torch.enable_grad() +def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule: + """Build a normalized training graph, for use with fx_to_pattern""" + gm: Optional[torch.fx.GraphModule] = None + + def record_joint_graph(joint_graph, inputs, **kwargs): + nonlocal gm + assert not gm + gm = clone_graph(joint_graph) + return default_partition(joint_graph, inputs, **kwargs) + + with torch._guards.tracing(None): + aot_function( + fn, + lambda g, i: make_boxed_func(g), + partition_fn=record_joint_graph, + decompositions=select_decomp_table(), + keep_inference_input_mutations=True, + enable_log=False, + )(*args) + assert gm + + from .fx_passes.joint_graph import pointless_view + + matcher_pass = PatternMatcherPass() + + pattern = CallFunction( + torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size") + ) + GraphPatternEntry( + pattern=pattern, handler=pointless_view, extra_check=_return_true + ).register(matcher_pass.patterns) + matcher_pass.apply(gm.graph) # type: ignore[arg-type] + + # remove in/out specs + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.graph.eliminate_dead_code() + gm.recompile() + return gm + + +def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: + args: List[torch.fx.node.Argument] = list() + torch.fx.map_arg((n.args, n.kwargs), args.append) + return args + + +def stable_topological_sort(graph: torch.fx.Graph): + # Nodes are in exactly one of these three collections: + + # - Nodes in `pending` are waiting to be processed (in reverse order): + pending = list(reversed(graph.nodes)) + + # - Nodes in `ready` have been processed and are already in the correct + # order. + ready = set() + + # - `waiting` is a mapping from a dependency to nodes which depend on that + # dependency. + waiting = defaultdict(list) + + # The cursor indicates the last processed node so we can add new nodes + # after it. + cursor = None + while pending: + node = pending.pop() + waiting_for = [x for x in _args(node) if x not in ready] + if waiting_for: + # We have unprocessed input nodes. Might as well wait for the last + # arg so an already sorted list will only recheck this node once. + waiting[waiting_for[-1]].append(node) + else: + ready.add(node) + if cursor and cursor.next is not node: + cursor.append(node) + cursor = node + # Mark the nodes that have been waiting for this node to finish as + # ready to check again. + pending.extend(reversed(waiting.pop(node, ()))) + + assert not waiting and len(ready) == len(graph.nodes) + + +def init_once_fakemode(fn: Callable[..., Any]): + """Wrapper around lazy init functions in fx_passes/""" + + @functools.lru_cache(None) + @functools.wraps(fn) + def lazy_init(): + counters_ref = counters["inductor"].copy() + + with torch._guards.tracing( + None + ), maybe_disable_fake_tensor_mode(), FakeTensorMode(): + result = fn() + + # clear view matches encountered during tracing + counters["inductor"] = counters_ref + + return result + + return lazy_init + + +def config_flag(name): + """Function for extra_check to put pass behind a flag""" + + def flag_check(match): + return getattr(config, name) + + return flag_check + + +def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: + class CopyGraph(Transformer): + def run_node(self, old_node): + new_node = super().run_node(old_node) + if isinstance(new_node, torch.fx.Proxy): + new_node.node.meta.update(old_node.meta) + new_node.node.name = self.new_graph._graph_namespace.create_name( + old_node.name, None + ) + return new_node + + return CopyGraph(input_graph).transform() + + +_seen_patterns: Set[str] = set() + + +def get_arg_value( + node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None +): + return ( + node.args[arg_number] + if len(node.args) > arg_number + else node.kwargs.get(kwarg_name) # type: ignore[arg-type] + ) + + +def filter_nodes(nodes: Iterable[torch.fx.Node], fn) -> List[torch.fx.Node]: + fns = [fn] + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + + return [node for node in nodes if node.target in fns] + + +def extract_target(node: Node): + """For call_function and call_method, we directly use the target function; + For call_module, the target is string, and we treat the module class + as a function. + """ + if node.op == "call_module": + return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type] + return node.target diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..5091f69000bc2974c62ddae3b8f610aa06f87ce1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/scheduler.py @@ -0,0 +1,2445 @@ +import collections +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import pprint +import textwrap +from typing import ( + Any, + Counter, + DefaultDict, + Dict, + Generic, + List, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, +) + +import sympy + +import torch +from torch._dynamo.utils import dynamo_timed +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.utils._triton import has_triton + +from . import comms, config, dependencies, ir, metrics +from .codegen.common import get_scheduling_for_device, Kernel +from .comm_analysis import estimate_nccl_collective_runtime +from .dependencies import Dep, MemoryDep, StarDep, WeakDep +from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout +from .sizevars import SimplifyIndexing +from .utils import ( + cache_on_self, + cmp, + free_symbol_has, + get_device_tflops, + get_dtype_size, + get_gpu_dram_gbps, + green_text, + is_collective, + is_wait, + red_text, + sympy_product, +) +from .virtualized import V + + +log = logging.getLogger(__name__) +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") + + +class WhyNoFuse: + # TODO when we drop support for Python < 3.10, we can use + # @dataclass(slots=True) instead of manually specifying __slots__. + __slots__ = ["node1", "node2", "reason", "args"] + reason: str + args: Tuple[Any, ...] + + def __init__(self, node1: "BaseSchedulerNode", node2: "BaseSchedulerNode"): + self.node1 = node1 + self.node2 = node2 + + def __call__(self, reason, *args): + self.reason = reason + self.args = args + fusion_log.debug(self) + + def __str__(self): + return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( + self.reason % self.args + ) + + +def pformat(obj): + if isinstance(obj, set): + # pformat has trouble with sets of sympy exprs + obj = sorted(obj, key=str) + result = pprint.pformat(obj, indent=4) + if "\n" in result: + return f"\n{textwrap.indent(result, ' '*4)}" + return result + + +class OutputNode: + def __init__(self, dep): + self.unmet_dependencies = {dep} + self.inverse_users = [] + + def is_reduction(self): + return False + + def get_alias_names(self): + return () + + def get_name(self): + return "OUTPUT" + + __repr__ = get_name + + +def _prune_redundant_deps(node, name_to_fused_node): + """ + Prunes weakdeps intended for mutation ordering + on an upstream fused node if after fusion there is another dependency + on the fused upstream node, making the weakdep redundant + + In essence this enforces an ordering on fusions. As fusions occur, weakdeps will + be incrementally removed, enabling other fusions, ensuring they are fused in order. + """ + name_to_dep_count: Counter[str] = collections.Counter() + + for dep in node.unmet_dependencies: + if not isinstance(dep, WeakDep): + name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1 + + def should_prune(dep): + if isinstance(dep, WeakDep): + is_redundant = ( + name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0 + ) + # These can occur because fused nodes always gather deps from their snodes + # If B has a weakdep on A + # B gets fused with C, then any time BC is fused, the weakdep will reappear + is_self_dep = name_to_fused_node[dep.name] == node + return is_redundant or is_self_dep + else: + return False + + deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)} + + if deps_to_prune: + node.unmet_dependencies = node.unmet_dependencies - deps_to_prune + node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) + + +# TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel +kernel_name_to_op = { + "extern_kernels.convolution": torch.ops.aten.convolution, + "extern_kernels.mm": torch.ops.aten.mm, + "extern_kernels.bmm": torch.ops.aten.bmm, + "extern_kernels.addmm": torch.ops.aten.addmm, +} + + +class BaseSchedulerNode: + def __init__(self, scheduler: "Scheduler", node: ir.Buffer): + self.scheduler: Scheduler = scheduler + self.node: ir.Buffer = node + self.users: List[NodeUser] = [] + self.inverse_users: List[BaseSchedulerNode] = [] + self.node_users: List[BaseSchedulerNode] = [] + self.set_read_writes(node.get_read_writes()) + self.ancestors: Set[str] = set() + self.min_order: int + self.max_order: int + self.last_usage: Set[ + str + ] = set() # buffers that won't be used after this kernel + self.written = False + + def __repr__(self): + return f"{type(self).__name__}(name={self.get_name()!r})" + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + lines = [ + f"{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})", + f"{name}.writes = {pformat(self.read_writes.writes)}", + f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}", + f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}", + f"{name}.users = {self.users}", + ] + try: + lines += [ + self.debug_str_extra(), + ] + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return "\n".join(lines).rstrip() + + def debug_str_extra(self) -> str: + return "" + + def log_details(self): + log.info( + "%s: unmet_dependencies = %s, writes = %s", + self, + self.unmet_dependencies, + self.read_writes.writes, + ) + + def update_mutated_names(self, renames: Dict[str, str]): + self.set_read_writes(self.read_writes.rename(renames)) + + def add_mutation_dep(self, dep): + self.set_read_writes(self.read_writes.with_read(dep)) + + def add_fake_dep(self, dep): + self.set_read_writes(self.read_writes.with_read(dep)) + + def set_users(self, users: List["NodeUser"]): + # deduplicate + result: Dict[int, NodeUser] = {} + for use in users: + if id(use.node) in result: + result[id(use.node)] = use.merge(result[id(use.node)]) + else: + result[id(use.node)] = use + self.users = list(result.values()) + + def set_last_usage( + self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] + ): + used_buffers = self.used_or_aliased_buffer_names() + used_buffers = {mutation_real_name.get(k, k) for k in used_buffers} + self.last_usage = used_buffers - future_used_buffers + + def get_aliases(self): + return self.node.get_alias_names() + + def get_mutations(self): + return self.node.get_mutation_names() + + def has_aliasing_or_mutation(self): + return bool(self.get_aliases() or self.get_mutations()) + + def set_read_writes(self, rw: dependencies.ReadWrites): + self.read_writes: dependencies.ReadWrites = rw + self.unmet_dependencies = self.read_writes.reads + self.prune_deps() + + def op_counts(self): + return self.read_writes.op_counts + + def used_buffer_names(self) -> Set[str]: + return { + dep.name + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) + } + + def used_or_aliased_buffer_names(self) -> Set[str]: + used_names = set() + + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes): + used_names.add(dep.name) + if V.graph.name_to_buffer.get(dep.name): + layout = V.graph.name_to_buffer[dep.name].get_layout() + # needed to avoid deallocating aliased buffer + # if there are still uses of aliases ahead + if isinstance(layout, ir.AliasedLayout): + used_names.add(layout.view.data.get_name()) + return used_names + + def prune_deps(self): + self.unmet_dependencies = { + dep + for dep in self.unmet_dependencies + if dep.name not in self.scheduler.available_buffer_names + } + + def prune_weak_deps(self): + # Prune weak dependencies on buffers that have been removed + def should_prune(dep): + return isinstance(dep, WeakDep) and dep.name in V.graph.removed_buffers + + to_remove = {dep for dep in self.read_writes.reads if should_prune(dep)} + self.set_read_writes(self.read_writes.remove_reads(to_remove)) + + def prune_redundant_deps(self, name_to_fused_node): + _prune_redundant_deps(self, name_to_fused_node) + + def get_name(self) -> str: + return self.node.get_name() + + def get_first_name(self) -> str: + return self.get_name() + + def get_names(self) -> Set[str]: + return {self.get_name()} + + def get_nodes(self) -> Sequence["BaseSchedulerNode"]: + return [self] + + def get_device(self): + return self.node.get_device() + + def is_reduction(self): + return False + + def is_split_scan(self): + return False + + def is_template(self): + return False + + def is_extern(self): + return False + + def is_foreach(self): + return False + + def can_inplace(self, read_dep: dependencies.MemoryDep): + return False + + def has_side_effects(self): + return False + + def decide_inplace_update(self): + """ + Decide if there should be inplace updates for the node + and record the decision in the active kernel. + """ + if not self.node.should_allocate(): + return + + if isinstance(self, (SchedulerNode,)) and ( + self.node.get_alias_names() or self.node.get_mutation_names() + ): + return + + if ( + ( + isinstance(self, (SchedulerNode,)) + # o what have i done. lets make this an api + or ( + isinstance(self, ExternKernelSchedulerNode) + and isinstance(self.node, (ir.AllReduce, ir.InPlaceHint)) + ) + ) + and config.inplace_buffers + and ( + not isinstance(V.kernel, torch._inductor.codegen.triton.TritonKernel) + or getattr(V.kernel, "mutations", None) is not None + ) + ): + from .codegen.wrapper import buffer_reuse_key + + ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name) + + for read in ordered_reads: + input_node: Optional[ + BaseSchedulerNode + ] = self.scheduler.name_to_node.get(read.name) + if input_node and V.graph.wrapper_code.can_reuse(input_node, self): + assert input_node.users is not None + remaining_uses = [ + x + for x in input_node.users + if x.node.get_name() + not in self.scheduler.available_buffer_names + ] + if ( + len(remaining_uses) == 1 + and remaining_uses[0].can_inplace + and remaining_uses[0].node is self + and not isinstance( + input_node.node.get_layout(), + ( + ir.MultiOutputLayout, + ir.MutationLayout, + ir.AliasedLayout, + ), + ) + and not ( + isinstance( + input_node.node, (ir.FallbackKernel, ir.MultiOutput) + ) + and len(input_node.node.get_alias_names()) > 0 + ) + and buffer_reuse_key(input_node.node) + == buffer_reuse_key(self.node) + ): + # hacky check for if V.kernel is a real kernel or NullHandler + if hasattr(V.kernel, "args"): + # if there isn't a triton kernel, then we don't need to call triton-specific things. + # but TODO this might be a convenient place to signal to the Collective kernels to inplace + # (and, can we make "kernel" less generic of a name?) + V.kernel.args.make_inplace( + input_node.get_name(), self.get_name() + ) + # mutations not tracked in cpp kernels + if isinstance( + V.kernel, torch._inductor.codegen.triton.TritonKernel + ): + V.kernel.mutations.add(input_node.get_name()) + V.kernel.mutations.add(self.get_name()) + + # update last usage of reused node + self.last_usage.discard(input_node.get_name()) + + V.kernel.inplace_update_buffers[ + self.get_name() + ] = input_node.get_name() + break + + def allocate(self): + if not self.node.should_allocate(): + return + + if isinstance(self, (SchedulerNode,)) and ( + self.node.get_alias_names() or self.node.get_mutation_names() + ): + V.graph.wrapper_code.codegen_allocation(self.node) + return + + # hacky check for if V.kernel is a real kernel or NullHandler + if ( + hasattr(V.kernel, "args") + and self.get_name() in V.kernel.inplace_update_buffers + ): + V.graph.wrapper_code.codegen_inplace_reuse( + self.scheduler.name_to_node[ + V.kernel.inplace_update_buffers[self.get_name()] + ].node, + self.node, + ) + else: + V.graph.wrapper_code.codegen_allocation(self.node) + + def can_free(self): + # There's no real allocated buffer, no need to free it + if isinstance(self.node.layout, ir.NoneLayout): + return False + for use in self.users: + if isinstance(use.node, OutputNode): + return False + return True + + def codegen_originating_info(self, buffer, only_once=True): + if not config.comment_origin: + return + + if only_once and self.written: + return + origins = self.node.origins + out_lines = [] + + for o in origins: + if o.op == "output": + # These are boring and samey + continue + + out_lines.append("") + # TODO(voz): Should the pragma be constant somewhere? + out_lines.append("#pragma CMT ORIGIN:") + op_info_str = f"#pragma CMT {o.op} {o.target}" + if "seq_nr" in o.meta: + op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}" + out_lines.append(op_info_str) + if "stack_trace" in o.meta: + stack_trace = f"{o.meta['stack_trace']}" + stack_trace_last_line = stack_trace.split("|")[-1] + out_lines.append( + "#pragma CMT " + + stack_trace_last_line.replace("{", "{{") + .replace("}", "}}") + .replace("\n", "\\") + ) + out_lines.append("#pragma CMT END ORIGIN") + out_lines.append("") + + if len(out_lines) == 0: + return + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + buffer.writelines(out_lines) + self.written = True + + def get_read_write_buffers_sizes(self) -> int: + """ + Counting the number of bytes accessed for a kernel is + surprisingly tricky. In particular, there is a differentiation + between 'theoretical' memory accesses and practical memory + accesses. For example, a layernorm kernel may actually access an + input 3 times, but in theory, it only needs to access its input + once (and may be optimized to do so through say, persistent + reductions) + + Another example is that even though a buffer is passed in, we may + not access the entire buffer. This may occur if we are accessing + a slice of the buffer. Another tricky case is for indirect + indexing, where the amount of bytes accessed depends on the + values of the input. + + What this function aims to compute is the memory accesses for + worst-case inputs, best-case optimization. What this means is + that for each buffer we compute the amount of potential accesses in two ways and take the minimum. + + 1. Numel in ranges multiplied by number of deps the buffer has + 2. The buffer size + """ + if isinstance(self, NopKernelSchedulerNode): + return 0 + if isinstance(self, ExternKernelSchedulerNode) and isinstance( + self.node, MultiOutput + ): + return 0 + + if isinstance(self, SchedulerNode): + node_numel = V.graph.sizevars.size_hint( + sympy_product(self.get_ranges()[0]) + * sympy_product(self.get_ranges()[1]) + ) + else: + node_numel = int(1e9) + buf_accesses = collections.defaultdict(list) + for dep in self.read_writes.reads | self.read_writes.writes: + buf_accesses[dep.name].append(dep) + + reads = {dep.name for dep in self.read_writes.reads} + writes = {dep.name for dep in self.read_writes.writes} + + def is_materialized(buf, snodes): + users = self.scheduler.name_to_node[buf].users + buf_uses = {user.node for user in users} + return len(buf_uses - set(snodes)) > 0 + + if isinstance(self, FusedSchedulerNode): + removed_buffers = { + dep for dep in writes if not is_materialized(dep, self.snodes) + } + writes = writes - removed_buffers + reads = reads - removed_buffers + node_bytes = 0 + + for buf_name in reads | writes: + buf_accessed_elems = sum([node_numel for dep in buf_accesses[buf_name]]) + buf: Union[ir.Buffer, ir.TensorBox] + if buf_name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[buf_name] + elif buf_name in V.graph.graph_inputs: + buf = V.graph.graph_inputs[buf_name] + else: + continue + + def get_buf_elems(buf): + return V.graph.sizevars.size_hint(sympy_product(buf.get_size())) + + # Kind of a lazy way to get the MultiOutput nodes corresponding to + # a MultiOutputLayout + if isinstance(buf.layout, MultiOutputLayout): + users = self.scheduler.name_to_node[buf.get_name()].users + buf_elems = sum(get_buf_elems(user.node.node) for user in users) + else: + buf_elems = get_buf_elems(buf) + + node_bytes += min(buf_elems, buf_accessed_elems) * get_dtype_size( + buf.get_dtype() + ) + + return node_bytes + + def get_estimated_runtime(self) -> float: + """ + Returns estimated op runtime in nanoseconds (ns) + """ + layout = None + dtype = None + if not hasattr(self, "node") or not self.node: + assert isinstance( + self, (FusedSchedulerNode, ForeachKernelSchedulerNode) + ), f"{type(self)=}" + assert self.snodes + if not self.snodes[0].node: + return 0 + layout = self.snodes[0].node.get_layout() + dtype = self.snodes[0].node.get_dtype() + else: + layout = self.node.get_layout() + dtype = self.node.get_dtype() + + if "cuda" != layout.device.type: + # default to no reordering based on runtime + return 0 + + # Collective kernels + if is_collective(self.node): + return estimate_nccl_collective_runtime(self.node) + elif is_wait(self.node): + # ir.Wait is only used for collective ops. + # The time needed for the collective op is already estimated and considered + # when we are processing the collective op IR node, so ir.Wait takes 0 time + # since it doesn't take extra time to get the result after the collective is completed. + return 0 + + try: + gpu_memory_bandwidth = get_gpu_dram_gbps() + gpu_flops = get_device_tflops(dtype) * 10**12 + except Exception: + return 0 + + if isinstance(self, ExternKernelSchedulerNode): + assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}" + op = kernel_name_to_op.get( + getattr(self.node, "python_kernel_name", ""), None + ) + + # if there is a resolved op, dry-run using fake mode and record flop count + if op is not None: + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.utils.flop_counter import FlopCounterMode + + with FakeTensorMode(), FlopCounterMode( + display=False + ) as flop_counter_mode: + from .ir import ir_node_to_tensor + + fake_inputs = [ + ir_node_to_tensor(input, guard_shape=False) + for input in self.node.inputs + ] + cls = self.node.__class__ + cls.process_kernel(op, *fake_inputs, **self.node.kwargs) + + # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship + factor = 1.0 + counted_flops = flop_counter_mode.get_total_flops() + counted_bytes = self.get_read_write_buffers_sizes() + compute_time = (factor * counted_flops / gpu_flops) * 1e9 + transfer_time = counted_bytes / gpu_memory_bandwidth + + # Return estimated runtime in nanoseconds + return max(compute_time, transfer_time) + + elif isinstance(self, FusedSchedulerNode) or isinstance( + self.node, ComputedBuffer + ): + # Return estimated runtime in nanoseconds (bytes / gbps) + return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth + + return 0 + + +class ExternKernelSchedulerNode(BaseSchedulerNode): + def debug_str_extra(self) -> str: + return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" + + def is_extern(self): + return True + + def has_side_effects(self): + return hasattr(self.node, "has_side_effects") and self.node.has_side_effects() + + def can_inplace(self, read_dep: dependencies.MemoryDep): + if self.get_aliases() or self.is_template(): + return False + + if read_dep.name not in self.scheduler.name_to_node: + # don't allow reuse of an 'input' buffer, we don't own it + # (would this have been fixed if I tracked mutations properly above?) + return False + if not isinstance( + self.node, (torch._inductor.ir.AllReduce, torch._inductor.ir.InPlaceHint) + ): + # TODO make this a property of the IR + return False + + if len(self.read_writes.writes) == 1: + write_dep = next(iter(self.read_writes.writes)) + numel_diff = read_dep.get_numel() - write_dep.get_numel() + return V.graph.sizevars.simplify(numel_diff) == 0 + + return False + + +class NopKernelSchedulerNode(BaseSchedulerNode): + pass + + +class SchedulerNode(BaseSchedulerNode): + def __init__( + self, + scheduler: "Scheduler", + node: Union[ir.ComputedBuffer, ir.TemplateBuffer], + ): + super().__init__(scheduler, node) + self._compute_attrs() + + def _compute_attrs( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + ): + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) + self._sizes, self._body = self.node.simplify_and_reorder( + extra_indexing_constraints=extra_indexing_constraints + ) + + group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn + self.group = (self.node.get_device(), group_fn(self._sizes)) + + if isinstance(self.node, ir.TemplateBuffer): + self.set_read_writes(self.node.normalized_read_writes()) + else: + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=True + ) + ) + + def recompute_size_and_body( + self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]] + ): + self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints) + + def debug_str_extra(self) -> str: + name = self.get_name() + lines = [ + f"{name}.group.device = {self.group[0]}", + f"{name}.group.iteration = {self.group[1]}", + f"{name}.sizes = {self._sizes}", + ] + if self.get_aliases(): + lines.append(f"{name}.aliases = {pformat(self.get_aliases())}") + if self.get_mutations(): + lines.append(f"{name}.mutations = {pformat(self.get_mutations())}") + if isinstance(self._body, ir.LoopBody): + lines.append(f"class {name}_loop_body:") + lines.append(textwrap.indent(self._body.debug_str(), " ")) + return "\n".join(lines) + + def get_ranges(self): + return self._sizes + + def is_reduction(self): + assert isinstance( + self.node, (ir.ComputedBuffer, ir.TemplateBuffer) + ), f"{type(self.node)=}" + return bool(self.node.get_reduction_type()) + + def is_split_scan(self): + assert isinstance( + self.node, (ir.ComputedBuffer, ir.TemplateBuffer) + ), f"{type(self.node)=}" + return isinstance(self.node, ir.ComputedBuffer) and isinstance( + self.node.data, ir.SplitScan + ) + + def is_template(self): + return isinstance(self.node, ir.TemplateBuffer) + + def get_template_node(self): + return self.node if self.is_template() else None + + def run(self, *index_vars): + self.decide_inplace_update() + self.mark_run() + self.codegen(index_vars) + + def mark_run(self): + self.allocate() + + def ranges_from_index_vars(self, index_vars): + sizes = self._sizes + assert sum(map(len, sizes)) == sum(map(len, index_vars)) + var_ranges = dict( + zip( + itertools.chain.from_iterable(index_vars), + itertools.chain.from_iterable(sizes), + ) + ) + return var_ranges + + def codegen(self, index_vars): + var_ranges = self.ranges_from_index_vars(index_vars) + try: + with V.set_ops_handler( + SimplifyIndexing(V.get_ops_handler(), var_ranges) + ), V.kernel.set_current_node(self): + self._body(*index_vars) + except Exception: + log.fatal("Error in codegen for %s", self.node) + raise + + def pointwise_read_writes(self): + """ + Get the memory dependencies in the non-reduction axis. + """ + sizes, reduction_sizes = self._sizes + + def fn(index): + return self._body(index, [sympy.Integer(0) for _ in reduction_sizes]) + + return dependencies.extract_read_writes(fn, sizes) + + def can_inplace(self, read_dep: dependencies.MemoryDep): + if self.get_aliases() or self.is_template(): + return False + if len(self.read_writes.writes) == 1 and isinstance( + read_dep, dependencies.MemoryDep + ): + write_dep = next(iter(self.read_writes.writes)) + assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}" + return read_dep.index == write_dep.index and read_dep.size == write_dep.size + return False + + @cache_on_self + def _get_atomic_add_buffers(self) -> Set[str]: + buffers_store_as_atomic_add = set() + if isinstance(self._body, ir.LoopBody): + for node in self._body.get_nodes(): + if ( + node.op == "call_method" + and node.target == "store" + and ( + ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add") + or (len(node.args) == 5 and node.args[4] == "atomic_add") + ) + ): + buffers_store_as_atomic_add.add( + node.kwargs["name"] + if "name" in node.kwargs + else (node.args[1] if len(node.args) >= 2 else "") + ) + return buffers_store_as_atomic_add + + def has_atomic_add(self, check_buf): + return check_buf in self._get_atomic_add_buffers() + + +class FusedSchedulerNode(BaseSchedulerNode): + """ + This is a "fake" scheduler node that represents a group of scheduler nodes + that are meant to be fused together. The way it does this is by maintaining + its unmet dependencies as the union of its constituent nodes. + """ + + @classmethod + def fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + assert node1.scheduler is node2.scheduler + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance( + node2, (SchedulerNode, FusedSchedulerNode) + ) + return cls(node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes())) # type: ignore[arg-type] + + def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]): + # NB: No need to call super().__init__() because we don't need to re-use any of its logic. + self.snodes = snodes + self.scheduler = scheduler + self.node: ir.Buffer = None # type: ignore[assignment] + self.users: List[NodeUser] = [] + self.inverse_users = [] + self.node_users = [] + self.group = max(snodes, key=lambda x: int(x.is_reduction())).group + self.ancestors = set.union( + *[x.ancestors for x in snodes if x.ancestors is not None] + ) + + self.set_read_writes( + dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) + ) + + self.unmet_dependencies = { + dep + for dep in set.union(*[x.unmet_dependencies for x in snodes]) + if dep.name not in self.get_names() + } - self.read_writes.writes + self.min_order = min([x.min_order for x in self.snodes]) + self.max_order = max([x.max_order for x in self.snodes]) + + @cache_on_self + def get_name(self) -> str: + return "_".join([x.get_name() for x in self.snodes]) + + def get_first_name(self) -> str: + return self.snodes[0].get_name() + + @cache_on_self + def get_names(self) -> Set[str]: + return set.union(*[x.get_names() for x in self.snodes]) + + def debug_str_extra(self) -> str: + lines = [ + f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" + for i, node in enumerate(self.snodes) + ] + return textwrap.indent("\n".join(lines).rstrip(), " ") + + def set_last_usage( + self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] + ): + # Set self.last_usage using the global information + # This will be used for inter-kernel optimisations + super().set_last_usage(future_used_buffers, mutation_real_name) + # Set self.last_usage on the snodes + # This will be used for optimisations within the kernel + future_used_buffers: Set[str] = set() + for node in reversed(self.snodes): + node.set_last_usage(future_used_buffers, mutation_real_name) + future_used_buffers.update(node.last_usage) # type: ignore[arg-type] + + @cache_on_self + def used_buffer_names(self) -> Set[str]: + return set.union(*[x.used_buffer_names() for x in self.snodes]) + + @cache_on_self + def used_or_aliased_buffer_names(self) -> Set[str]: + return set.union(*[x.used_or_aliased_buffer_names() for x in self.snodes]) + + def get_nodes(self) -> List[SchedulerNode]: + return self.snodes + + def __repr__(self): + return f"{type(self).__name__}(nodes={self.get_name()})" + + @cache_on_self + def is_reduction(self): + return any(x.is_reduction() for x in self.snodes) + + @cache_on_self + def is_split_scan(self): + return any(x.is_split_scan() for x in self.snodes) + + @cache_on_self + def is_template(self): + return any(x.is_template() for x in self.snodes) + + @cache_on_self + def get_template_node(self): + for node in self.snodes: + if node.is_template(): + return node + return None + + def get_device(self): + return self.group[0] + + @cache_on_self + def has_aliasing_or_mutation(self): + return any(x.has_aliasing_or_mutation() for x in self.snodes) + + @cache_on_self + def op_counts(self): + op_counts: Counter[str] = collections.Counter() + for node in self.snodes: + op_counts.update(node.op_counts()) + return op_counts + + def has_atomic_add(self, check_buf): + return any( + ( + isinstance(sub_schedule_node1, SchedulerNode) + and sub_schedule_node1.has_atomic_add(check_buf) + ) + for sub_schedule_node1 in self.get_nodes() + ) + + # None of these need to be implemented, as a FusedSchedulerNode is just an + # abstraction for scheduling purposes + def update_mutated_names(self, renames: Dict[str, str]): + raise NotImplementedError + + def add_mutation_dep(self, name): + raise NotImplementedError + + def set_users(self, users: List["NodeUser"]): + raise NotImplementedError + + def get_aliases(self): + raise NotImplementedError + + def get_mutations(self): + raise NotImplementedError + + def can_inplace(self, read_dep: dependencies.MemoryDep): + raise NotImplementedError + + def allocate(self): + raise NotImplementedError + + def can_free(self): + raise NotImplementedError + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + node_typestr = ",".join(type(n).__name__ for n in self.snodes) + lines = [ + f"{name}: {type(self).__name__}({node_typestr})", + f"{name}.writes = {pformat(self.read_writes.writes)}", + f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}", + f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}", + f"{name}.users = {self.users}", + ] + try: + lines += [ + self.debug_str_extra(), + ] + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return "\n".join(lines).rstrip() + + +class ForeachKernelSchedulerNode(FusedSchedulerNode): + """Scheduler node which consists of a list of scheduler nodes that each operate on a + distinct tensor in a list of tensors.""" + + def get_consumer_subnode_for(self, producer): + if producer.get_name() in self.read_to_node: + return self.read_to_node[producer.get_name()] + + return None + + def get_producer_subnode_for(self, consumer): + for rd in consumer.read_writes.reads: + if rd.name in self.name_to_node: + return self.name_to_node[rd.name] + + return None + + @classmethod + def can_fuse(cls, producer, consumer): + why = WhyNoFuse(producer, consumer) + if producer.is_foreach() and consumer.is_foreach(): + foreach_match = len(producer.snodes) == len(consumer.snodes) + if not foreach_match: + why("foreach do not have same length") + return foreach_match and all( + producer.scheduler.can_fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ) + elif consumer.is_foreach(): + consumer_subnode = consumer.get_consumer_subnode_for(producer) + if consumer_subnode is not None: + return consumer.scheduler.can_fuse(producer, consumer_subnode) + + why("candidate producer is not dep of any foreach consumer") + return False + + elif producer.is_foreach(): + producer_subnode = producer.get_producer_subnode_for(consumer) + if producer_subnode is not None: + return producer.scheduler.can_fuse(producer_subnode, consumer) + + why("candidate consumer has no dep in any foreach producer") + return False + + raise AssertionError( + "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node" + ) + + @classmethod + def fuse(cls, producer, consumer): + assert producer.is_foreach() or consumer.is_foreach() + prev_node_1 = None + prev_node_2 = None + if producer.is_foreach() and consumer.is_foreach(): + fused_nodes = [ + FusedSchedulerNode.fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ] + elif producer.is_foreach(): + producer_subnode = producer.get_producer_subnode_for(consumer) + fused_nodes = [] + prev_node_1 = producer + prev_node_2 = None + for node in producer.snodes: + if node is producer_subnode: + new_node = FusedSchedulerNode.fuse(node, consumer) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + + elif consumer.is_foreach(): + consumer_subnode = consumer.get_consumer_subnode_for(producer) + fused_nodes = [] + prev_node_1 = consumer + prev_node_2 = None + + for node in consumer.snodes: + if node is consumer_subnode: + new_node = FusedSchedulerNode.fuse(producer, node) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + + return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined] + + def __init__( + self, + scheduler: "Scheduler", + nodes: List[SchedulerNode], + prev_node_1=None, + prev_node_2=None, + ): + self.read_to_node = {} + self.name_to_node = {} + + if prev_node_1 is None or prev_node_2 is None: + super().__init__(scheduler, nodes) + + for node in nodes: + for read in node.read_writes.reads: + self.read_to_node[read.name] = node + + for name in node.get_names(): + self.name_to_node[name] = node + else: + self.scheduler = scheduler + self.snodes = nodes + self.node: ir.Buffer = None # type: ignore[assignment] + self.users: List[NodeUser] = [] + + self.set_read_writes( + dependencies.ReadWrites.merge_list( + [prev_node_1.read_writes, prev_node_2.read_writes] + ) + ) + + self.unmet_dependencies = { + dep + for dep in set.union( + prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies + ) + if dep.name not in self.get_names() + } - self.read_writes.writes + + self.min_order = min([prev_node_1.min_order, prev_node_2.min_order]) + self.max_order = max([prev_node_1.max_order, prev_node_2.max_order]) + + foreach_node = prev_node_1 if prev_node_1.is_foreach() else prev_node_2 + other_node = prev_node_2 if prev_node_1.is_foreach() else prev_node_1 + + self.ancestors = foreach_node.ancestors + self.ancestors.update(other_node.ancestors) + + self.name_to_node = foreach_node.name_to_node + for name in other_node.get_names(): + self.name_to_node[name] = other_node + + self.group = (nodes[0].get_device(), "foreach") + + self.origins: Set[torch.fx.Node] = set() + + def mark_run(self): + raise NotImplementedError + + def codegen(self): + assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}" + self.node.get_store_function()(self.node.make_loader()()) + + def can_free(self): + return NotImplementedError + + def is_foreach(self): + return True + + def get_subkernel_nodes(self): + """Returns a list of nodes which comprise the foreach kernel, operating on corresponding elements of our input lists. + These nodes may be vertically fused.""" + return list(self.snodes) + + def get_nodes(self): + """Returns all nodes contained in this kernel, unpacking fused nodes into their constituent scheduler nodes.""" + return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes)) + + def get_first_name(self): + return self.snodes[0].get_first_name() + + def prune_redundant_deps(self, name_to_fused_node): + _prune_redundant_deps(self, name_to_fused_node) + + for node in self.snodes: + node.prune_redundant_deps(name_to_fused_node) + + +def pick_loop_order(stride_lengths, sizes, priority_idx=()): + """ + A heuristic to decide loop iteration orders. This has not been well + tuned and may be something we should autotune. + """ + + @functools.cmp_to_key + def index_cmp(a, b): + if sizes[a] == 1 or sizes[b] == 1: + # 1-sizes don't matter, just move them to the end + return cmp(sizes[a] == 1, sizes[b] == 1) + + stride_len_a = [sl[a] for sl in stride_lengths] + stride_len_b = [sl[b] for sl in stride_lengths] + + # equivalent to + # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all() + a_first = sum( + sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + b_first = sum( + sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + if a_first > b_first: + return -1 + if b_first > a_first: + return 1 + + # otherwise contiguous + return cmp(b, a) + + order = list(reversed(range(len(stride_lengths[0])))) + if len(priority_idx) > 0: + # if we have priority node, only use that node's order + stride_lengths = [stride_lengths[pi] for pi in priority_idx] + if config.pick_loop_orders: + order.sort(key=index_cmp) + return order + + +@dataclasses.dataclass +class NodeUser: + node: BaseSchedulerNode + can_inplace: bool = False + + # A weak user must be scheduled after a given node, but doesn't actually + # use the result + is_weak: bool = False + + def __hash__(self): + return hash((self.node.get_name(), self.can_inplace, self.is_weak)) + + def __eq__(self, other): + return ( + self.get_name() == other.get_name() + and self.can_inplace == other.can_inplace + and self.is_weak == other.is_weak + ) + + def get_name(self): + return self.node.get_name() + + def merge(self, other: "NodeUser") -> "NodeUser": + assert self.node is other.node + return NodeUser( + self.node, + self.can_inplace and other.can_inplace, + self.is_weak and other.is_weak, + ) + + +_post_grad_graph_counter = itertools.count() + + +class Scheduler: + @dynamo_timed + def __init__(self, nodes): + super().__init__() + self.backends = {} + self.fuse_cache = {} + self.post_grad_graph_id = next(_post_grad_graph_counter) + + self.nodes = [] + self.available_buffer_names = { + *V.graph.graph_inputs.keys(), + *V.graph.constants.keys(), + } + + self.nodes = [self.create_scheduler_node(n) for n in nodes] + + # some new constants could have been created above + self.available_buffer_names.update(V.graph.constants.keys()) + for node in self.nodes: + node.prune_deps() + + self.name_to_node: Dict[str, BaseSchedulerNode] = { + n.get_name(): n for n in self.nodes + } + self.name_to_fused_node: Dict[ + str, BaseSchedulerNode + ] = dict() # set in fuse_nodes() + + # mutation_real_name: Maps back to the original name for codegen + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_real_name = {"buf0" : "buf1"} + # all subsequent uses of buf0 become buf1's usage in dependency graph + self.mutation_real_name = {} + + # We handle mutation by renaming modified versions of the same + # buffer in the dependency graph to prevent cycles. + # mutation_renames: tracks the current name for a given buffer + # (changed once per mutation) + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_renames = {"buf1" : "buf0"} + # in codegen we only use buf0, never buf1 + self.mutation_renames = {} + + self.compute_dependencies() + self.topological_sort_schedule() + self.dead_node_elimination() + if config.reorder_for_compute_comm_overlap: + comms.decide_global_ordering_of_comms(self.nodes) + self.compute_ancestors() + + metrics.ir_nodes_pre_fusion += len(self.nodes) + V.debug.ir_pre_fusion(self.nodes) + self.num_orig_nodes = len(self.nodes) + self.name_to_fused_node = {n.get_name(): n for n in self.nodes} + self.create_foreach_nodes() + self.topological_sort_schedule() + self.logged_slow_fusion = set() + self.fuse_nodes() + if config.reorder_for_compute_comm_overlap: + # Refresh node_users and inverse_users to reflect fused nodes + self.compute_node_users() + self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) + self.compute_last_usage() + V.debug.ir_post_fusion(self.nodes) + V.debug.graph_diagram(self.nodes) + self.debug_draw_graph() + + # used during codegen: + self.current_device: torch.device = None # type: ignore[assignment] + self.buffer_names_to_free = set() + + # fx graph node to the position it appears in the graph + # for debug attribution + self.origin_to_index = {} + + get_metric_table("graph_stats").add_row( + lambda: { + "graph_id": self.post_grad_graph_id, + "num_nodes_before_fusion": self.num_orig_nodes, + "num_nodes_after_fusion": len(self.nodes), + } + ) + + def debug_draw_graph(self): + """Generate an image of the graph for debugging""" + if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": + from .debug import draw_buffers + + draw_buffers(self.nodes, print_graph=True) + + def debug_print_nodes(self, label): + if log.isEnabledFor(logging.INFO): + log.info("%s:", label) + for node in self.nodes: + node.log_details() + + def create_scheduler_node(self, node): + assert ( + node.origins is not None + ), "All nodes passed to scheduling must have an origin" + if node.is_no_op(): + return NopKernelSchedulerNode(self, node) + elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): + return SchedulerNode(self, node) + elif isinstance(node, ir.ExternKernel): + return ExternKernelSchedulerNode(self, node) + else: + raise NotImplementedError(node) + + def create_foreach_nodes(self): + removed_node_names = set() + fe_nodes = [] + kept_node_names = self.name_to_fused_node.keys() + + for names in V.graph.lists.values(): + names = [ + name + for name in names + if name in kept_node_names + and not isinstance(self.name_to_node[name], NopKernelSchedulerNode) + ] + if not names: + # All nodes eliminated + continue + + removed_node_names.update(names) + snodes = [self.name_to_node[name] for name in names] + + fe_node = ForeachKernelSchedulerNode(self, snodes) # type: ignore[arg-type] + + fe_nodes.append(fe_node) + + for name in names: + self.name_to_fused_node[name] = fe_node + + self.nodes = [ + node for node in self.nodes if node.get_name() not in removed_node_names + ] + fe_nodes + + def compute_dependencies(self): + """ + Create dependency edges between nodes, handling aliasing and + mutation properly. + """ + + T = TypeVar("T") + + class DedupList(Generic[T]): + """ + This data structure behaves like a list except it makes sure the + elements remain unique. + Normally one could use a set/dict for this purpose however + the list in question gets elements appended as it is being + iterated over which means that we need to keep the list + semantics. + """ + + def __init__(self, items=None, membership=None): + self.items = items or list() + self.membership = membership or set() + + def append(self, node_user: T) -> None: + if node_user in self.membership: + return + self.items.append(node_user) + self.membership.add(node_user) + + def __add__(self, other: "DedupList[T]") -> "DedupList[T]": + new_membership = set.union(self.membership, other.membership) + new_items = self.items + [ + x for x in other.items if x not in self.membership + ] + return DedupList(new_items, new_membership) + + name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict( + DedupList + ) + + # handle aliasing by using python aliasing in name_to_users + # if foo aliases bar then we will make name_to_users["foo"] point + # to the same python list as name_to_users["bar"] + for node1 in self.nodes: + node1_name = node1.get_name() + for node2_name in node1.get_aliases(): + if node1_name in name_to_users and node2_name in name_to_users: + # merge the two + list1 = name_to_users[node1_name] + list2 = name_to_users[node2_name] + combined = list1 + list2 + for key in name_to_users.keys(): + if name_to_users[key] is list1 or name_to_users[key] is list2: + name_to_users[key] = combined + elif node1_name in name_to_users: + name_to_users[node2_name] = name_to_users[node1_name] + else: + name_to_users[node1_name] = name_to_users[node2_name] + + def rename(n): + if n in self.mutation_renames: + return rename(self.mutation_renames[n]) + return n + + def dep_closure(node_name): + reachable_names = {node_name} + node = self.name_to_node[node_name] + write_dep = next(iter(node.read_writes.writes)) + for read_dep in node.read_writes.reads: + if ( + read_dep.name in self.name_to_node + and isinstance(read_dep, dependencies.MemoryDep) + and isinstance(write_dep, dependencies.MemoryDep) + and read_dep.index == write_dep.index + and read_dep.size == write_dep.size + ): + reachable_names.update(dep_closure(read_dep.name)) + return reachable_names + + def add_user(used_by_name, user_node, can_inplace=False, is_weak=False): + name_to_users[rename(used_by_name)].append( + NodeUser(user_node, can_inplace, is_weak) + ) + + unbacked_symbol_to_origin_node = {} + + for node in self.nodes: + log.debug("scheduling %s", node.node) + + # unbacked symbols don't follow ordinary buffer dependencies, so + # we track their def/uses separately + unbacked_symbol_defs = sorted( + node.node.get_unbacked_symbol_defs(), key=lambda x: x.name + ) + for s in unbacked_symbol_defs: + assert isinstance(s, sympy.Symbol) + # Pick the first definer as canonical. There may be multiple + # because if a MultiOutputLayout buffer propagates an unbacked + # symint to multiple outputs, they will all claim to def it. + if s not in unbacked_symbol_to_origin_node: + unbacked_symbol_to_origin_node[s] = node + + unbacked_symbol_uses = sorted( + node.node.get_unbacked_symbol_uses(), key=lambda x: x.name + ) + # if a kernel takes unbacked symints, register dependencies + for s in unbacked_symbol_uses: + assert ( + s in unbacked_symbol_to_origin_node + ), f"{s} not in {unbacked_symbol_to_origin_node}" + node.add_fake_dep(StarDep(unbacked_symbol_to_origin_node[s].get_name())) + + # a node will mutate either 0 or 1 buffers + assert len(node.get_mutations()) <= 1 + for alt_name in node.get_mutations(): + alt_name = rename(alt_name) + # this node must run after the prior writer + add_user(alt_name, node) + node.add_mutation_dep(StarDep(alt_name)) + for other_node in name_to_users[alt_name].items: + # this node must run after all prior readers + other_name = rename(other_node.get_name()) + known_dep_node_names = dep_closure(node.get_name()) + if other_name not in known_dep_node_names: + # If this node already directly or indirectly depends on other_node, + # we don't need to insert an extra dep. + node.add_mutation_dep(WeakDep(other_name)) + add_user(other_name, node, is_weak=True) + + # add normal non-mutation dependencies + for read in node.read_writes.reads: + is_weak = isinstance(read, WeakDep) + add_user(read.name, node, node.can_inplace(read), is_weak) + + node.update_mutated_names(self.mutation_renames) + + # update our renaming scheme for the next iteration + for alt_name in node.get_mutations(): + self.mutation_renames[rename(alt_name)] = node.get_name() + self.mutation_renames[alt_name] = node.get_name() + self.mutation_real_name[node.get_name()] = self.mutation_real_name.get( + alt_name, alt_name + ) + + # make sure outputs aren't dead-code-eliminated + for node_name in V.graph.get_output_names(): + log.debug("scheduling output %s", node_name) + add_user(node_name, OutputNode(StarDep(node_name))) + + # make sure unbacked symints aren't dead-code-eliminated + for node in V.graph.graph_outputs: + for s in node.get_unbacked_symbol_uses(): + assert ( + s in unbacked_symbol_to_origin_node + ), f"{s} not in {unbacked_symbol_to_origin_node.keys()}" + node_name = unbacked_symbol_to_origin_node[s].node.name + log.debug("scheduling output %s for unbacked symint %s", node_name, s) + add_user(node_name, OutputNode(StarDep(node_name))) + + # make sure input mutation isn't dead-code-eliminated + for name in self.mutation_renames: + if name in V.graph.graph_inputs: + add_user(name, OutputNode(StarDep(name))) + V.graph.mutated_inputs.add(name) + + inp_names = { + name: index for index, name in enumerate(V.graph.graph_inputs.keys()) + } + V.graph.mutated_input_idxs = [ + inp_names[name] for name in V.graph.mutated_inputs + ] + + # copy users information onto the nodes + for node in self.nodes: + node.set_users(name_to_users[node.get_name()].items) + + # populate inverse_users + for node in self.nodes: + for user in node.users: + user.node.inverse_users.append(node) + + def compute_node_users(self): + # set up buffer name to (fused)snode mapping + buf_to_snode = {} + for node in self.nodes: + if isinstance(node, FusedSchedulerNode): + for x in node.snodes: + buf_to_snode[x.get_name()] = node + buf_to_snode[node.get_name()] = node + + for node in self.nodes: + node.node_users = [] + node.inverse_users = [] + + # compute inverse_users + for node in self.nodes: + inverse_users = [] + for dep in node.unmet_dependencies: + assert dep.name in buf_to_snode + dep_node = buf_to_snode[dep.name] + inverse_users.append(dep_node) + node.inverse_users = inverse_users + + # compute node_users + # TODO: ideally, we should deduplicate .users and .node_users, + # but currently .users contains extra information that's difficult to + # extract into a standalone container. + node_to_users: Dict[BaseSchedulerNode, List[BaseSchedulerNode]] = {} + for node in self.nodes: + for inverse_user in node.inverse_users: + node_to_users.setdefault(inverse_user, []).append(node) + for node, users in node_to_users.items(): + node.node_users = users + + def dead_node_elimination(self): + """ + Remove any nodes without users + """ + again = True # repeat until a fixed point + while again: + updated_nodes = [] + for node in self.nodes: + + def can_eliminate_user(user: NodeUser): + return user.is_weak or user.get_name() in V.graph.removed_buffers + + can_eliminate = not node.has_side_effects() and all( + can_eliminate_user(u) for u in node.users + ) + + if not can_eliminate: + updated_nodes.append(node) + else: + # dead code + log.debug("removed dead node: %s", node.get_name()) + V.graph.removed_buffers.add(node.get_name()) + + again = len(self.nodes) > len(updated_nodes) + self.nodes = updated_nodes + + # Prune any WeakDeps no longer needed + for node in self.nodes: + node.prune_weak_deps() + + def topological_sort_schedule(self): + """ + Ensure self.nodes is in topologically sorted order + """ + seen: Set[ir.Buffer] = set() + name_to_node: Dict[str, ir.Buffer] = dict() + result: List[ir.Buffer] = [] + + def visit(n): + if n not in seen: + seen.add(n) + for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): + visit(name_to_node[dep.name]) + result.append(n) + + for node in self.nodes: + for name in node.get_names(): + name_to_node[name] = node + for node in self.nodes: + visit(node) + self.nodes = result + + def compute_ancestors(self): + """ + Populate each node.ancestors + """ + # note self.nodes is topologically sorted + name_to_ancestors: Dict[str, Set[str]] = {} + for node in self.nodes: + ancestors = set() + for dep in node.unmet_dependencies: + ancestors.add(dep.name) + ancestors |= name_to_ancestors[dep.name] + name_to_ancestors[node.get_name()] = ancestors + node.ancestors = ancestors + + for order, node in enumerate(self.nodes): + node.min_order = order + node.max_order = order + + def fuse_nodes(self): + """ + Mutates self.nodes to combine nodes into FusedSchedulerNodes. + """ + for i in range(10): + old_len = len(self.nodes) + fusion_log.debug( + "===== attempting fusion (%d/10): %d nodes =====", i + 1, old_len + ) + self.fuse_nodes_once() + new_len = len(self.nodes) + fusion_log.debug( + "completed fusion round (%d/10): fused %d nodes into %d nodes\n", + i + 1, + old_len, + new_len, + ) + if new_len == old_len or new_len == 1: + fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) + break + + def benchmark_fused_nodes(self, nodes): + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + assert len(nodes) > 0 + device = nodes[0].get_device() + V.graph.scheduler = self + self.current_device = device + backend = self.get_backend(device) + return backend.benchmark_fused_nodes(nodes) + + def speedup_by_fusion(self, node1, node2): + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + if not config.benchmark_fusion: + return True + + if ( + node1.is_template() + and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer) + or node1.is_foreach() + or node2.is_foreach() + ): + # TODO support benchmarking epilogue fusion + return True + + node_list_1 = node1.get_nodes() + device = node_list_1[0].get_device() + + # don't support benchmark fusion for CPU right now. + if device.type == "cpu": + return True + + node_list_2 = node2.get_nodes() + node_list_fused = node_list_1 + node_list_2 + + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + # Skip benchmarking them by allowing fusion. + if any( + hasattr(n.node, "data") + and hasattr(n.node.data, "scatter_mode") + and n.node.data.scatter_mode == "atomic_add" + for n in node_list_fused + ): + return True + + from triton.compiler.errors import CompilationError + + why = WhyNoFuse(node1, node2) + + try: + ms1, path1 = self.benchmark_fused_nodes(node_list_1) + if math.isinf(ms1): + why("register spilling of the first kernel") + return False + ms2, path2 = self.benchmark_fused_nodes(node_list_2) + if math.isinf(ms2): + why("register spilling of the second kernel") + return False + ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused) + if math.isinf(ms_fused): + why("register spilling of the fused kernel") + return False + except CompilationError as e: + # workaround triton issue: https://github.com/openai/triton/issues/2151 + if "Loop-carried variable" in str(e): + return True # allow fusion + else: + raise + + if fusion_log.isEnabledFor(logging.DEBUG): + if ms_fused < ms1 + ms2: + fusion_log.debug( + "can fuse (benchmark): fusing %s with %s cause %sx speedup", + node1.get_names(), + node2.get_names(), + green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), + ) + else: + fusion_log.debug( + "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown", + node1.get_names(), + node2.get_names(), + red_text(f"{ms_fused / (ms1 + ms2):.3f}"), + ) + + if ( + is_metric_table_enabled("slow_fusion") + and ms_fused >= ms1 + ms2 + and (path1, path2) not in self.logged_slow_fusion + ): + self.logged_slow_fusion.add((path1, path2)) + get_metric_table("slow_fusion").add_row( + lambda: { + "kernel1_path": path1, + "kernel1_latency": ms1, + "kernel2_path": path2, + "kernel2_latency": ms2, + "fused_kernel_path": path_fused, + "fused_kernel_latency": ms_fused, + "slow_down_ratio": ms_fused / (ms1 + ms2), + } + ) + return ms_fused < ms1 + ms2 + + def fuse_nodes_once(self): + """ + Mutates self.nodes to combine nodes into FusedSchedulerNodes. + + This relies on two key functions to control the logic: + - self.can_fuse(): checks if a fusion is legal + - self.score_fusion(): assigns priority to a given fusion + """ + fused_nodes = set(self.nodes) + for node1, node2 in self.get_possible_fusions(): + node1 = self.name_to_fused_node[node1.get_first_name()] + node2 = self.name_to_fused_node[node2.get_first_name()] + if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( + node1, node2 + ): + if not self.speedup_by_fusion(node1, node2): + continue + fusion_log.debug( + "fusing %s with %s", node1.get_name(), node2.get_name() + ) + + # above can_fuse asserts that node2 has the same device + device = node1.get_device() + node3 = self.get_backend(device).fuse(node1, node2) + fused_nodes.remove(node1) + fused_nodes.remove(node2) + fused_nodes.add(node3) + self.name_to_fused_node.update( + {n.get_name(): node3 for n in node3.get_nodes()} + ) + self.nodes = sorted(fused_nodes, key=lambda x: x.min_order) + self.topological_sort_schedule() + self.prune_redundant_deps() + + def prune_redundant_deps(self): + for node in self.nodes: + node.prune_redundant_deps(self.name_to_fused_node) + + def get_possible_fusions(self): + """ + Helper to find all legal fusion opportunities, sorted by self.score_fusion() + """ + possible_fusions = [] + seen = set() + + def check_all_pairs(nodes): + for node1_index, node1 in enumerate(nodes): + for node2 in nodes[node1_index + 1 :]: + key = (node1, node2) + if key in seen: + continue + seen.add(key) + + if self.can_fuse(node1, node2): + possible_fusions.append(key) + elif (node2.is_template() or node2.is_foreach()) and self.can_fuse( + node2, node1 + ): + # foreach fusions and epilogue fusions are order dependent + possible_fusions.append((node2, node1)) + + buffer_names_grouping = collections.defaultdict(list) + for node in self.nodes: + for buf in node.used_buffer_names(): + buffer_names_grouping[buf].append(node) + for node_grouping in buffer_names_grouping.values(): + check_all_pairs(node_grouping) + + if config.aggressive_fusion: + group_grouping = collections.defaultdict(list) + for node in self.nodes: + group = getattr(node, "group", None) + if group: + group_grouping[group].append(node) + for node_grouping in group_grouping.values(): + check_all_pairs(node_grouping) + + possible_fusions.sort(key=self.score_fusion_key, reverse=True) + fusion_log.debug("found %d possible fusions", len(possible_fusions)) + return possible_fusions + + def will_fusion_create_cycle(self, node1, node2): + """ + Finds whether there's a path from node1 to node2 (or vice-versa) + caused indirectly by other fusions. + """ + + def found_path(node): + # only fused nodes can introduce new ancestors. + if isinstance(node, FusedSchedulerNode) and node not in visited: + visited.add(node) + if node.get_names().issubset(combined_ancestors): + # All fusion outputs are in ancestors of node1 and node2, thus + # cannot introduce new path: + # + # 1. if output is neither descendent of node1 or node2, the + # output cannot introduce a path + # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be + # on path(node1->node2), hence it cannot be ancestor of node2 + # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be + # ancestor of node1 + return False + else: + # continue DFS of new ancestors introduced by the fusion + return bool(combined_names & node.ancestors) or any( + found_path(self.name_to_fused_node[n]) + for n in node.ancestors - combined_ancestors + ) + return False + + visited = set() + combined_names = node1.get_names() | node2.get_names() + combined_ancestors = (node1.ancestors | node2.ancestors) - combined_names + cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors) + if cycle: + WhyNoFuse(node1, node2)("will create cycle") + return cycle + + def can_fusion_increase_peak_memory( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ): + """ + This function prevents fusion for nodes that can increase memory + footprint. This problem is more common in horizontal fusion, where nodes + that are far apart in the original order get fused, lengthening the live + intervals of tensors. This is very evident in models with activation + checkpointing, where the recomputed nodes from different checkpointed + regions get fused and significantly increase the memory footprint. + + The current attempt is a quick, possibly hacky, heuristic to prevent the + fusion of nodes that are far away in the original order. + + A better but difficult to implement heurisitic would be to use live + intervals of the buffers, find region of peak pressure in the original + program and prevent fusion that crosses that peak region. We might need + special care or good approximation in this implementation, as fusion of + node changes live intervals, and re-computing live intervals and peak + memory after each fusion can introduce large compilation overhead. + """ + proximity_score = max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + return proximity_score > 64 + + def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + """ + Determine if it is possible to combine node1 and node2 into a + single fused node. + """ + + if node1 is node2: + return False + + why = WhyNoFuse(node1, node2) + + if ( + isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node1.is_template() + ): + why("node1 is extern or nop") + return False + if ( + isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node2.is_template() + ): + why("node2 is extern or nop") + return False + + if node2.get_names() & node1.ancestors: + why("node1 must go before node2") + return False + + if ( + isinstance(node1, (FusedSchedulerNode, SchedulerNode)) + and isinstance(node2, SchedulerNode) + and isinstance(node2._body, ir.LoopBody) + ): + # Fix issue: https://github.com/pytorch/pytorch/issues/108963 + # Check: + # If node2 reads a buf which is a mutation buf of node1(SchedulerNode) or among nodes in node1(FusedSchedulerNode), + # we will get the corresponding mutation buf and check if this mutation buf is stored by atomic_add mode. + # If True, we will disable the fusion of node1 and node2. + if any( + ( + node2_used_buf in self.mutation_renames + and node1.has_atomic_add(self.mutation_renames[node2_used_buf]) + ) + for node2_used_buf in node2._body.reads_name2expr.keys() + ): + return False + + if node2.is_template(): + why("templates can only fuse epilogues") + return False + if node1.is_template() and ( + node2.has_aliasing_or_mutation() + or node2.is_reduction() + or not config.epilogue_fusion + ): + why("template epilogue not satisfied") + return False + + device = node1.get_device() + device2 = node2.get_device() + if device != device2: + why("device mismatch (%s vs %s)", device, device2) + return False + del device2 + + no_shared_data = self.score_fusion_memory(node1, node2) == 0 + if no_shared_data and ( + not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() + ): + why("no shared data") + return False # heuristic not needed for correctness + + if ( + not node1.is_foreach() + and not node2.is_foreach() + and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size + ): + why("exceeds max fusion") + return False # heuristic not needed for correctness + + if node1.get_names() & node2.ancestors: + # node2 depends on node1 outputs + if not self.can_fuse_vertical(node1, node2): + return False + return self.get_backend(device).can_fuse_vertical(node1, node2) + else: # nodes don't depend on each other, but may have common reads + if self.can_fusion_increase_peak_memory(node1, node2): + why("will increase peak memory") + return False + return self.get_backend(device).can_fuse_horizontal(node1, node2) + + def can_fuse_vertical(self, node1, node2): + """ + Check if it is legal to fuse a consumer (node2) into a producer (node1). + + We can fuse them if all the reads of node2 either match + corresponding writes in node1, or are written by nodes that can + be scheduled before the fusion of node1 and node2. + + We also disable fusion of a write subsequent to a read if the reads + and writes do not align. + """ + node1_names = node1.get_names() + computed_deps = set() + why = WhyNoFuse(node1, node2) + + # StarDep doesn't match MemoryDep, different indices don't match + # However, broadcasting sometimes strips dimensions, and if that's the case + # we still can match unmet dep + # if there's indirect indexing, don't match it + def fusable_read_and_write(read: Dep, write: Dep): + return ( + self.mutation_renames.get(read.name, read.name) == write.name + and (isinstance(read, MemoryDep) and isinstance(write, MemoryDep)) + and not free_symbol_has(read.index, "tmp") + and not free_symbol_has(write.index, "tmp") + and read.index == write.index + and len(read.size) >= len(write.size) + and read.size[: len(write.size)] == write.size + ) + + for rd in node2.unmet_dependencies: + for cd in node1.read_writes.writes: + if fusable_read_and_write(rd, cd): + computed_deps.add(rd) + + remaining_deps = {dep.name for dep in node2.unmet_dependencies - computed_deps} + if remaining_deps & node1_names: + # MemoryDeps didn't match and read different locations of the same buffer. + # Examples here include: + # - MemoryDep("foo", x) != MemoryDep("foo", x + 1) + # - MemoryDep("foo", x) != StarDep("foo") + why("memory deps did not match") + return False + for name in remaining_deps: + if node1_names & self.name_to_fused_node[name].ancestors: + why("intermediate nodes between node1 & node2") + return False + + # similar to can_inplace, if we are going to fuse a write subsequent to a read + # require that the indexing and size is the same + for write in node2.read_writes.writes: + for read in node1.read_writes.reads: + if write.name != self.mutation_renames.get(read.name, read.name): + continue + + # bail on StarDep + if not fusable_read_and_write(read=read, write=write): + why("fusing a write into a read with different indexing formula") + return False + + return True + + def score_fusion(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + """ + Assign a score (higher comes first) to the fusion of node1 + and node2. When different fusions conflict with each other, + this is the way we decide what order to run them in. + + Our current score is based on: + - Estimate of the saved memory operations + - Fusions closer together in original order + """ + memory_score = self.score_fusion_memory(node1, node2) + proximity_score = -max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + return ( + node1.is_template() == config.epilogue_fusion_first and memory_score > 0, + node1.is_reduction() == node2.is_reduction() and memory_score > 0, + memory_score, + proximity_score, + ) + + def score_fusion_memory(self, node1, node2): + """ + The first term in our fusion score that estimates number of saved memory operations. + """ + common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( + node2.read_writes.reads | node2.read_writes.writes + ) + common_memory_deps = { + dep for dep in common_memory_deps if not dep.has_unbacked_symbols() + } + return sum(dep.numbytes_hint() for dep in common_memory_deps) + + def score_fusion_key(self, nodes): + """ + Shim for list.sort(key=...) + """ + node1, node2 = nodes + return self.score_fusion(node1, node2) + + def compute_last_usage(self): + """ + Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) + """ + + future_used_buffers = set() + for node_name in V.graph.get_output_names(): + future_used_buffers.add(node_name) + + for node in reversed(self.nodes): + node.set_last_usage(future_used_buffers, self.mutation_real_name) + future_used_buffers.update(node.last_usage) + + def free_buffers(self): + """Free any buffers that are no longer needed""" + for name in sorted( + self.buffer_names_to_free + - V.graph.removed_buffers + - V.graph.wrapper_code.freed + ): + if name in self.name_to_node: + node = self.name_to_node[name] + if node.can_free(): + V.graph.wrapper_code.codegen_free(node.node) + elif name in V.graph.graph_inputs: + storage = V.graph.graph_inputs[name].data + assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer() + V.graph.wrapper_code.codegen_free(storage.data) + + self.buffer_names_to_free.clear() + + def remove_kernel_local_buffers(self): + """ + Any buffers that are both created and have a last use in the + same kernel can be removed. + """ + + # V.kernel.store_buffer_names should represent the set of nodes + # get fused + fused_node_names = V.kernel.store_buffer_names + names_to_remove = [] + for out_buf in V.kernel.store_buffer_names: + users = self.name_to_node[out_buf].users + assert users is not None + users = {user.get_name() for user in users if not user.is_weak} + if users.issubset(fused_node_names): + names_to_remove.append(out_buf) + + def remove_filter(n): + return ( + n not in V.kernel.must_keep_buffers + and n not in V.kernel.args.input_buffers + and n not in self.mutation_renames + and n not in self.mutation_real_name + ) + + names_to_remove = list(filter(remove_filter, names_to_remove)) + + for name in names_to_remove: + if name in V.kernel.args.inplace_buffers: + buf = V.kernel.args.inplace_buffers[name] + if isinstance(buf, str) and buf.startswith("REMOVED"): + continue + remove = all(n in names_to_remove for n in buf.other_names) + if remove: + self.remove_inplace_buffer(name) + V.kernel.inplaced_to_remove.add(name) + else: + self.remove_buffer(name) + + def remove_buffer(self, name): + # Assign a special value instead of deleting the entry + # because we still rely on output_buffers's length to + # generate unique arg name. + log.debug("remove_buffer(%r)", name) + V.kernel.args.output_buffers[name] = "REMOVED" + V.kernel.removed_buffers.add(name) + + def remove_inplace_buffer(self, name): + log.debug("removing_inplace_buffer(%r)", name) + inner_name = V.kernel.args.inplace_buffers[name].inner_name + V.kernel.args.inplace_buffers[name] = inner_name.replace( + "in_out_ptr", "REMOVED" + ) + V.kernel.removed_buffers.add(name) + + def flush(self): + for backend in self.backends.values(): + backend.flush() + self.free_buffers() + + def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode): + assert isinstance(scheduler_node, ExternKernelSchedulerNode) + # 'decide_inplace_update' stores the inplace update decisions in + # the current kernel from where 'allocate' retrieve those decisions. + # We have to make sure there is a non-NULL kernel handler to store + # those inplace update decisions. + with V.set_kernel_handler(Kernel(increase_kernel_count=False)): + scheduler_node.decide_inplace_update() + scheduler_node.allocate() + node = scheduler_node.node + assert isinstance(node, ir.ExternKernel), f"{type(node)=}" + node.codegen(V.graph.wrapper_code) + self.free_buffers() + + def create_backend(self, device: torch.device): + assert ( + device.type != "cuda" or device.index is not None + ), f"{device} should have been normalized in lowering" + V.graph.add_device_info(device) + + device_scheduling = get_scheduling_for_device(device.type) + if device_scheduling is None: + raise RuntimeError(f"Unsupported device type: {device.type}") + + if device.type == "cuda" and not has_triton(): + device_props = torch.cuda.get_device_properties(device) + if device_props.major < 7: + raise RuntimeError( + f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950 + ) + else: + raise RuntimeError( + "Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950 + ) + + return device_scheduling(self) + + def get_backend(self, device: torch.device): + if device not in self.backends: + self.backends[device] = self.create_backend(device) + return self.backends[device] + + def enter_context(self, node): + def get_order(n): + if n not in self.origin_to_index: + self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) + return self.origin_to_index[n] + + # Use a dict to have ordering + origins = { + (get_order(e), e): None for n in node.get_nodes() for e in n.node.origins + } + origins = list(origins.keys()) + if origins: + _, last = max(origins, key=operator.itemgetter(0)) + V.graph.wrapper_code.enter_context(last) + + @dynamo_timed + def codegen(self): + for node in self.nodes: + try: + log.debug( + "Generating code for node %s with estimated runtime %f", + node.get_name(), + node.get_estimated_runtime(), + ) + except Exception as e: + log.debug( + "Generating code for node %s with estimated runtime 0.0", + node.get_name(), + ) + + self.enter_context(node) + + if not isinstance(node, NopKernelSchedulerNode): + device = node.get_device() + if ( + device != self.current_device + or node.is_extern() + or node.is_template() + ): + self.flush() + if device != self.current_device: + if device.type == "cuda": + if self.current_device and self.current_device.type == "cuda": + V.graph.wrapper_code.codegen_device_guard_exit() + assert device.index is not None, "device should have an index" + V.graph.wrapper_code.codegen_device_guard_enter(device.index) + elif self.current_device and self.current_device.type == "cuda": + V.graph.wrapper_code.codegen_device_guard_exit() + self.current_device = device + + self.buffer_names_to_free.update(node.last_usage) + + if node.is_template(): + node, *epilogue = node.get_nodes() + self.get_backend(device).codegen_template(node, epilogue) # type: ignore[possibly-undefined] + elif node.is_extern(): + self.codegen_extern_call(node) + elif node.is_foreach(): + self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined] + elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): + self.get_backend(device).codegen_nodes(node.get_nodes()) # type: ignore[possibly-undefined] + else: + assert isinstance(node, NopKernelSchedulerNode) + node.allocate() + + if config.debug_check_inf_and_nan: + V.graph.wrapper_code.generate_inf_and_nan_checker(node) + + if config.triton.debug_sync_kernel: + self.get_backend(device).codegen_sync() # type: ignore[possibly-undefined] + + self.available_buffer_names.update(node.get_names()) + + if not isinstance(node, NopKernelSchedulerNode): + device = node.get_device() + if self.get_backend(device).ready_to_flush(): + self.flush() + + if self.current_device and self.current_device.type == "cuda": + # exit the outermost CUDA device guard. this is + # important for nested indentation codegen-ing. + V.graph.wrapper_code.codegen_device_guard_exit() + + self.flush() + + def is_unaligned_buffer(self, buf_name): + if buf_name in V.graph.graph_inputs or buf_name in V.graph.constants: + # all graph inputs or constants are assumed to be aligned + return False + node = self.name_to_node[buf_name] + layout = node.node.get_layout() + if isinstance(layout, ir.AliasedLayout): + return not layout.maybe_guard_aligned() + else: + return False + + +class BaseScheduling: + def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + """ + Check whether node1 and node2 can be vertically fused or not. + """ + raise NotImplementedError() + + def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + """ + Check whether node1 and node2 can be horizontally fused or not. + """ + raise NotImplementedError() + + def fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + """ + Fuse two nodes + """ + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def group_fn(self, sizes): + """ + Process the iteration sizes in case a transformation needs to be applied. + """ + raise NotImplementedError() + + def codegen_template( + self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode] + ): + """ + Given a template node, generate a kernel. + + This function is only available for triton now. If the third-party backend behaves as a sub-class + of TritonScheduling, it can override it or reuse it. + """ + raise NotImplementedError() + + def codegen_nodes(self, nodes: List[SchedulerNode]): + """ + Generate a kernel given a list of pre-fused nodes. + """ + raise NotImplementedError() + + def codegen_sync(self): + """ + Generate synchronization code for the kernel. This method depends on the hardware characteristics. + """ + raise NotImplementedError() + + def ready_to_flush(self) -> bool: + """ + Check whether the backend is requesting the scheduler to flush the generated kernel. + If not supported, please return False. + """ + return False + + def flush(self): + """ + Flush the generated kernel and python wrapper code to the source code file. + """ + raise NotImplementedError() + + def benchmark_fused_nodes(self, nodes): + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + raise NotImplementedError() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..6c80626ed12ad696b6f75759844b235d5fba252e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py @@ -0,0 +1,1156 @@ +import builtins +import functools +import inspect +import itertools +import logging +import operator +import sys +import textwrap +import time +from concurrent.futures import ThreadPoolExecutor +from io import StringIO + +from typing import Any, Callable, Dict, List, Optional, Union +from unittest.mock import patch + +import sympy + +import torch +from torch._dynamo.testing import rand_strided +from torch._dynamo.utils import counters, identity, preserve_rng_state + +from . import config, ir +from .autotune_process import TensorMeta, TritonBenchmarkRequest +from .codecache import code_hash, PersistentCache, PyCodeCache +from .codegen.common import ( + ChoiceCaller, + IndentedBuffer, + KernelTemplate, + PrimitiveInfoType, +) +from .codegen.triton import ( + gen_common_triton_imports, + texpr, + TritonKernel, + TritonPrinter, + TritonScheduling, +) +from .codegen.triton_utils import config_of, signature_to_meta +from .exc import CUDACompileError +from .utils import ( + do_bench, + get_dtype_size, + Placeholder, + sympy_dot, + sympy_product, + unique, +) +from .virtualized import V + +log = logging.getLogger(__name__) + +# correctness checks struggle with fp16/tf32 +VERIFY: Dict[str, Any] = dict() +PRINT_AUTOTUNE = True +DEBUG = False + + +class KernelNamespace: + pass + + +# these objects are imported from the generated wrapper code +extern_kernels = KernelNamespace() + + +class PartialRender: + """ + Some parts of a template need to be generated at the end, but + inserted into the template at the start. This allows doing a bunch + of replacements after the initial render. + """ + + def __init__(self, code, replacement_hooks): + super().__init__() + self.code = code + self.replacement_hooks = replacement_hooks + + def finalize(self): + code = self.code + assert code is not None, "can only be called once" + self.code = None + for key, fn in self.replacement_hooks.items(): + code = code.replace(key, fn()) + return code + + +class TritonTemplateKernel(TritonKernel): + def __init__( + self, + kernel_name, + input_nodes, + output_node, + defines, + num_stages, + num_warps, + grid_fn, + meta, + call_sizes, + use_jit=True, + prefix_args=0, + suffix_args=0, + epilogue_fn=identity, + *, + index_dtype, + ): + super().__init__( + sympy_product(output_node.get_size()), + sympy.Integer(1), + index_dtype=index_dtype, + ) + self.input_nodes = input_nodes + self.output_node = output_node + self.named_input_nodes = {} + self.defines = defines + self.kernel_name = kernel_name + self.template_mask = None + self.use_jit = use_jit + self.num_stages = num_stages + self.num_warps = num_warps + self.grid_fn = grid_fn + self.meta = meta + self.call_sizes = call_sizes + # for templates with fixed epilogues + self.prefix_args = prefix_args + self.suffix_args = suffix_args + self.epilogue_fn = epilogue_fn + self.render_hooks = dict() + self.triton_meta: Optional[Dict[str, object]] = None + + def need_numel_args(self): + return False + + def estimate_kernel_num_bytes(self): + """ + Estimate the total number of bytes this kernel takes. + For in/out nodes, sizes are counted twice: once for reading and + once for writing. + """ + ninplace_args = len(unique(self.args.inplace_buffers.values())) + num_bytes = [] + for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): + size = V.graph.sizevars.size_hints(inp.get_size()) + numel = functools.reduce(operator.mul, size) + dtype_size = get_dtype_size(inp.get_dtype()) + num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(num_bytes) + + def jit_lines(self): + if self.use_jit: + return "@triton.jit" + + argdefs, _, signature = self.args.python_argdefs() + triton_meta = { + "signature": signature_to_meta(signature, size_dtype=self.index_dtype), + "device": V.graph.scheduler.current_device.index, + "device_type": V.graph.scheduler.current_device.type, + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] + triton_meta["constants"][arg_num] = 1 # type: ignore[index] + self.triton_meta = triton_meta + + inductor_meta = { + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + } + if config.profile_bandwidth or config.benchmark_kernel: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + return f""" + @triton_heuristics.template( + num_stages={self.num_stages}, + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + ) + @triton.jit + """ + + def def_kernel(self, *argnames): + """ + Hook called from template code to generate function def and + needed args. + """ + assert all(isinstance(x, str) for x in argnames) + renames = IndentedBuffer(initial_indent=1) + + named_args = self.input_nodes[ + self.prefix_args : len(self.input_nodes) - self.suffix_args + ] + + assert len(argnames) == len(named_args), ( + len(argnames), + len(named_args), + self.prefix_args, + len(self.input_nodes), + ) + + for input_node in self.input_nodes[: self.prefix_args]: + # get args in correct order + self.args.input(input_node.get_name()) + + for name, input_node in zip(argnames, named_args): + arg_name = f"arg_{name}" + self.named_input_nodes[name] = input_node + self.args.input_buffers[input_node.get_name()] = arg_name + + # The args may be duplicated, so renaming must be after args are de-duplicated. + for name in argnames: + input_node = self.named_input_nodes[name] + arg_name = self.args.input_buffers[input_node.get_name()] + if input_node.get_layout().offset == 0: + renames.writeline(f"{name} = {arg_name}") + else: + offset = texpr(self.rename_indexing(input_node.get_layout().offset)) + renames.writeline(f"{name} = {arg_name} + {offset}") + + for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]: + # get args in correct order + self.args.input(input_node.get_name()) + + def hook(): + # python_argdefs() cannot be run until after the rest of the template lazily adds more args + arg_defs, *_ = self.args.python_argdefs() + code = IndentedBuffer() + code.splice(gen_common_triton_imports()) + code.splice(self.jit_lines()) + code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):") + with code.indent(): + code.splice(self.defines) + code.splice(renames.getvalue()) + return code.getvalue() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def size(self, name: str, index: int): + """ + Hook called from template code to get the size of an arg. + Will add needed args to pass it in if it is dynamic. + """ + assert isinstance(index, int) + if name is None: + val = self.output_node.get_size()[index] + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_size()[index] + return texpr(self.rename_indexing(val)) + + def stride(self, name, index): + """ + Hook called from template code to get the stride of an arg. + Will add needed args to pass it in if it is dynamic. + """ + assert isinstance(index, int) + if name is None: + val = self.output_node.get_stride()[index] + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_stride()[index] + return texpr(self.rename_indexing(val)) + + def store_output(self, indices, val, mask): + """ + Hook called from template code to store the final output + (if the buffer hasn't been optimized away), then append any + epilogue fusions. + """ + assert isinstance(indices, (list, tuple)) + assert isinstance(val, str) + assert isinstance(mask, str) + assert self.template_mask is None + indices = list(map(TritonPrinter.paren, indices)) + index_symbols = [sympy.Symbol(x) for x in indices] + lengths = [V.graph.sizevars.simplify(s) for s in self.output_node.get_size()] + assert len(indices) == len(lengths) + + # glue to make generated code use same indexing from template + for name, range_tree_entry in zip( + indices, self.range_trees[0].construct_entries(lengths) + ): + range_tree_entry.set_name(name) + contiguous_index = sympy_dot( + ir.FlexibleLayout.contiguous_strides(lengths), index_symbols + ) + contiguous_index = self.rename_indexing(contiguous_index) + self.body.writeline("xindex = " + texpr(contiguous_index)) + self.range_trees[0].lookup(sympy.Integer(1), sympy_product(lengths)).set_name( + "xindex" + ) + self.template_mask = mask + self.template_indices = indices + output_index = self.output_node.get_layout().make_indexer()(index_symbols) + output_index = self.rename_indexing(output_index) + if output_index == contiguous_index: + output_index = sympy.Symbol("xindex") + + epilogue_args = [val] + for input_node in itertools.chain( + self.input_nodes[: self.prefix_args], + self.input_nodes[len(self.input_nodes) - self.suffix_args :], + ): + input_node.freeze_layout() + epilogue_args.append(input_node.make_loader()(index_symbols)) + + V.ops.store( + self.output_node.get_name(), + output_index, + self.epilogue_fn(*epilogue_args), + ) + self.codegen_body() + + def hook(): + # more stuff might have been added since the codegen_body above + self.codegen_body() + return textwrap.indent(self.body.getvalue(), " ").strip() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def render(self, template, kwargs): + return PartialRender( + template.render(**self.template_env(), **kwargs), + self.render_hooks, + ) + + def make_load(self, name, indices, mask): + """ + Optional helper called from template code to generate the code + needed to load from an tensor. + """ + assert isinstance(indices, (list, tuple)) + assert isinstance(name, str) + assert isinstance(mask, str) + stride = self.named_input_nodes[name].get_stride() + indices = list(map(TritonPrinter.paren, indices)) + assert len(indices) == len(stride) + index = " + ".join( + f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) + ) + return f"tl.load({name} + ({index}), {mask})" + + def template_env(self): + """ + Generate the namespace visible in the template. + """ + return { + fn.__name__: fn + for fn in [ + self.def_kernel, + self.size, + self.stride, + self.store_output, + self.make_load, + ] + } + + def indexing( + self, + index: sympy.Expr, + *, + dense_indexing=False, + copy_shape=None, + override_mask=None, + block_ptr=False, + ): + """ + Override the default indexing to use our custom mask and force + dense indexing. + """ + return super().indexing( + index, + dense_indexing=False, + copy_shape=self.template_mask, + override_mask=self.template_mask, + block_ptr=block_ptr, + ) + + def initialize_range_tree(self, pid_cache): + super().initialize_range_tree(pid_cache) + # ignore default codegen + self.body.clear() + self.indexing_code.clear() + + def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): + wrapper = V.graph.wrapper_code + _, call_args, _ = self.args.python_argdefs() + call_args = [str(a) for a in call_args] + + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + if isinstance(call_args[i], sympy.Symbol): + call_args[i] = texpr(call_args[i]) + + if V.graph.cpp_wrapper: + # In the cpp_wrapper case, we have to compute CUDA launch grid at runtime + # if any dynamic dimension is involved. We rely on the Python version + # of the grid function to generate those grid configs, which may contain + # symbolic values. The wrapper will use cexpr to print out C++ code + # appropriately for the grid configs. + grid_args = [V.graph.sizevars.simplify(s) for s in self.call_sizes] + [ + self.meta + ] + grid = self.grid_fn(*grid_args) + + wrapper.generate_kernel_call( + name, + call_args, + device_index=V.graph.scheduler.current_device.index, + grid=grid, + triton_meta=self.triton_meta, + ) + else: + stream_name = wrapper.write_get_raw_stream( + V.graph.scheduler.current_device.index + ) + + wrapper.add_import_once(f"import {self.grid_fn.__module__}") + meta = wrapper.add_meta_once(self.meta) + + grid_call = [ + texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes + ] + [meta] + grid_call = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})" + wrapper.writeline( + f"{name}.run({', '.join(call_args)}, grid={grid_call}, stream={stream_name})" + ) + + +@functools.lru_cache(None) +def _jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class TritonTemplate(KernelTemplate): + index_counter = itertools.count() + all_templates: Dict[str, "TritonTemplate"] = dict() + + def __init__(self, name: str, grid: Any, source: str, debug=False): + super().__init__(name) + self.grid = grid + self.template = self._template_from_string(source) + assert name not in self.all_templates, "duplicate template name" + self.all_templates[name] = self + self.debug = debug + + def generate( + self, + input_nodes, + layout, + num_stages, + num_warps, + prefix_args=0, + suffix_args=0, + epilogue_fn=identity, + **kwargs, + ): + assert self.template, "requires jinja2" + defines = StringIO() + for name, val in kwargs.items(): + defines.write(f" {name} : tl.constexpr = {val}\n") + defines = defines.getvalue() + + fake_out = ir.Buffer("buf_out", layout) + kernel_name = f"triton_{self.name}" + + numel = sympy_product(layout.size) + buffers = itertools.chain(input_nodes, (fake_out,)) + if not TritonScheduling.can_use_32bit_indexing(numel, buffers): + raise NotImplementedError( + "64-bit indexing is not yet implemented for triton templates" + ) + + kernel_options = dict( + input_nodes=input_nodes, + defines=defines, + num_stages=num_stages, + num_warps=num_warps, + grid_fn=self.grid, + meta=kwargs, + call_sizes=layout.size, + prefix_args=prefix_args, + suffix_args=suffix_args, + epilogue_fn=epilogue_fn, + index_dtype="tl.int32", + ) + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(fake_out) + ), TritonTemplateKernel( + kernel_name=kernel_name, + output_node=fake_out, + use_jit=True, + **kernel_options, + ) as kernel: + try: + code = kernel.render(self.template, kwargs).finalize() + except ZeroDivisionError: + # TODO(nmacchioni): fix sympy division by zero + return None + if self.debug: + print("Generated Code:\n", code) + extra = ( + "-".join( + [ + *[ + f"{kwarg}={repr(kwargs[kwarg])}" + for kwarg in sorted(kwargs.keys()) + ], + f"num_stages={num_stages}", + f"num_warps={num_warps}", + ] + ) + + "-" + ) + mod = PyCodeCache.load(code, extra) + _, call_args, _ = kernel.args.python_argdefs() + + expected_args = list(unique(x.get_name() for x in input_nodes)) + expected_args.extend([fake_out.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]), + fallback=config.unbacked_symint_fallback, + ) + + kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}" + + def make_kernel_render(out_node): + kernel = TritonTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + output_node=out_node, + use_jit=False, + **kernel_options, + ) + render = functools.partial( + kernel.render, + self.template, + kwargs, + ) + return kernel, render + + # create the BenchmarkRequest + assert mod.__file__ is not None + grid = self.grid( + *V.graph.sizevars.size_hints( + layout.size, + fallback=config.unbacked_symint_fallback, + ), + kwargs, + ) + bmreq = TritonBenchmarkRequest( + module_path=mod.__file__, + module_cache_key=mod.key, + kernel_name=kernel_name, + grid=grid, + extra_args=extra_args, + num_stages=num_stages, + num_warps=num_warps, + matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0), + input_tensor_meta=TensorMeta.from_irnodes(input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(layout), + ) + + return TritonTemplateCaller( + kernel_hash_name, + input_nodes, + layout, + make_kernel_render, + extra.strip("-").replace("-", ", "), + bmreq, + log_info={ + "tile_shape": str( + ( + kwargs.get("BLOCK_M", -1), + kwargs.get("BLOCK_K", -1), + kwargs.get("BLOCK_N", -1), + ) + ), + "num_stages": num_stages, + "num_warps": num_warps, + "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), + "acc_type": str(kwargs.get("ACC_TYPE", None)), + }, + ) + + +class ExternKernelChoice: + def __init__( + self, + kernel, + cpp_kernel=None, + *, + name=None, + has_out_variant=True, + op_overload=None, + use_fallback_kernel=False, + ): + super().__init__() + name = name or kernel.__name__ + assert callable(kernel) + assert not hasattr(extern_kernels, name), "duplicate extern kernel" + self.name = name + self.cpp_kernel_name = cpp_kernel + self.has_out_variant = has_out_variant + setattr(extern_kernels, name, kernel) + self.op_overload = op_overload + self.use_fallback_kernel = use_fallback_kernel + + def to_callable(self): + return getattr(extern_kernels, self.name) + + def call_name(self): + return f"extern_kernels.{self.name}" + + @functools.lru_cache(None) + def hash_key(self): + fn = self.to_callable() + parts = [ + self.name, + getattr(fn, "__name__", ""), + getattr(fn, "__module__", ""), + ] + try: + parts.append(inspect.getsource(fn)) + except Exception: + pass + return code_hash("-".join(parts)) + + def bind( + self, + input_nodes, + layout, + ordered_kwargs_for_cpp_kernel=(), + **kwargs, + ): + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + return ExternKernelCaller( + self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant + ) + + +class TritonTemplateCaller(ChoiceCaller): + def __init__( + self, + name, + input_nodes, + layout, + make_kernel_render, + debug_extra, + bmreq, + log_info: Optional[ + Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] + ] = None, + ): + super().__init__(name, input_nodes, layout) + self.make_kernel_render = make_kernel_render + self.debug_extra = debug_extra + self.bmreq: TritonBenchmarkRequest = bmreq + if log_info is None: + log_info = {} + self.log_info: Dict[str, Any] = log_info + self.log_info.update( + { + "backend": "Triton", + "grid": str(self.bmreq.grid), + "num_stages": self.bmreq.num_stages, + "num_warps": self.bmreq.num_warps, + } + ) + + def benchmark(self, *args, out): + assert self.bmreq is not None + return self.bmreq.benchmark(*args, output_tensor=out) + + def __str__(self): + return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})" + + def call_name(self): + return f"template_kernels.{self.name}" + + def hash_key(self): + return "-".join( + [ + self.name.rsplit("_", 1)[0], + self.bmreq.module_cache_key, + ] + ) + + def output_node(self): + return ir.TensorBox.create( + ir.TritonTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + ) + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return self.log_info + + +class ExternKernelCaller(ChoiceCaller): + def __init__( + self, + choice: ExternKernelChoice, + input_nodes, + layout, + kwargs=None, + *, + has_out_variant=True, + ): + super().__init__(choice.name, input_nodes, layout) + self.choice = choice + self.kwargs = kwargs or {} + self.has_out_variant = has_out_variant + + def __str__(self): + return f"ExternKernelCaller({self.choice.call_name()})" + + def benchmark(self, *args, out): + if self.has_out_variant: + return super().benchmark(*args, out=out) + else: + algo = self.to_callable() + out_new = algo(*args) + torch._C._dynamo.guards.assert_size_stride( + out_new, tuple(out.size()), tuple(out.stride()) + ) + out.copy_(out_new) # for correctness checking + return do_bench(lambda: algo(*args)) + + def to_callable(self): + fn = self.choice.to_callable() + if self.kwargs: + return functools.partial(fn, **self.kwargs) + else: + return fn + + def hash_key(self): + return "-".join( + [ + self.choice.name, + *[ + f"{kwarg}={repr(self.kwargs[kwarg])}" + for kwarg in sorted(self.kwargs.keys()) + ], + self.choice.hash_key(), + ] + ) + + def output_node(self): + if config.abi_compatible and self.choice.use_fallback_kernel: + assert ( + self.choice.op_overload is not None + ), "Please provide an op_overload to use ir.FallbackKernel" + inner = ir.FallbackKernel.create( + self.choice.op_overload, *self.input_nodes, **self.kwargs + ) + else: + cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc + inner = cls( + layout=self.layout, + inputs=self.input_nodes, + python_kernel_name=self.choice.call_name(), + cpp_kernel_name=self.choice.cpp_kernel_name, + ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel, + op_overload=self.choice.op_overload, + kwargs=self.kwargs, + ) + + return ir.TensorBox.create(inner) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "extern", + "kernel_call_name": self.choice.call_name(), + } + + +class ErrorFromChoice(RuntimeError): + def __init__(self, msg, choice: ChoiceCaller, inputs_str): + msg += f"\nFrom choice {choice}\n{inputs_str}" + super().__init__(msg) + self.choice = choice + + +class AlgorithmSelectorCache(PersistentCache): + def __call__( + self, + name, + choices: List[ChoiceCaller], + input_nodes, + layout, + # optional dict mapping arg indices to the functions + # generating a torch.Tensor for that input from the + # corresponding ir.Buffer. if passed for a given + # arg, the function will be called instead of + # generating a random torch.Tensor for benchmarking. + input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, + precompilation_timeout_seconds: int = 60 * 60, + ): + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + + # TODO(nmacchioni): remove once CI tests are fixed + choices = [choice for choice in choices if choice is not None] + if len(choices) == 0: + raise RuntimeError( + "No choices to select, please consider adding ATEN into max_autotune_gemm_backends " + "config (defined in torch/_inductor/config.py) to allow at least one choice. " + ) + log.debug("Max autotune selects from %s choices.", str(len(choices))) + + if len(choices) == 1: + if not isinstance(choices[0], CUDATemplateCaller): + # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. + return choices[0].output_node() + + @functools.lru_cache(None) + def make_benchmark_fn(): + return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) + + def precompile(choices): + if ( + precompilation_timeout_seconds is None + or precompilation_timeout_seconds <= 0 + ): + return + num_workers = min( + config.compile_threads, + torch.get_num_threads(), + len(choices), + ) + if num_workers <= 0: + return + log.info( + "Multithreaded precompilation for %d choices using %d worker threads", + len(choices), + num_workers, + ) + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = executor.map( + lambda c: c.precompile(), + [c for c in choices if hasattr(c, "precompile")], + timeout=precompilation_timeout_seconds, + ) + try: + iterator = iter(futures) + while True: + try: + next(iterator) + except CUDACompileError: + log.error( # noqa: G201 + "CUDA Compilation error", exc_info=True + ) + except TimeoutError: + log.warning( + f"Precompilation timed out after {precompilation_timeout_seconds} seconds." # noqa: G004 + ) + except StopIteration: + pass + executor.shutdown(wait=True) + + def autotune(choices): + try: + precompile(choices) + except TimeoutError: + log.warning( + "Precompilation phase took longer than timeout allowed. Continuing" + ) + pass + return make_benchmark_fn()(choices) + + if config.autotune_in_subproc: + from .autotune_process import tuning_pool + + # do the optional warmup + tuning_pool.initialize() + + autotune_start_ts = time.time() + timings = self.lookup( + choices, + name, + repr([self.key_of(x) for x in input_nodes]), + autotune, + ) + autotune_elapse = time.time() - autotune_start_ts + if timings == {} or choices[0] not in timings: + return choices[0].output_node() + + if make_benchmark_fn.cache_info().currsize: + counters["inductor"]["select_algorithm_autotune"] += 1 + if ( + make_benchmark_fn.cache_info().currsize + or log.getEffectiveLevel() == logging.DEBUG + or config.trace.log_autotuning_results + ): + self.log_results(name, input_nodes, timings, autotune_elapse) + selected_choice = builtins.min(timings, key=timings.__getitem__).output_node() + log.debug("selected choice: %s", str(selected_choice)) + return selected_choice + + @classmethod + def make_benchmark_fn( + cls, + choices, + input_nodes, + layout, + input_gen_fns=None, + ): + if input_gen_fns is None: + input_gen_fns = {} + + # de-duplicate args + unique_example_inputs = { + x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x) + for i, x in enumerate(input_nodes) + } + example_inputs = list(unique_example_inputs.values()) + example_inputs_extern = [ + torch.as_strided( + unique_example_inputs[input_node.get_name()], + V.graph.sizevars.size_hints( + input_node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + input_node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hint( + input_node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + for input_node in input_nodes + ] + + out = cls.benchmark_example_value(layout) + out_extern = torch.as_strided( + out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset) + ) + if VERIFY: + choices[0].benchmark(*example_inputs_extern, out=out_extern) + expected = out_extern.clone() + + if DEBUG: + print(f"{len(choices)} tuning requests:") + + def debug_str(): + def tensor_repr(x): + return ( + f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, " + f"dtype={x.dtype!r}, device={x.device.type!r})" + ) + + lines = [ + "inputs = [", + ] + for x in example_inputs: + lines.append(f" {tensor_repr(x)},") + lines += ["]", f"out = {tensor_repr(out)}", ""] + return "\n".join(lines) + + def benchmark_choice_in_current_process(choice): + out.zero_() + if isinstance(choice, ExternKernelCaller): + # aten kernels want the offset baked in for sliced tensors + result = choice.benchmark(*example_inputs_extern, out=out_extern) + else: + # triton templates want the base pointer for sliced tensors + result = choice.benchmark(*example_inputs, out=out) + if VERIFY: + torch.testing.assert_close(out_extern, expected, **VERIFY) + torch.cuda.synchronize() # shake out any CUDA errors + return result + + def benchmark_in_current_process(choices): + timings = {} + for choice in choices: + try: + timing = benchmark_choice_in_current_process(choice) + except CUDACompileError as e: + log.warning( + "CUDA compilation error: \n%s. \nIgnore this choice.", str(e) + ) + timing = float("inf") + except RuntimeError as e: + msg = str(e) + if "invalid argument" in msg: + msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n" + log.warning(msg) + timing = float("inf") + else: + if "illegal memory access" in msg: + msg += "\n\nEither error in template or triton bug.\n" + raise ErrorFromChoice(msg, choice, debug_str()) # noqa: TRY200 + except AssertionError as e: + raise AssertionError( # noqa: TRY200 + f"Incorrect result from choice {choice}\n\n{e}" + ) + + timings[choice] = timing + + return timings + + def benchmark_in_sub_process(choices): + from . import autotune_process + + # only benchmark triton kernel in sub process for now. + # ATen/Extern kernel are still benchmarked in the current process. + extern = [c for c in choices if isinstance(c, ExternKernelCaller)] + triton = [c for c in choices if not isinstance(c, ExternKernelCaller)] + + timings = benchmark_in_current_process(extern) + timings.update(autotune_process.benchmark_in_sub_process(triton)) + return timings + + benchmark = ( + benchmark_in_sub_process + if config.autotune_in_subproc + else benchmark_in_current_process + ) + + return benchmark + + @staticmethod + def log_results( + name: str, + input_nodes: List[ir.IRNode], + timings: Dict[ChoiceCaller, float], + elapse: float, + ): + V.debug.log_autotuning_results(name, input_nodes, timings, elapse) + if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE: + return + sizes = ", ".join( + [ + "x".join( + map( + str, + V.graph.sizevars.size_hints( + n.get_size(), fallback=config.unbacked_symint_fallback + ), + ) + ) + for n in input_nodes + ] + ) + n = None if log.getEffectiveLevel() == logging.DEBUG else 10 + top_k = sorted(timings, key=timings.__getitem__)[:n] + best = top_k[0] + best_time = timings[best] + sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") + for choice in top_k: + result = timings[choice] + if result: + sys.stderr.write( + f" {choice.name} {result:.4f} ms {best_time/result:.1%}\n" + ) + else: + sys.stderr.write( + f" {choice.name} {result:.4f} ms \n" + ) + + autotune_type_str = ( + "SubProcess" if config.autotune_in_subproc else "SingleProcess" + ) + sys.stderr.write(f"{autotune_type_str} AUTOTUNE takes {elapse:.4f} seconds\n") + + @staticmethod + def benchmark_example_value(node): + """ + Convert an ir.Buffer into a concrete torch.Tensor we can use for + benchmarking. + """ + if isinstance(node, ir.Layout): + node = ir.Buffer("fake", node) + # triton templates want the base tensor. + if isinstance(node, ir.BaseView): + node = node.unwrap_view() + # preserve rng states to avoid the rand_strided call below changes + # the rng states for the real model code. + with preserve_rng_state(): + return rand_strided( + V.graph.sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + device=node.get_device(), + dtype=node.get_dtype(), + extra_size=node.layout.offset, + ) + + @staticmethod + def key_of(node): + """ + Extract the pieces of an ir.Buffer that we should invalidate cached + autotuning results on. + """ + sizevars = V.graph.sizevars + return ( + node.get_device().type, + str(node.get_dtype()), + *sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + *sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + + +_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None + + +def autotune_select_algorithm(*args, **kwargs): + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs) + + +def realize_inputs(*args): + if len(args) == 1: + return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0])) + return [realize_inputs(x) for x in args] + + +# ensure lowering is imported so that `extern_kernels.*` is populated +from . import lowering # noqa: F401 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/sizevars.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/sizevars.py new file mode 100644 index 0000000000000000000000000000000000000000..ceff1bddc913431ac7fb690e844606843b575220 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/sizevars.py @@ -0,0 +1,643 @@ +import functools +import itertools +import logging +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union + +import sympy +from sympy import Expr + +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.value_ranges import bound_sympy + +from .utils import sympy_index_symbol, sympy_subs, VarRanges +from .virtualized import V + +log = logging.getLogger(__name__) + + +# This class is a little awkward, because ShapeEnv is doing most of the heavy +# lifting and in some cases we should be directly passing through to ShapeEnv, +# but there is some extra inductor logic that needs to be handled here +class SizeVarAllocator: + def __init__(self, shape_env=None): + super().__init__() + if shape_env is None: + shape_env = ShapeEnv() + self.shape_env = shape_env + self.var_to_val = self.shape_env.var_to_val + self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements + # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. + # The basic idea is if we have some complicated sympy expression + # f(s0), we may choose to precompute it on the host and then replace + # all occurrences of that sympy expression with ps0, so that when we + # codegen we simply reference ps0 directly without repeating + # f(s0). Unlike regular size variables, ps variables cannot be + # guarded upon; so if we are asked to guard on a Sympy expression + # which potentially could have already had a precomputed replacement + # on it, we are obligated to invert the precomputed replacements + # (inv_precomputed_replacements). + self.precomputed_replacements: Dict[Expr, sympy.Symbol] = dict() + self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = dict() + self.stride_vars = self.make_stride_vars_cache() + self.simplify_with_ranges = self.make_simplify_with_ranges_cache() + self._simplify_loops = self.make_simplify_loops_cache() + + def simplify(self, expr: Expr): + return sympy.expand(expr).xreplace(self.replacements) + + def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]: + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: Dict[Tuple[Any, ...], Expr] = dict() + replacement_count = len(self.replacements) + + def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr: + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (expr, *var_ranges.items()) + result = cache.get(key, None) + if result is None: + result = self._simplify_with_ranges(expr, var_ranges) + cache[key] = result + return result + + return simplify_with_ranges + + def make_simplify_loops_cache(self): + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: Dict[Tuple[Any, ...], Any] = dict() + replacement_count = len(self.replacements) + + def simplify_loops(index_vars, sizes, index_formulas): + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (*index_vars, *sizes, *index_formulas) + result = cache.get(key, None) + if result is None: + result = self._simplify_loops_impl(index_vars, sizes, index_formulas) + cache[key] = result + return result + + return simplify_loops + + def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr: + """ + Simplify indexing expression with knowledge of the ranges of + iteration variables. + """ + + expr = join_dimensions(self.simplify(expr)) + original_expr = expr + + def remove_zero_terms(base, divisor): + """Symbols smaller than the divisor are zero""" + for v in base.free_symbols: + if v in var_ranges: + # var smaller than divisor can be removed + # if the rest is guaranteed to be multiple of divisor + rest = sympy.Wild("_rest", exclude=[v]) + m = base.match(v + rest) + if m and v not in m[rest].free_symbols: + gcd = sympy.gcd(m[rest], divisor) + if gcd == divisor: + if self.statically_known_leq(var_ranges[v], divisor): + base = m[rest] + return base + + def visit_indexing_div(base, divisor): + return FloorDiv(remove_zero_terms(base, divisor), divisor) + + def visit_modular_indexing(base, divisor, modulus): + base = remove_zero_terms(base, divisor) + base_pos = True + if isinstance(base, ModularIndexing): + # for modular indexing, biggest values from the ranges don't necessarily result in + # the biggest result, the biggest result is modulus - 1 + base_s = base.args[2] - 1 + elif not base.has(ModularIndexing): + # actual iteration range is to size-1 + iter_ranges_zero = {k: 0 for k, v in var_ranges.items()} + base_lowest = sympy_subs(base, iter_ranges_zero) + if self.statically_known_leq(0, base_lowest): # type: ignore[arg-type] + # can't replace with indexing div if base can be negative + base_pos = True + else: + base_pos = False + iter_ranges = {k: v - 1 for k, v in var_ranges.items()} + base_s = sympy_subs(base, iter_ranges) + else: + base_s = base + if self.statically_known_lt(base_s, modulus * divisor) and base_pos: + return FloorDiv(base, divisor) + return ModularIndexing(base, divisor, modulus) + + if expr.has(ModularIndexing): + expr = expr.replace( + ModularIndexing( + sympy.Wild("base"), + sympy.Wild("divisor"), + sympy.Wild("modulus"), + ), + visit_modular_indexing, + ) + + if expr.has(FloorDiv): + expr = expr.replace( + FloorDiv( + sympy.Wild("base"), + sympy.Wild("divisor"), + ), + visit_indexing_div, + ) + + if expr != original_expr: + return self._simplify_with_ranges(expr, var_ranges) + return expr + + def _simplify_loops_impl( + self, index_vars: List[sympy.Symbol], sizes, index_formulas + ): + """ + Try to remove as many axis from loop iterations as possible, by: + 1) removing size==1 dimensions + 2) fuse contiguous dimensions into a single loop + If channel_last = True, we will prevent the last dim fused with other dims + """ + sizes = list(map(self.simplify, sizes)) + + strides = [self.stride_vars(x, index_vars) for x in index_formulas] + assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) + + for i in range(len(sizes)): + if sizes[i] == 1: + # remove dim + sizes[i] = None + + def can_merge_dims(a, b): + for k in range(len(strides)): + if self.simplify(strides[k][a] * sizes[a]) == self.simplify( + strides[k][b] + ): + # approximate test passed, try sound version + va = index_vars[a] + vb = index_vars[b] + v = sympy_index_symbol("_merge_tester") + expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0}) + expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v}) + if self.simplify(expr1) == self.simplify(expr2): + continue + return False + return True + + changed = True + while changed: + changed = False + for i, j in itertools.product( + reversed(range(len(sizes))), reversed(range(len(sizes))) + ): + if i == j or sizes[i] is None or sizes[j] is None: + continue + if can_merge_dims(i, j): + changed = True + sizes[i] = sizes[i] * sizes[j] + sizes[j] = None + + def reindex(index): + it = list(reversed(index)) + new_index = [] + for size in sizes: + if size is None: + new_index.append(sympy.Integer(0)) + else: + new_index.append(it.pop()) + assert not it + return new_index + + def prune(index): + assert len(index) == len(sizes) + return [i for i, s in zip(index, sizes) if s is not None] + + return [x for x in sizes if x is not None], reindex, prune + + # Note - [On Statically Known] + # + # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system + # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was + # true, we add a guard and return True, otherwise, False. + # + # def maybe_guard_foo(args): + # if size_hinted_check(args): + # return False # No guard, no optim + # guard(args) # Make a guard + # return True # Safe to apply optimization + # + # The prior system incurred a guard, and green lit an optimization. + # + # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the + # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we + # return False. + # + # def maybe_guard_foo(args): + # if all_static(args): + # return True # Safe to apply optimization + # else: + # return False # No guard, no optim + + # See Note - [On Statically Known] + + def is_expr_static_and_true(self, expr: Union[Expr, int]) -> bool: + if expr in (True, False): + return bool(expr) + + try: + simplified = self.shape_env._maybe_evaluate_static(expr) + if simplified is not None: + return bool(simplified) + except Exception: + log.debug("Could not simplify %s", expr) + + return False + + def statically_known_equals(self, left: Expr, right: Expr) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right are equal. + """ + return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type] + + # See Note - [On Statically Known] + def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right lists are equal. + """ + if len(left) != len(right): + return False + if all(self.statically_known_equals(l, r) for l, r in zip(left, right)): + return True + return False + + # See Note - [On Statically Known] + def statically_known_leq(self, left: Expr, right: Expr) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than or equal to right. + """ + expr = left <= right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_lt(self, left: Expr, right: Expr) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than right. + """ + expr = left < right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_multiple_of(self, numerator: Expr, denominator: Expr) -> bool: + """ + Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. + """ + expr = sympy.Eq(numerator % denominator, 0) + return self.is_expr_static_and_true(expr) # type: ignore[arg-type] + + # The guard functions require you to ALREADY KNOW that a particular + # condition holds. If you don't know (you want to guard on an expression + # being a particular value, and then get access to that value), use + # the evaluate functions. + + def guard_equals(self, left: Expr, right: Expr) -> Expr: + if isinstance(left, Expr): + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] + if isinstance(right, Expr): + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + assert self.shape_env.evaluate_expr(sympy.Eq(left, right)) + return left + + def guard_leq(self, left: Expr, right: Expr) -> None: + return self.guard_lt(left, right + 1) + + def guard_lt(self, left: Expr, right: Expr) -> None: + assert self.shape_env.evaluate_expr(sympy.Lt(left, right)) + + def expect_true(self, expr: Expr, *, msg: str) -> None: + expr = sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] + self.shape_env.defer_runtime_assert(expr, msg, fx_node=None) + + def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr: + # Prefer returning the expression without unbacked symints + if self.shape_env.is_unbacked_symint(left): + self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type] + return right + elif self.shape_env.is_unbacked_symint(right): + self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type] + return left + else: + return self.guard_equals(left, right) + + def guarded_order(self, seq): + """ + Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing. + Used for generating block_ptrs. + """ + seq = [*map(self.remove_precomputed_replacements, seq)] + seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)] + seq.sort() + order = [-1] * len(seq) + last_var = None + for new_index, (_, orig_index, var) in enumerate(seq): + order[orig_index] = new_index + if last_var is not None: + self.guard_leq(last_var, var) + last_var = var + return order + + # The evaluate functions evaluate some symbolic sympy expression + # (NB: not necessarily an Expr) and return what the concrete result + # is, guarding on the expression being that result + + # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b) + # as this will ensure that you actually have a sympy'ified expression, + # and will prevent you from incorrectly writing evaluate_expr(a == b) + # which does the wrong thing if a or b is a sympy expression + def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool: + assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left) + return self.shape_env.evaluate_expr(sympy.sympify(left)) + + def evaluate_min(self, left: Expr, right: Expr) -> Expr: + """return the smaller of left and right, and guard on that choice""" + lv = self.size_hint(left) + rv = self.size_hint(right) + if lv <= rv: + self.guard_leq(left, right) + return left + else: + self.guard_leq(right, left) + return right + + def evaluate_max(self, left: Expr, right: Expr) -> Expr: + """return the larger of left and right, and guard on that choice""" + # Always choose the opposite of eval min for consistency + # This means min(a, b) and max(a, b) produce the same guards + min_val = self.evaluate_min(left, right) + return right if min_val is left else left + + def evaluate_static_shape(self, left: Expr) -> int: + right = self.size_hint(left) + self.guard_equals(left, sympy.Integer(right)) + return int(right) + + def evaluate_static_shapes(self, left: List[Expr]) -> List[int]: + return [self.evaluate_static_shape(x) for x in left] + + def remove_precomputed_replacements(self, expr: Expr) -> Expr: + if any(s.name.startswith("ps") for s in expr.free_symbols): # type: ignore[attr-defined] + return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] + return expr + + def symbolic_hint(self, expr: Expr) -> Expr: + # Substitute all hints into expr, but leave unbacked symints alone + if not isinstance(expr, Expr): + assert isinstance(expr, int) + return expr + free_symbols = expr.free_symbols + if not free_symbols: + return int(expr) # type: ignore[return-value] + expr = self.remove_precomputed_replacements(expr) + return sympy_subs(expr, self.var_to_val) + + def size_hint(self, expr: Expr, *, fallback: Optional[int] = None) -> int: + out = self.symbolic_hint(expr) + if not isinstance(out, (int, sympy.Integer)) and fallback is not None: + # Use the provided heuristic fallback hint + sym_vrs = { + s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols + } + if all(vr is not None for vr in sym_vrs.values()): + expr_vr = bound_sympy(expr, sym_vrs) # type: ignore[arg-type] + lower = self.size_hint(expr_vr.lower) # type: ignore[arg-type] + upper = self.size_hint(expr_vr.upper) # type: ignore[arg-type] + fallback = min(max(fallback, lower), upper) + return fallback + try: + return int(out) + except Exception: + log.debug("failed on: %s", out) + raise + + def size_hints( + self, + exprs: Iterable[Expr], + *, + fallback: Optional[int] = None, + ) -> Tuple[int, ...]: + return tuple(self.size_hint(x, fallback=fallback) for x in exprs) + + def _lru_cache(self, fn, maxsize=None): + """ + Wrapper around functools.lru_cache that clears when replacements + has been invalidated. + """ + fn_cache = functools.lru_cache(maxsize)(fn) + prior_len = len(self.replacements) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + nonlocal prior_len + if prior_len != len(self.replacements): + prior_len = len(self.replacements) + fn_cache.cache_clear() + return fn_cache(*args, **kwargs) + + return wrapper + + def make_stride_vars_cache(self): + cache = self._lru_cache(self._stride_vars) + + def stride_vars( + index: Expr, + vars: List[sympy.Symbol], + support_vars: Optional[List[sympy.Symbol]] = None, + ) -> List[Expr]: + if not support_vars: + support_vars = vars + return cache(index, tuple(vars), tuple(support_vars)) + + return stride_vars + + def _stride_vars( + self, index: Expr, vars: List[sympy.Symbol], support_vars: List[sympy.Symbol] + ) -> List[Expr]: + """Convert an indexing expression back into strides + + NOTE: This is only valid if the index is a standard strided offset + calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a + stride of -10 because the index wraps around after the first element + + """ + strides = [] + index = self.simplify(index) + # remove any offset + index = index - sympy_subs( + index, {v: sympy.Integer(0) for v in support_vars if v != 0} + ) + for i in range(len(vars)): + # drop all the other dims + index_dim = sympy_subs( + index, + { + support_vars[j]: sympy.Integer(0) + for j in range(len(support_vars)) + if vars[i] != support_vars[j] and support_vars[j] != 0 + }, + ) + v = vars[i] + if v == 0: + strides.append(sympy.Integer(0)) + else: + # TODO(jansel): should we use sympy.diff here? + strides.append( + sympy_subs(index_dim, {v: sympy.Integer(1)}) + - sympy_subs(index_dim, {v: sympy.Integer(0)}) + ) + return strides + + def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr: + """Extract offset part of an indexing expression""" + index = self.simplify(index) + return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0}) + + def stride_hints( + self, + index: Expr, + vars: List[sympy.Symbol], + support_vars: Optional[List[sympy.Symbol]] = None, + ) -> List[int]: + for v in index.free_symbols: + if v.name.startswith("indirect"): # type: ignore[attr-defined] + index = sympy_subs(index, {v: 0}) # type: ignore[dict-item] + result = [] + for s in self.stride_vars(index, vars, support_vars): + try: + result.append(self.size_hint(s)) + except TypeError: + result.append(0) + return result + + def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]: + strides = tuple(map(abs, self.stride_hints(index, vars))) + order = list(range(len(strides))) + order.sort(key=lambda x: (strides[x] == 0, strides[x])) + return order + + def lookup_precomputed_size(self, expr: Expr) -> Expr: + if ( + isinstance(expr, (int, sympy.Symbol, sympy.Number)) + or expr.is_number + or expr.is_symbol + ): + return expr + expr = self.remove_precomputed_replacements(expr) + if expr not in self.precomputed_replacements: + sym = sympy_index_symbol(f"ps{len(self.precomputed_replacements)}") + self.precomputed_replacements[expr] = sym + self.inv_precomputed_replacements[sym] = expr + return self.precomputed_replacements[expr] + + def free_symbols(self) -> Set[sympy.Symbol]: + return set(self.var_to_val.keys()) - set(self.replacements.keys()) + + +def join_dimensions(expr: Expr) -> Expr: + if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): + return expr # fast exit path + return _join_dimensions_cached(expr) + + +@functools.lru_cache(256) +def _join_dimensions_cached(expr: Expr) -> Expr: + """ + ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) + becomes + ModularIndexing(i0, 1, 128) + ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32) + becomes i0 + + + This type of pattern can come from view operations + """ + assert isinstance(expr, sympy.Add) + + scale = sympy.Wild("scale", exclude=[0]) + base = sympy.Wild("base") + divisor = sympy.Wild("divisor") + mod1 = sympy.Wild("modulus") + mod2 = sympy.Wild("modulus2") + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] + * m1[mod1] + * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2) + ) + if m2 and term1 != term2: + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] + * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2]) + ) + return expr + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1]) + ) + if m2 is not None: # in case of success we get an empty dict here + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] * FloorDiv(m1[base], m1[divisor]) + ) + return expr + return expr + + +class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] + """ + A wrapper around .virtualize.ops that uses var range information to + simplify ModularIndexing/FloorDiv. + """ + + def __init__(self, inner, var_ranges: VarRanges): + super().__init__(inner) + self.name = "SimplifyIndexing" + self._simplify: Callable[ + [Expr], Expr + ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(name, self._simplify(index)) + + def store(self, name, index, value, mode=None): + return self._inner.store(name, self._simplify(index), value, mode=mode) + + def store_reduction(self, name, index, value): + return self._inner.store_reduction(name, self._simplify(index), value) + + def index_expr(self, index, dtype): + return self._inner.index_expr(self._simplify(index), dtype) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3611f6deaadb6f550d47ca43e9b7470b57ab64b3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/utils.py @@ -0,0 +1,1428 @@ +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import enum +import functools +import getpass +import inspect +import io +import itertools +import logging +import math +import operator +import os +import platform +import re +import shutil +import sys +import tempfile +import textwrap +import time +import unittest +from dataclasses import fields +from datetime import datetime +from io import StringIO +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + NamedTuple, + Optional, + Protocol, + Set, + TypeVar, + Union, + ValuesView, +) +from unittest import mock + +import sympy +from typing_extensions import Concatenate, ParamSpec + +import torch +from torch._dynamo.device_interface import get_interface_for_device +from torch.autograd import DeviceType +from torch.autograd.profiler_util import EventList +from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing +from . import config + +log = logging.getLogger(__name__) + +_T = TypeVar("_T") +VarRanges = Dict[sympy.Expr, sympy.Expr] + + +def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float: + """ + Returns benchmark results by examining torch profiler events. + This could be more accurate as it doesn't count CPU side overhead. + However, this also requires manually excluding irrelevant event, e.g. + vectorized_elementwise_kernel which is used to fill L2 cache, + various CUDA events, etc, so could also be fragile. + """ + + fn() + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + + # Warm-up + for _ in range(n_warmup): + fn() + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + # Benchmark + for i in range(n_repeat): + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + fn() + # Record clocks + torch.cuda.synchronize() + + log.debug("raw events") + log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + filtered_events = EventList( + [ + event + for event in p.events() + if event.device_type == DeviceType.CUDA and event.name != "Context Sync" + ] + ) + if len(filtered_events) % n_repeat != 0: + raise RuntimeError( + "Failed to divide all profiling events into #repeat groups. " + "#CUDA events: %d, #repeats: %s", + len(filtered_events), + n_repeat, + ) + num_event_per_group = len(filtered_events) / n_repeat + actual_events = EventList( + [ + event + for i, event in enumerate(filtered_events) + if i % num_event_per_group != 0 + ] + ) + actual_events._build_tree() + actual_events = actual_events.key_averages() + + log.debug("profiling time breakdown") + log.debug(actual_events.table(row_limit=-1)) + + res = sum(event.cuda_time_total for event in actual_events) / 1000.0 / n_repeat + log.debug("profiling results: %s ms", res) + return res + + +def do_bench(*args, **kwargs): + @functools.lru_cache(None) + def load_triton(): + try: + # NB: Lazily load triton, as importing triton is slow + # see https://github.com/openai/triton/issues/1599 + from triton.testing import do_bench as triton_do_bench + except ImportError as exc: + raise NotImplementedError("requires Triton") from exc + + # triton PR https://github.com/openai/triton/pull/1513 change the + # quantile fields name from 'percentiles' to 'quantiles' + # and change the default value from (0.5, 0.2, 0.8) to None. + # This may break inductor since a caller expects a tuple may get a item. + # + # Add a wrapper to maintain the same behavior for inductor. + # Maybe we should have own implementation of this function? + return triton_do_bench, ( + "quantiles" + if inspect.signature(triton_do_bench).parameters.get("quantiles") + is not None + else "percentiles" + ) + + triton_do_bench, quantile_field_name = load_triton() + + if quantile_field_name not in kwargs: + kwargs[quantile_field_name] = (0.5, 0.2, 0.8) + return triton_do_bench(*args, **kwargs)[0] + + +@functools.lru_cache(None) +def has_torchvision_roi_align() -> bool: + try: + from torchvision.ops import roi_align # noqa: F401 + + return roi_align is not None and hasattr( + getattr(torch.ops, "torchvision", None), "roi_align" + ) + except ImportError: + return False + + +def conditional_product(*args): + return functools.reduce(operator.mul, [x for x in args if x]) + + +def decode_device(device: Union[Optional[torch.device], str]) -> torch.device: + if device is None: + return torch.tensor(0.0).device # default device + if isinstance(device, str): + device = torch.device(device) + if device.type != "cpu" and device.index is None: + device_interface = get_interface_for_device(device.type) + return torch.device(device.type, index=device_interface.Worker.current_device()) + return device + + +def sympy_product(it): + return functools.reduce(operator.mul, it, sympy.Integer(1)) + + +def sympy_dot(seq1, seq2): + assert len(seq1) == len(seq2) + return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) + + +def unique(it: Iterable[_T]) -> ValuesView[_T]: + return {id(x): x for x in it}.values() + + +def ceildiv( + numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] +) -> Union[int, sympy.Expr]: + if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): + return CeilDiv(numer, denom) + # TODO: There is a bug in a call to this function, to repro: + # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy + # --amp --only YituTechConvBert --dynamic-shapes + assert isinstance(numer, int) and isinstance( + denom, int + ), f"{numer}: {type(numer)}, {denom}: {type(denom)}" + return -(numer // -denom) + + +def next_power_of_2(n: int) -> int: + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def _type_of(key): + # Use the function here to get rid of dependencies on the Triton during the codegen. + # Refer to Triton implementation here: + # https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238 + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + dtype_str = str(key).split(".")[-1] + tys = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8e4b15x4": "fp8e4b15x4", + "float8_e4m3fn": "fp8e4nv", + "float8_e5m2": "fp8e5", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", + } + # reinterpret can create triton type + for v in list(tys.values()): + tys[v] = v + return key if isinstance(key, str) else f"*{tys[dtype_str]}" + + +def convert_shape_to_inductor( + lst: Iterable[Union[int, torch.SymInt]] +) -> List[sympy.Expr]: + """ + Gets the shape and stride of a tensor. For non-symbolic tensors, this is + trivial. But for symbolic tensors, we need to map from SymIntNode into + sympy.Expr. + """ + return [ + i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst + ] + + +def convert_shape_to_symint( + lst: Iterable[Union[int, sympy.Expr]] +) -> List[Union[int, torch.SymInt]]: + """ + Takes a list of shapes from Inductor and converts them into symints (or just + ints if all shapes are static). + """ + from .virtualized import V + + return [ + i + if isinstance(i, int) + else int(i) + if isinstance(i, sympy.Integer) + else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) + for i in lst + ] + + +def is_view(op: torch._ops.OpOverload): + """ + Does this op overload have aliasing + """ + assert isinstance(op, torch._ops.OpOverload) + return any(a.alias_info is not None for a in op._schema.arguments) + + +def is_pointwise_use(use): + if not use.op == "call_function": + return False + + if not ( + isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem + ): + return False + + if use.target is operator.getitem or is_view(use.target): + return all(is_pointwise_use(u) for u in use.users) + + return torch.Tag.pointwise in use.target.tags + + +def gen_gm_and_inputs(target, args, kwargs): + g = torch.fx.Graph() + g_args = [] + a_args = [] + for n, arg in enumerate(args): + if isinstance(arg, torch.Tensor): + g_args.append(g.placeholder(f"arg{n}")) + a_args.append(arg) + else: + g_args.append(arg) + assert all(not isinstance(x, torch.Tensor) for x in kwargs.values()) + node = g.call_function(target, tuple(g_args), kwargs) + if ( + len(target._schema.returns) == 1 + and str(target._schema.returns[0].type) == "Tensor" + ): + node = (node,) + g.output(node) + + gm = torch.fx.GraphModule({}, g) + return gm, a_args + + +def synchronize(device: str = "cuda"): + if device == "cpu": + return + device_interface = get_interface_for_device(device) + if device_interface.is_available(): + device_interface.synchronize() + + +def timed( + model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda" +) -> float: + synchronize(device) + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(times): + result = model(*example_inputs) + synchronize(device) + t1 = time.perf_counter() + # GC the result after timing + assert result is not None # type: ignore[possibly-undefined] + return t1 - t0 + + +def print_performance( + fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda" +): + timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)]) + took = torch.median(timings) / times + print(f"{took/baseline:.6f}") + return took + + +def precompute_method(obj: Any, method: str): + """Replace obj.method() with a new method that returns a precomputed constant.""" + result = getattr(obj, method)() + setattr(obj, method, lambda: result) + + +def precompute_methods(obj: Any, methods: List[str]): + """Replace methods with new methods that returns a precomputed constants.""" + for method in methods: + precompute_method(obj, method) + + +def cmp(a, b) -> int: + return int(a > b) - int(a < b) + + +def pad_listlike(x, size): + if len(x) == 1: + return type(x)([x[0]]) * size + else: + return x + + +# Used to ensure that iterating over a set is deterministic +def tuple_sorted(x): + if len(x) == 0: + return [] + + def sort_func(elem): + if isinstance(elem, str): + return elem + else: + # We expect `elem` to be `scheduler.BaseSchedulerNode` type here, + # but we are not able to do isinstance assert because of circular dependency + return elem.get_name() + + return sorted(x, key=sort_func) + + +P = ParamSpec("P") +RV = TypeVar("RV", covariant=True) + + +class CachedMethod(Generic[P, RV], Protocol): + @staticmethod + def clear_cache(self) -> None: + ... + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: + ... + + +# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature +def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: + key = f"__{fn.__name__}_cache" + + @functools.wraps(fn) + def wrapper(self): + if not hasattr(self, key): + setattr(self, key, fn(self)) + return getattr(self, key) + + def clear_cache(self): + if hasattr(self, key): + delattr(self, key) + + wrapper.clear_cache = clear_cache # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] + + +def aggregate_origins(node_schedule): + from . import ir + + if isinstance(node_schedule, list): + return functools.reduce( + operator.or_, + [ + node.node.origins + for node in node_schedule + if hasattr(node, "node") and node.node + ], + set(), + ) + elif isinstance(node_schedule, ir.ExternKernel): + return node_schedule.origins + else: + return set() + + +def get_fused_kernel_name(node_schedule, descriptive_names): + all_origins = aggregate_origins(node_schedule) + if descriptive_names == "original_aten": + # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions) + sources = [ + origin.meta["original_aten"]._overloadpacket.__name__ + for origin in all_origins + if origin.op == "call_function" + and "original_aten" in origin.meta + and origin.meta["original_aten"] is not None + ] + sources = sorted(set(sources)) + elif descriptive_names == "torch": + # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph) + sources = [] + for origin in all_origins: + if origin.op == "call_function" and "source_fn_stack" in origin.meta: + source_fn = origin.meta["source_fn_stack"][-1] + if isinstance(source_fn[1], str): + sources.append(source_fn[1]) + else: + sources.append(source_fn[1].__name__) + sources = sorted(set(sources)) + elif descriptive_names == "inductor_node": + sources = [ + origin.name for origin in all_origins if origin.op == "call_function" + ] + else: + raise NotImplementedError + sources = sources + return "_".join(["fused"] + sources) + + +def get_kernel_metadata(node_schedule, wrapper): + all_origins = aggregate_origins(node_schedule) + inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"] + + from_node_dict = collections.defaultdict(list) + original_aten_dict = collections.defaultdict(list) + for node in inductor_nodes: + if "original_aten" in node.meta and node.meta["original_aten"] is not None: + key = str(node.meta["original_aten"]._overloadpacket) + original_aten_dict[key].append(node.name) + if "from_node" in node.meta: + key = node.meta["from_node"][0][0] + from_node_dict[key].append(node.name) + metadata = ( + f"{wrapper.comment} Source Nodes: [{', '.join(sorted(from_node_dict.keys()))}], " + f"Original ATen: [{', '.join(sorted(original_aten_dict.keys()))}]" + ) + # trace back to original node here + detailed_metadata = [] + for original_node, nodes in sorted(from_node_dict.items()): + detailed_metadata.append( + f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}" + ) + return metadata, "\n".join(detailed_metadata) + + +def dominated_nodes( + initial_queue: Iterable[torch.fx.Node], skip_filter=None +) -> Set[torch.fx.Node]: + """Returns the set of nodes whose values depend on those within initial_queue""" + initial_queue = list(initial_queue) + dominated_set = set(initial_queue) + + while initial_queue: + node = initial_queue.pop() + for user in node.users: + if skip_filter and skip_filter(user): + continue + if user not in dominated_set: + dominated_set.add(user) + initial_queue.append(user) + + return dominated_set + + +def gather_origins(args, kwargs): + import itertools + + from . import ir + + def is_unrealized_node(n): + if isinstance(n, ir.TensorBox): + return is_unrealized_node(n.data) + if isinstance(n, ir.StorageBox): + return is_unrealized_node(n.data) + return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) + + kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] + arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] + return set(itertools.chain(*arg_origins, *kwarg_origins)) + + +def sympy_str(expr: sympy.Expr) -> str: + """ + Normal sympy str is very slow, this is a lot faster. The result are + somewhat worse, as it doesn't do as much simplification. So don't + use this for final codegen. + """ + if isinstance(expr, sympy.Symbol): + return expr.name + if isinstance(expr, sympy.Add): + return " + ".join(map(sympy_str, expr.args)) + if isinstance(expr, sympy.Mul): + return " * ".join(map(sympy_str, expr.args)) + + if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)): + return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" + return str(expr) + + +def sympy_index_symbol(name: str) -> sympy.Symbol: + """ + Used to generate an integer-nonnegative symbol. + """ + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert name[0] != "s" + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return sympy.Symbol(name, integer=True, nonnegative=True) + + +def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr: + """ + When the passed replacement symbol v is a string, it is converted to a symbol with name v that + have the same replaced expression integer and nonnegative properties. + """ + + def to_symbol(replaced, replacement): + assert isinstance(replaced, sympy.Expr) + if isinstance(replacement, str): + return sympy.Symbol( + replacement, + integer=replaced.is_integer, # type: ignore[attr-defined] + nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] + ) + else: + return replacement + + # xreplace is faster than subs, but is way more picky + return sympy.sympify(expr).xreplace( + {k: to_symbol(k, v) for k, v in replacements.items()} + ) + + +def free_symbol_startswith(index: sympy.Expr, prefix: str): + return any(v.name.startswith(prefix) for v in index.free_symbols) # type: ignore[attr-defined] + + +def free_symbol_has(index: sympy.Expr, pattern: str): + return any(pattern in v.name for v in index.free_symbols) # type: ignore[attr-defined] + + +def is_symbolic(a: Any) -> bool: + return isinstance(a, torch.SymInt) or ( + isinstance(a, torch.Tensor) + and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride())) + ) + + +def any_is_symbolic(*args: Any) -> bool: + return any(is_symbolic(a) for a in args) + + +def has_incompatible_cudagraph_ops(gm): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + forbidden_set = { + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "aten.multinomial.default", + "fbgemm.dense_to_jagged.default", + "fbgemm.jagged_to_padded_dense.default", + "run_and_save_rng_state", + "run_with_rng_state", + "aten._local_scalar_dense", + # Technically, it's not necessary to ban this, because an + # assert_scalar with constant arguments can be validly run + # with CUDA graphs, but the operator is also pointless with + # constant arguments, so might as well ban + "aten._assert_scalar", + } + if torch.are_deterministic_algorithms_enabled(): + forbidden_set.update( + { + "aten._unsafe_index_put.default", + "aten.index_put.default", + "aten.index_put_.default", + "aten.scatter.src", + "aten.scatter.reduce", + "aten.scatter.value_reduce", + "aten.scatter_add_", + "aten.scatter_add.default", + "aten.scatter_reduce.two", + "aten.scatter_reduce_.two", + "aten.scatter_reduce.two_out", + } + ) + for node in gm.graph.nodes: + if str(node.target) in forbidden_set: + return True + if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val): + return True + return False + + +def output_node(gm: torch.fx.GraphModule): + """Get the output node from an FX graph""" + last_node = next(iter(reversed(gm.graph.nodes))) + assert last_node.op == "output" + return last_node + + +# Attempt to import AttrsDescriptor from Triton +try: + from triton.compiler.compiler import AttrsDescriptor + + attrs_descriptor_available = True + # Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor + attr_desc_fields = {f.name for f in fields(AttrsDescriptor)} + ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields + divisible_by_8_available = "divisible_by_8" in attr_desc_fields +except ImportError: + attrs_descriptor_available = False + +# Define `instance_descriptor` function with clear conditional handling +if attrs_descriptor_available: + + def instance_descriptor( + divisible_by_16=None, + equal_to_1=None, + ids_of_folded_args=None, + divisible_by_8=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "divisible_by_16": divisible_by_16, + "equal_to_1": equal_to_1, + } + + # Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor + if ids_of_folded_args_available: + kwargs["ids_of_folded_args"] = ids_of_folded_args + if divisible_by_8_available: + kwargs["divisible_by_8"] = divisible_by_8 + + # Instantiate AttrsDescriptor with the prepared arguments + return AttrsDescriptor(**kwargs) + +else: + # Define a namedtuple as a fallback when AttrsDescriptor is not available + instance_descriptor = collections.namedtuple( # type: ignore[no-redef] + "instance_descriptor", + ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], + defaults=[tuple(), tuple(), tuple(), tuple()], + ) + + +@functools.lru_cache(None) +def cache_dir() -> str: + cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + if cache_dir is None: + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + cache_dir = os.path.join( + tempfile.gettempdir(), + "torchinductor_" + sanitized_username, + ) + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +@contextlib.contextmanager +def fresh_inductor_cache(cache_entries=None): + """ + Contextmanager that provides a clean tmp cachedir for inductor. + + Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes + generated with this cache instance. + """ + with tempfile.TemporaryDirectory() as inductor_cache_dir: + with mock.patch.dict( + os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} + ): + triton_cache_dir = os.path.join(inductor_cache_dir, "triton") + with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): + yield + if isinstance(cache_entries, dict): + assert len(cache_entries) == 0, "expected empty cache_entries dict" + if os.path.exists(triton_cache_dir): + files = os.listdir(triton_cache_dir) + cache_entries.update( + { + f: os.path.getsize(os.path.join(triton_cache_dir, f)) + for f in files + if ".lock" not in f + } + ) + + +def argsort(seq) -> List[int]: + # preserve original order for equal strides + getter = seq.__getitem__ + a_r = range(len(seq)) + return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413 + + +@functools.lru_cache(8) +def get_dtype_size(dtype): + return torch.empty((), dtype=dtype).element_size() + + +class LineContext(NamedTuple): + context: Any + + +class IndentedBuffer: + tabwidth = 4 + + def __init__(self, initial_indent=0): + self._lines = [] + self._indent = initial_indent + + def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]: + buf = StringIO() + p = 1 + linemap = [] + for line in self._lines: + if isinstance(line, DeferredLineBase): + line = line() + if line is None: + continue + elif isinstance(line, LineContext): + linemap.append((p, line.context)) + continue + assert isinstance(line, str) + buf.write(line) + buf.write("\n") + p += 1 + line.count("\n") + return buf.getvalue(), linemap + + def getvalue(self) -> str: + v, _ = self.getvaluewithlinemap() + return v + + def getrawvalue(self) -> str: + buf = StringIO() + for line in self._lines: + if isinstance(line, DeferredLineBase): + line = line() + if line is None: + continue + elif isinstance(line, LineContext): + continue + assert isinstance(line, str) + # backslash implies line continuation + if line.endswith("\\"): + buf.write(line[:-1]) + else: + buf.write(line) + buf.write("\n") + return buf.getvalue() + + def clear(self): + self._lines.clear() + + def __bool__(self): + return bool(self._lines) + + def prefix(self): + return " " * (self._indent * self.tabwidth) + + def newline(self): + self.writeline("\n") + + def writeline(self, line): + if isinstance(line, LineContext): + self._lines.append(line) + elif isinstance(line, DeferredLineBase): + self._lines.append(line.with_prefix(self.prefix())) + elif line.strip(): + self._lines.append(f"{self.prefix()}{line}") + else: + self._lines.append("") + + def writelines(self, lines): + for line in lines: + self.writeline(line) + + def indent(self, offset=1): + @contextlib.contextmanager + def ctx(): + self._indent += offset + try: + yield + finally: + self._indent -= offset + + return ctx() + + def do_indent(self, offset=1): + self._indent += offset + + def do_unindent(self, offset=1): + self._indent -= offset + + def splice(self, other_code, strip=False): + if isinstance(other_code, IndentedBuffer): + dedent = float("inf") + for line in other_code._lines: + if not isinstance(line, LineContext) and line: + dedent = min(dedent, len(line) - len(line.lstrip())) + if math.isinf(dedent): + dedent = 0 + for line in other_code._lines: + if isinstance(line, LineContext): + self._lines.append(line) + else: + IndentedBuffer.writeline(self, line[int(dedent) :]) + else: + other_code = textwrap.dedent(other_code) + if strip: + other_code = other_code.lstrip() + if not other_code: + return + other_code = other_code.rstrip() + for line in other_code.split("\n"): + self.writeline(line) + + def __repr__(self): + return f"{type(self)}({self.getvalue()})" + + +class DeferredLineBase: + """A line that can be 'unwritten' at a later time""" + + def __init__(self, line): + if not line.strip(): + line = "" + self.line = line + + def __call__(self) -> Optional[str]: + """Returns either self.line or None to indicate the line has been 'unwritten'""" + raise NotImplementedError() + + def _new_line(self, line: str) -> DeferredLineBase: + """Returns a new deferred line with the same condition""" + raise NotImplementedError() + + def with_prefix(self, prefix): + return self._new_line(f"{prefix}{self.line}") + + def lstrip(self): + return self._new_line(self.line.lstrip()) + + def __getitem__(self, index): + return self._new_line(self.line[index]) + + def __bool__(self): + return bool(self.line) + + def __len__(self): + return len(self.line) + + +@functools.lru_cache(None) +def is_big_gpu(index): + sms = torch.cuda.get_device_properties(index).multi_processor_count + if sms < 80: # V100 + log.warning("not enough SMs to use max_autotune_gemm mode") + return False + return True + + +def use_max_autotune() -> bool: + return ( + config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache + ) + + +def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: + return ( + use_max_autotune() + and layout.device.type == "cuda" + and layout.dtype in allowed_layout_dtypes + and is_big_gpu(layout.device.index or 0) + ) + + +def _use_autotune_backend(backend: str) -> bool: + return backend.upper() in [ + x.strip() for x in config.max_autotune_gemm_backends.upper().split(",") + ] + + +def use_triton_template(layout, *, enable_int32=False): + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] + if enable_int32: + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] + return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( + "TRITON" + ) + + +def use_cutlass_template(layout): + from .codegen.cuda.cutlass_utils import try_import_cutlass + + # Do not use cutlass template on ROCm + if torch.version.hip: + return False + + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] + res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( + "CUTLASS" + ) + + if res: + if not try_import_cutlass(): + log.warning( + "Failed to import CUTLASS lib. Please check whether " + "_inductor.config.cuda.cutlass_dir is set correctly. " + "Skipping CUTLASS backend for now." + ) + return False + return res + + +def use_aten_gemm_kernels(): + return not use_max_autotune() or _use_autotune_backend("ATEN") + + +class DebugDirManager: + counter = itertools.count(0) + prev_debug_name: str + + def __init__(self): + self.id = next(DebugDirManager.counter) + + def __enter__(self): + self.prev_debug_name = torch._dynamo.config.debug_dir_root + self.new_name = f"{self.prev_debug_name}_tmp_{self.id}" + torch._dynamo.config.debug_dir_root = self.new_name + + def __exit__(self, *args): + shutil.rmtree(self.new_name) + torch._dynamo.config.debug_dir_root = self.prev_debug_name + + +def run_and_get_code(fn, *args, **kwargs): + from .graph import GraphLowering + + compile_to_module = GraphLowering.compile_to_module + source_codes = [] + + def patched_compile_to_module(self): + mod = compile_to_module(self) + with open(mod.__file__) as f: + source_codes.append(f.read()) + return mod + + # If FX code caching is enabled, a hit prevents getting the code. + with config.patch({"fx_graph_cache": False}): + with mock.patch.object( + GraphLowering, "compile_to_module", patched_compile_to_module + ): + torch._dynamo.reset() + result = fn(*args, **kwargs) + return result, source_codes + + +def run_and_get_triton_code(fn, *args, **kwargs): + _, source_codes = run_and_get_code(fn, *args, **kwargs) + # Can have two outputs if backwards was eagerly compiled + assert ( + 1 <= len(source_codes) <= 2 + ), f"expected one or two code outputs got {len(source_codes)}" + return source_codes[0] + + +@contextlib.contextmanager +def override_lowering(aten_op, override_fn): + """ + Override the lowering of aten_op with override_fn. + The first argument of override_fn is the original lowering fn. + """ + from torch._inductor import lowering + + orig_fn = lowering.lowerings[aten_op] + try: + lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn) + yield + finally: + lowering.lowerings[aten_op] = orig_fn + + +def add_scheduler_init_hook(pre_fn, post_fn=None): + """ + Add hook functions to be called at the beginning and end of Scheduler.__init__. + Used for unit tests. + """ + from torch._inductor.scheduler import Scheduler + + orig_fn = Scheduler.__init__ + + def wrapper(scheduler, nodes): + pre_fn(scheduler, nodes) + out = orig_fn(scheduler, nodes) + if post_fn: + post_fn(scheduler, nodes) + return out + + return unittest.mock.patch.object(Scheduler, "__init__", wrapper) + + +def developer_warning(msg): + """ + Warnings that will be actionable for PyTorch developers, but not + end users. Allows us to easily disable them in stable releases but + keep them on for nightly builds. + """ + if config.developer_warnings: + log.warning(msg) + else: + log.info(msg) + + +def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int: + """ + Return the total number of bytes the arguments of tensor type takes. + + For in/out args, tensor sizes are counted twice: once for reading and + once for writing. + + The first num_in_out_args arguments are in out tensors. + """ + return sum( + arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args)) + for i, arg in enumerate(args) + if isinstance(arg, torch.Tensor) + ) + + +def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True): + info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}" + slow = ms > 0.012 and gb_per_s < 650 + return red_text(info_str) if color and slow else info_str + + +def get_benchmark_name(): + """ + An experimental API used only when config.benchmark_kernel is true. + + The benchmark name is only available at codegen time. So we can not + directly call it in benchmark_all_kernels which is run after codegen. + + The function assumes the argument after --only is the benchmark name. + It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc + scripts, this function may return None. + + There are 2 flavors of --only argument we need handle: + 1. --only model_name + 2. --only=model_name + """ + try: + idx = sys.argv.index("--only") + if ( + idx + 1 < len(sys.argv) + and len(sys.argv[idx + 1]) > 0 + and sys.argv[idx + 1][0] != "-" + ): + return sys.argv[idx + 1] + except ValueError: + pass + + for arg in sys.argv: + if arg.startswith("--only="): + return arg[len("--only=") :] + + +def is_ones(items): + return all(x == 1 for x in items) + + +def is_zeros(items): + return all(x == 0 for x in items) + + +def is_cpu_device(inputs): + return all( + item.device == torch.device("cpu") + for item in inputs + if isinstance(item, torch.Tensor) + ) + + +def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: + assert isinstance( + val, sympy.Expr + ), "only support sympy.Expr as input to get_sympy_Expr_dtype" + if val.is_integer: # type: ignore[attr-defined] + return torch.int64 + else: + return torch.float64 + + +@contextlib.contextmanager +def maybe_profile(should_profile, *args, **kwargs): + if should_profile: + with torch.profiler.profile(*args, **kwargs) as p: + yield p + else: + yield + + +def triton_config_to_hashable(cfg): + """ + Convert triton config to a tuple that can uniquely identify it. We can use + the return value as a dictionary key. + """ + items = sorted(cfg.kwargs.items()) + items.append(("num_warps", cfg.num_warps)) + items.append(("num_stages", cfg.num_stages)) + return tuple(items) + + +def parallel_num_threads(): + threads = config.cpp.threads + if threads < 1: + threads = torch.get_num_threads() + return threads + + +HAS_COLORAMA = True +try: + import colorama +except ImportError: + HAS_COLORAMA = False + + +def _color_text(msg, color): + if not HAS_COLORAMA: + return msg + + return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET + + +def green_text(msg): + return _color_text(msg, "green") + + +def yellow_text(msg): + return _color_text(msg, "yellow") + + +def red_text(msg): + return _color_text(msg, "red") + + +def blue_text(msg): + return _color_text(msg, "blue") + + +@functools.lru_cache(None) +def get_device_tflops(dtype): + from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops + + assert dtype in (torch.float16, torch.bfloat16, torch.float32) + + if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): + # Triton API change in https://github.com/openai/triton/pull/2293 + from torch._utils_internal import max_clock_rate + + sm_clock = max_clock_rate() + if dtype in (torch.float16, torch.bfloat16): + return get_max_tensorcore_tflops(dtype, sm_clock) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32, sm_clock) + else: + return get_max_simd_tflops(torch.float32, sm_clock) + else: + if dtype in (torch.float16, torch.bfloat16): + return get_max_tensorcore_tflops(dtype) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32) + else: + return get_max_simd_tflops(torch.float32) + + +@functools.lru_cache(None) +def get_gpu_dram_gbps(): + from triton.testing import get_dram_gbps + + return get_dram_gbps() + + +def is_welford_reduction(reduction_type): + return reduction_type.startswith("welford") + + +def reduction_num_outputs(reduction_type): + return 3 if is_welford_reduction(reduction_type) else 1 + + +def get_max_y_grid(): + return 65535 + + +def is_linux() -> bool: + return platform.system() == "Linux" + + +def has_free_symbols(itr: Iterable[Any]): + return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr) + + +def is_dynamic(*args): + from . import ir + + for t in args: + if isinstance(t, ir.TensorBox): + if has_free_symbols(t.data.get_size()) or ( + hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride()) + ): + return True + elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)): + assert hasattr(t, "get_size") and hasattr(t, "get_stride") + if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()): + return True + elif not isinstance(t, ir.IRNode): + continue + else: + raise TypeError(f"unexpected type for is_dynamic {type(t)}") + + return False + + +# Placeholder strings used in triton codegen. +class Placeholder(enum.Enum): + # The placeholder for the actual name of a triton kernel. + # e.g. for "def triton_" it would be "triton_" + KERNEL_NAME = "KERNEL_NAME" + + # The descriptive name of the triton kernel; when unique_kernel_names = False, this + # placeholder will be replaced with a string with more information. + DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME" + + +def pass_execution_and_save(func, gm, msg): + from .pattern_matcher import stable_topological_sort + + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + delete=False, + ) as f: + before_io = io.StringIO() + after_io = io.StringIO() + print(f"Before:\n{gm.graph}", file=f) + print(gm.graph, file=before_io) + start_time = datetime.now() + func(gm.graph) + time_elapsed = datetime.now() - start_time + # recompile graph + stable_topological_sort(gm.graph) + gm.graph.lint() + gm.recompile() + + print(f"After:\n{gm.graph}", file=f) + print(gm.graph, file=after_io) + t = before_io.getvalue() == after_io.getvalue() + log.info( + "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s", + msg, + f.name, + t, + time_elapsed, + ) + + +def is_collective(node): + from . import ir + + return isinstance(node, ir.CollectiveKernel) or type(node) == ir._CollectiveKernel + + +def is_wait(node): + from . import ir + + return isinstance(node, ir.Wait) or type(node) == ir._WaitKernel + + +def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int): + "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)" + num_rng_seed_offset_inputs = ( + 2 if torch._functorch.config.functionalize_rng_ops else 0 + ) + return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs + + +def count_tangents(fx_g: torch.fx.GraphModule): + """ + Infers which inputs are static for a backwards graph + """ + + def is_saved_tensor(x): + return ( + "tangents" not in x.name + and "bwd_seed" not in x.name + and "bwd_base_offset" not in x.name + ) + + arg_count = 0 + static_arg_idxs = [] + for n in fx_g.graph.nodes: + if n.op == "placeholder": + if is_saved_tensor(n): + static_arg_idxs.append(arg_count) + arg_count += 1 + + assert static_arg_idxs == list(range(len(static_arg_idxs))) + return len(static_arg_idxs) + + +@dataclasses.dataclass +class BoxedBool: + value: bool + + def __bool__(self): + return self.value + + @staticmethod + def disable(obj): + if isinstance(obj, BoxedBool): + obj.value = False + return obj + return False + + +@contextlib.contextmanager +def collect_defined_kernels(kernel_list): + from .codegen.wrapper import WrapperCodeGen + + orig_define_kernel = WrapperCodeGen.define_kernel + + def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs): + nonlocal kernel_list + kernel_list.append(kernel_code) + return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs) + + with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel): + yield diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/wrapper_benchmark.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/wrapper_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..c0205659ef72bd43641339b7de73749ddfc9bc8e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/wrapper_benchmark.py @@ -0,0 +1,299 @@ +import dataclasses +import tempfile +from collections import defaultdict + +import torch +from torch.autograd import DeviceType +from .utils import create_bandwidth_info_str, do_bench, get_num_bytes + +_kernel_category_choices = [ + "foreach", + "persistent_reduction", + "pointwise", + "reduction", + "split_scan", + "template", +] + + +def get_kernel_category_by_source_code(src_code): + """ + Similar to get_kernel_category but use the source code. Call this API + if we have not compile the src_code to module yet. + """ + choices = [ + ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code + ] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_kernel_category(kernel_mod): + """ + Given the module defining a triton kernel, return the category of the kernel. + Category can be one of: + - pointwise + - reduction + - persistent_reduction + + Currently we simply decide the category depending on what decorator is imported + by the kernel. + """ + choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_triton_kernel(mod): + from torch._inductor.triton_heuristics import CachingAutotuner + + cand_list = [ + v + for k, v in mod.__dict__.items() + if k.startswith("triton_") and isinstance(v, CachingAutotuner) + ] + assert len(cand_list) == 1 + return cand_list[0] + + +def benchmark_all_kernels(benchmark_name, benchmark_all_configs): + """ + An experimental API used only when config.benchmark_kernel is true. + + Run the kernel benchmarks for all the kernels cached in PyCodeCache. + Used in the compiled modules. + + Put this method here rather than codegen it for convenience since its implementation + does not change based on different graph modules being compiled. + """ + from torch._inductor.codecache import PyCodeCache + + nfound = 0 + for kernel_key, kernel_mod in PyCodeCache.cache.items(): + if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"): + continue + + triton_kernel = get_triton_kernel(kernel_mod) + kernel_category = get_kernel_category(kernel_mod) + args = kernel_mod.get_args() + num_in_out_ptrs = len( + [ + arg_name + for arg_name in triton_kernel.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + + def get_info_str(ms, n_regs, n_spills, shared, prefix=""): + if not any(x is None for x in [n_regs, n_spills, shared]): + kernel_detail_str = ( + f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem" + ) + else: + kernel_detail_str = "" + + gb_per_s = num_gb / (ms / 1e3) + return create_bandwidth_info_str( + ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str + ) + + kernel_desc = ( + f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}" + ) + if benchmark_all_configs: + assert hasattr(kernel_mod, "benchmark_all_configs") + bench_result = kernel_mod.benchmark_all_configs(args) + print(kernel_desc) + for launcher, ms in bench_result.items(): + print( + f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" + ) + else: + ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True) + assert ( + len(triton_kernel.launchers) == 1 + ), "Autotuner should have selected the best config" + launcher = triton_kernel.launchers[0] + print( + get_info_str( + ms, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + prefix=f"{kernel_desc} ", + ) + ) + + nfound += 1 + if nfound == 0: + print( + "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True" + ) + + +@dataclasses.dataclass +class ProfileEvent: + category: str + key: str + self_cuda_time_ms: float + # the benchmark is run multiple times and we average the count across all the + # runs. It should be an integer but define a float just in case. + count: float + + +def parse_profile_event_list(benchmark_name, event_list, wall_time_ms, nruns): + def get_self_cuda_time(ev): + """ + ev.self_cuda_time_total is in microsecond. Convert to millisecond. + """ + return ev.self_cuda_time_total / 1000 / nruns + + all_events = defaultdict(list) + + def add_event(ev, category): + profile_ev = ProfileEvent( + category=category, + key=ev.key, + self_cuda_time_ms=get_self_cuda_time(ev), + count=ev.count / nruns, # average across all runs + ) + all_events[category].append(profile_ev) + + for ev in event_list: + assert not ev.is_legacy, "Don't support the legacy profiler" + if ev.device_type == DeviceType.CPU: + # ignore the event on CPU side + continue + + category = "unknown" + if ev.key.startswith("triton_"): + if ev.key.startswith("triton_poi"): + category = "triton_pointwise" + elif ev.key.startswith("triton_red"): + category = "triton_reduction" + elif ev.key.startswith("triton_per"): + category = "triton_persistent_reduction" + else: + category = "triton_unknown" + + add_event(ev, category) + + def report_category(category, profile_events): + from tabulate import tabulate + + profile_events.sort(key=lambda ev: ev.self_cuda_time_ms, reverse=True) + + rows = [] + total_time = 0.0 + print(f"\n == {category} category kernels == ") + for ev in profile_events: + total_time += ev.self_cuda_time_ms + percent = f"{ev.self_cuda_time_ms / wall_time_ms * 100:.2f}%" + rows.append([ev.key[:120], ev.self_cuda_time_ms, ev.count, percent]) + rows.append( + ["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"] + ) + print( + tabulate( + rows, headers=["Kernel", "Self CUDA TIME (ms)", "Count", "Percent"] + ) + ) + return total_time + + def report(): + category_list = [ + "triton_pointwise", + "triton_reduction", + "triton_persistent_reduction", + "triton_unknown", + "unknown", + ] + assert set(all_events.keys()).issubset( + set(category_list) + ), f"{list(all_events.keys())}" + + per_category_wall_time = {} + total_cuda_ms = 0.0 + for category in category_list: + if category in all_events: + _time = report_category(category, all_events[category]) + per_category_wall_time[category] = _time + total_cuda_ms += _time + + gpu_busy_percent = f"{total_cuda_ms / wall_time_ms * 100:.2f}%" + print(f"\nPercent of time when GPU is busy: {gpu_busy_percent}") + print(f"Total wall time {wall_time_ms:.3f} ms") + + # output such a line so we can gather such line from all compiled modules from all + # benchmarks and tabulate it! + # Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent, + # unknown_category_percent, GPU_busy_percent, wall_time_ms + tabulate_line = f"Output for tabulate: {benchmark_name}" + for category in category_list: + percent = ( + f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%" + ) + tabulate_line += f", {percent}" + tabulate_line += f", {gpu_busy_percent}, {wall_time_ms:.3f}ms" + + print(tabulate_line) + + report() + + +def compiled_module_main(benchmark_name, benchmark_compiled_module_fn): + """ + This is the function called in __main__ block of a compiled module. + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark-kernels", + "-k", + action="store_true", + help="Whether to benchmark each individual kernels", + ) + parser.add_argument( + "--benchmark-all-configs", + "-c", + action="store_true", + help="Whether to benchmark each individual config for a kernel", + ) + parser.add_argument( + "--profile", + "-p", + action="store_true", + help="Whether to profile the compiled module", + ) + args = parser.parse_args() + + if args.benchmark_kernels: + benchmark_all_kernels(benchmark_name, args.benchmark_all_configs) + else: + times = 10 + repeat = 10 + wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000 + + if not args.profile: + return + + with torch.profiler.profile(record_shapes=True) as p: + benchmark_compiled_module_fn(times=times, repeat=repeat) + + path = f"{tempfile.gettempdir()}/compiled_module_profile.json" + p.export_chrome_trace(path) + print(f"Profiling result for a compiled module of benchmark {benchmark_name}:") + print(f"Chrome trace for the profile is written to {path}") + event_list = p.key_averages(group_by_input_shape=True) + print(event_list.table(sort_by="self_cuda_time_total", row_limit=10)) + parse_profile_event_list( + benchmark_name, event_list, wall_time_ms, times * repeat + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h new file mode 100644 index 0000000000000000000000000000000000000000..cb652fffcb14819d8ca5292daa012ad47f4c3fad --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h new file mode 100644 index 0000000000000000000000000000000000000000..71836a9e25d3d82d9cd5024b2f33e147e14bf87e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h @@ -0,0 +1 @@ +#include diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DynamicLibrary.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DynamicLibrary.h new file mode 100644 index 0000000000000000000000000000000000000000..523a21985f225eb72ac23c562e990fc105bd1ed4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DynamicLibrary.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +class DynamicLibraryError : public Error { + using Error::Error; +}; + +} // namespace c10 + +namespace at { + +struct DynamicLibrary { + AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary); + + TORCH_API DynamicLibrary( + const char* name, + const char* alt_name = nullptr, + bool leak_handle = false); + + TORCH_API void* sym(const char* name); + + TORCH_API ~DynamicLibrary(); + + private: + bool leak_handle; + void* handle = nullptr; +}; + +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h new file mode 100644 index 0000000000000000000000000000000000000000..392e2a27b0130c7ba55621d6ac1d6fd4e989db02 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h @@ -0,0 +1 @@ +#include diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..2111758cb07be2a4ab5bfe932688ed394e53d1e8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h @@ -0,0 +1,324 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..a9a0b4ecdcf8b9e323d41f0b39941528a2f0b0cd --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h @@ -0,0 +1,86 @@ +#pragma once +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at { + +// Note [Tensor-subclass-like Tensors] +// Tensor-subclass-like is defined as: +// - a Tensor subclass (via __torch_dispatch__ in Python or extending +// TensorImpl in C++) +// - anything else that shares the same perils as Tensor subclasses. +// For example, many Tensor subclasses do not have storage and meta Tensors +// do not have storage either, so meta Tensors belong here. +// +// We should ensure that PyTorch internals supports Tensor-subclass-like +// objects. In particular, Tensor-subclass-like objects struggle with two +// classes of operations that are problematic for Tensor subclasses: +// 1. Because some Tensor subclasses do not have storage, .item() or +// .data_ptr() calls are not good. +// 2. Certain in-place operations can eliminate the typing of the Tensor +// subclass. For example: +// >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input) +// If input is a Tensor subclass, then the above ends up either erroring out +// or returning a regular non-Tensor-subclass Tensor! + +constexpr auto kFunctorchWrappedTensors = DispatchKeySet( + {DispatchKey::FuncTorchGradWrapper, + DispatchKey::FuncTorchBatched, + DispatchKey::Functionalize}); + +constexpr auto kTensorSubclassLike = + kFunctorchWrappedTensors | + DispatchKeySet( + {// WARNING: DO NOT put combined backend component + functionality keys + // here, you will incorrectly always match on the functionality key + // no matter the backend component + DispatchKey::Batched, + DispatchKey::Sparse, + DispatchKey::SparseCsr, + DispatchKey::Python}) | + DispatchKeySet(BackendComponent::MetaBit); + +inline bool isTensorSubclassLike(const Tensor& tensor) { + if (c10::impl::dispatch_mode_enabled()) + return true; + auto key_set = tensor.unsafeGetTensorImpl()->key_set(); + return !(key_set & kTensorSubclassLike).empty(); +} + +inline bool areAnyTensorSubclassLike(TensorList tensors) { + if (c10::impl::dispatch_mode_enabled()) + return true; + return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike); +} + +inline bool areAnyOptionalTensorSubclassLike( + const c10::List>& tensors) { + if (c10::impl::dispatch_mode_enabled()) + return true; + return std::any_of( + tensors.begin(), tensors.end(), [](const optional& opt_tensor) { + return ( + opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value())); + }); +} + +// Helper function to deal testing truthfulness of a scalar tensor +// in a Composite Compliant manner. +// NOTE: This function expects a scalar tensor of boolean dtype. +// Eg. +// Non-Composite Compliant Pattern : (t == 0).all().item() +// Composite Compliant Patter : is_salar_tensor_true((t == 0).all()) +inline bool is_scalar_tensor_true(const Tensor& t) { + TORCH_INTERNAL_ASSERT(t.dim() == 0) + TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool) + return at::equal(t, t.new_ones({}, t.options())); +} + +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h new file mode 100644 index 0000000000000000000000000000000000000000..c64643546a2c1097a7a323dafc6cf5079d1b2fd9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include +#include + +#include + +// Use TORCH_CUDA_CPP_API or TORCH_CUDA_CU_API for exports from this folder diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContext.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContext.h new file mode 100644 index 0000000000000000000000000000000000000000..0cb024dd701b284502965cba681f1f9beb214592 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContext.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +// Preserved for BC, as many files depend on these includes +#include +#include +#include +#include diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h new file mode 100644 index 0000000000000000000000000000000000000000..7b3f5c391f2ac074b2a50961fafc1b42dc1a2320 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h @@ -0,0 +1,115 @@ +#pragma once + +#include + +#include +#include + +namespace at::cuda { + +template +cudaDataType getCudaDataType() { + TORCH_INTERNAL_ASSERT(false, "Cannot convert type ", typeid(scalar_t).name(), " to cudaDataType.") +} + +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16F; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_32F; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_64F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_16F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_32F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_64F; +} + +// HIP doesn't define integral types +#ifndef USE_ROCM +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_8U; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_8I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_32I; +} +#endif + +#if !defined(USE_ROCM) +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_64I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16BF; +} +#endif + +inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) { + switch (scalar_type) { +// HIP doesn't define integral types +#ifndef USE_ROCM + case c10::ScalarType::Byte: + return CUDA_R_8U; + case c10::ScalarType::Char: + return CUDA_R_8I; + case c10::ScalarType::Int: + return CUDA_R_32I; +#endif + case c10::ScalarType::Half: + return CUDA_R_16F; + case c10::ScalarType::Float: + return CUDA_R_32F; + case c10::ScalarType::Double: + return CUDA_R_64F; + case c10::ScalarType::ComplexHalf: + return CUDA_C_16F; + case c10::ScalarType::ComplexFloat: + return CUDA_C_32F; + case c10::ScalarType::ComplexDouble: + return CUDA_C_64F; +#if !defined(USE_ROCM) + case c10::ScalarType::Short: + return CUDA_R_16I; + case c10::ScalarType::Long: + return CUDA_R_64I; + case c10::ScalarType::BFloat16: + return CUDA_R_16BF; +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + case c10::ScalarType::Float8_e4m3fn: + return CUDA_R_8F_E4M3; + case c10::ScalarType::Float8_e5m2: + return CUDA_R_8F_E5M2; +#endif +#else // USE_ROCM + case c10::ScalarType::BFloat16: + return CUDA_R_16BF; +#if defined(HIP_NEW_TYPE_ENUMS) + case c10::ScalarType::Float8_e4m3fnuz: + return HIP_R_8F_E4M3_FNUZ; + case c10::ScalarType::Float8_e5m2fnuz: + return HIP_R_8F_E5M2_FNUZ; +#else + case c10::ScalarType::Float8_e4m3fnuz: + return static_cast(1000); + case c10::ScalarType::Float8_e5m2fnuz: + return static_cast(1001); +#endif +#endif + default: + TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.") + } +} + +} // namespace at::cuda diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Exceptions.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Exceptions.h new file mode 100644 index 0000000000000000000000000000000000000000..c647bc2531b4bb624430ea454805197f68bfca0d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Exceptions.h @@ -0,0 +1,174 @@ +#pragma once + +#include +#include +#include + +#ifdef CUDART_VERSION +#include +#endif + +#include +#include +#include + + +namespace c10 { + +class CuDNNError : public c10::Error { + using Error::Error; +}; + +} // namespace c10 + +#define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \ + do { \ + auto error_object = EXPR; \ + if (!error_object.is_good()) { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN Frontend error: ", error_object.get_message()); \ + } \ + } while (0) \ + +#define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__) + +// See Note [CHECK macro] +#define AT_CUDNN_CHECK(EXPR, ...) \ + do { \ + cudnnStatus_t status = EXPR; \ + if (status != CUDNN_STATUS_SUCCESS) { \ + if (status == CUDNN_STATUS_NOT_SUPPORTED) { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN error: ", \ + cudnnGetErrorString(status), \ + ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \ + } else { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \ + } \ + } \ + } while (0) + +namespace at::cuda::blas { +C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error); +} // namespace at::cuda::blas + +#define TORCH_CUDABLAS_CHECK(EXPR) \ + do { \ + cublasStatus_t __err = EXPR; \ + TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \ + "CUDA error: ", \ + at::cuda::blas::_cublasGetErrorEnum(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +const char *cusparseGetErrorString(cusparseStatus_t status); + +#define TORCH_CUDASPARSE_CHECK(EXPR) \ + do { \ + cusparseStatus_t __err = EXPR; \ + TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \ + "CUDA error: ", \ + cusparseGetErrorString(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +// cusolver related headers are only supported on cuda now +#ifdef CUDART_VERSION + +namespace at::cuda::solver { +C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); + +constexpr const char* _cusolver_backend_suggestion = \ + "If you keep seeing this error, you may use " \ + "`torch.backends.cuda.preferred_linalg_library()` to try " \ + "linear algebra operators with other supported backends. " \ + "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; + +} // namespace at::cuda::solver + +// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan. +// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue. +#define TORCH_CUSOLVER_CHECK(EXPR) \ + do { \ + cusolverStatus_t __err = EXPR; \ + if ((CUDA_VERSION < 11500 && \ + __err == CUSOLVER_STATUS_EXECUTION_FAILED) || \ + (CUDA_VERSION >= 11500 && \ + __err == CUSOLVER_STATUS_INVALID_VALUE)) { \ + TORCH_CHECK_LINALG( \ + false, \ + "cusolver error: ", \ + at::cuda::solver::cusolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`", \ + ". This error may appear if the input matrix contains NaN. ", \ + at::cuda::solver::_cusolver_backend_suggestion); \ + } else { \ + TORCH_CHECK( \ + __err == CUSOLVER_STATUS_SUCCESS, \ + "cusolver error: ", \ + at::cuda::solver::cusolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`. ", \ + at::cuda::solver::_cusolver_backend_suggestion); \ + } \ + } while (0) + +#else +#define TORCH_CUSOLVER_CHECK(EXPR) EXPR +#endif + +#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR) + +// For CUDA Driver API +// +// This is here instead of in c10 because NVRTC is loaded dynamically via a stub +// in ATen, and we need to use its nvrtcGetErrorString. +// See NOTE [ USE OF NVRTC AND DRIVER API ]. +#if !defined(USE_ROCM) + +#define AT_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + const char* err_str; \ + CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \ + if (get_error_str_err != CUDA_SUCCESS) { \ + AT_ERROR("CUDA driver error: unknown error"); \ + } else { \ + AT_ERROR("CUDA driver error: ", err_str); \ + } \ + } \ + } while (0) + +#else + +#define AT_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + AT_ERROR("CUDA driver error: ", static_cast(__err)); \ + } \ + } while (0) + +#endif + +// For CUDA NVRTC +// +// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE, +// incorrectly produces the error string "NVRTC unknown error." +// The following maps it correctly. +// +// This is here instead of in c10 because NVRTC is loaded dynamically via a stub +// in ATen, and we need to use its nvrtcGetErrorString. +// See NOTE [ USE OF NVRTC AND DRIVER API ]. +#define AT_CUDA_NVRTC_CHECK(EXPR) \ + do { \ + nvrtcResult __err = EXPR; \ + if (__err != NVRTC_SUCCESS) { \ + if (static_cast(__err) != 7) { \ + AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \ + } else { \ + AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \ + } \ + } \ + } while (0) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h new file mode 100644 index 0000000000000000000000000000000000000000..40ab07077293d16baea2ff294ed15387f4bffc91 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h @@ -0,0 +1,85 @@ +// Copyright © 2022 Apple Inc. + +#pragma once +#include +#include +#include + + +#ifdef __OBJC__ +#include +#include +#include +typedef id MTLDevice_t; +typedef id MTLLibrary_t; +typedef id MTLComputePipelineState_t; +typedef id MTLLibrary_t; +#else +typedef void* MTLDevice; +typedef void* MTLDevice_t; +typedef void* MTLLibrary_t; +typedef void* MTLComputePipelineState_t; +typedef void* MTLLibrary_t; +#endif + +using namespace std; + +namespace at::mps { + +// Helper enum to check if a MPSGraph op is supported in a given macOS version +enum class MacOSVersion : uint32_t { + MACOS_VER_13_0_PLUS = 0, + MACOS_VER_13_1_PLUS, + MACOS_VER_13_2_PLUS, + MACOS_VER_13_3_PLUS, + MACOS_VER_14_0_PLUS, +}; + +//----------------------------------------------------------------- +// MPSDevice +// +// MPSDevice is a singleton class that returns the default device +//----------------------------------------------------------------- + +class TORCH_API MPSDevice { + public: + /** + * MPSDevice should not be cloneable. + */ + MPSDevice(MPSDevice& other) = delete; + /** + * MPSDevice should not be assignable. + */ + void operator=(const MPSDevice&) = delete; + /** + * Gets single instance of the Device. + */ + static MPSDevice* getInstance(); + /** + * Returns the single device. + */ + MTLDevice_t device() { + return _mtl_device; + } + /** + * Returns whether running on Ventura or newer + */ + bool isMacOS13Plus(MacOSVersion version) const; + + MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel); + MTLLibrary_t getMetalIndexingLibrary(); + + ~MPSDevice(); + + private: + static MPSDevice* _device; + MTLDevice_t _mtl_device; + MTLLibrary_t _mtl_indexing_library; + MPSDevice(); +}; + +TORCH_API bool is_available(); +TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS); +TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false); + +} // namespace at::mps diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h new file mode 100644 index 0000000000000000000000000000000000000000..c463c80e1c6dcfff66bde315d59bf7fcb73e9860 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)( + TensorList, + Tensor&, + const Tensor&); + +using _amp_update_scale_cpu__fn = Tensor& (*)( + Tensor&, + Tensor&, + const Tensor&, + double, + double, + int64_t); + +DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub); +DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub); + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h new file mode 100644 index 0000000000000000000000000000000000000000..3b30df1c21fad9473c9b588adc6fb82308150039 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h @@ -0,0 +1,189 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace at::native::cpublas { + +namespace internal { +void normalize_last_dims( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + int64_t *lda, int64_t *ldb, int64_t *ldc); +} // namespace internal + +using gemm_fn = void(*)( + at::ScalarType type, + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const Scalar& alpha, + const void *a, int64_t lda, + const void *b, int64_t ldb, + const Scalar& beta, + void *c, int64_t ldc); + +DECLARE_DISPATCH(gemm_fn, gemm_stub); + +template +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + at::opmath_type alpha, + const scalar_t *a, int64_t lda, + const scalar_t *b, int64_t ldb, + at::opmath_type beta, + scalar_t *c, int64_t ldc) { + internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); + gemm_stub( + kCPU, c10::CppTypeToScalarType::value, + transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + double alpha, + const double *a, int64_t lda, + const double *b, int64_t ldb, + double beta, + double *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const float *a, int64_t lda, + const float *b, int64_t ldb, + float beta, + float *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const at::BFloat16 *a, int64_t lda, + const at::BFloat16 *b, int64_t ldb, + float beta, + at::BFloat16 *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const float alpha, + const at::BFloat16 *a, int64_t lda, + const at::BFloat16 *b, int64_t ldb, + const float beta, + float *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const at::Half *a, int64_t lda, + const at::Half *b, int64_t ldb, + float beta, + at::Half *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const float alpha, + const at::Half *a, int64_t lda, + const at::Half *b, int64_t ldb, + const float beta, + float *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + c10::complex alpha, + const c10::complex *a, int64_t lda, + const c10::complex *b, int64_t ldb, + c10::complex beta, + c10::complex *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + c10::complex alpha, + const c10::complex *a, int64_t lda, + const c10::complex *b, int64_t ldb, + c10::complex beta, + c10::complex *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + int64_t alpha, + const int64_t *a, int64_t lda, + const int64_t *b, int64_t ldb, + int64_t beta, + int64_t *c, int64_t ldc); + +template +void gemm_batched( + TransposeType transa, TransposeType transb, + int64_t batch_size, int64_t m, int64_t n, int64_t k, + scalar_t alpha, + const scalar_t * const *a, int64_t lda, + const scalar_t * const *b, int64_t ldb, + const scalar_t beta, + scalar_t * const *c, int64_t ldc); + +template +void gemm_batched_with_stride( + TransposeType transa, TransposeType transb, + int64_t batch_size, int64_t m, int64_t n, int64_t k, + scalar_t alpha, + const scalar_t *a, int64_t lda, int64_t batch_stride_a, + const scalar_t *b, int64_t ldb, int64_t batch_stride_b, + scalar_t beta, + scalar_t *c, int64_t ldc, int64_t batch_stride_c); + +using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy); + +DECLARE_DISPATCH(axpy_fn, axpy_stub); + +template +void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){ + if(n == 1) + { + incx = 1; + incy = 1; + } + axpy_stub( + kCPU, c10::CppTypeToScalarType::value, + n, a, x, incx, y, incy); +} + +void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy); +void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy); +void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); +void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); + +using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy); + +DECLARE_DISPATCH(copy_fn, copy_stub); + +template +void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) { + if(n == 1) + { + incx = 1; + incy = 1; + } + copy_stub( + kCPU, c10::CppTypeToScalarType::value, + n, x, incx, y, incy); +} + +void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy); +void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy); +void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); +void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); + +} // namespace at::native::cpublas diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..9111c3515afcefec2d81a261737ec28bcae00cdc --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h @@ -0,0 +1,263 @@ +#include + +#pragma once + +namespace at::native { + +namespace { + +// operator_brackets_proxy is used in +// CompositeRandomAccessor in place of operator[]. +// For some iterators, references returned by operator[] +// could become invalid, operator_brackets_proxy tries to +// resolve that by making accessor[n] to be equivalent to +// *(accessor + n). +template +class operator_brackets_proxy { + using reference = typename std::iterator_traits::reference; + using value_type = typename std::iterator_traits::value_type; + +public: + C10_HOST_DEVICE + operator_brackets_proxy(Accessor const& accessor) + : accessor(accessor) + {} + + C10_HOST_DEVICE + operator reference() { + return *accessor; + } + + C10_HOST_DEVICE + reference operator*() { + return *accessor; + } + + C10_HOST_DEVICE + operator_brackets_proxy& operator=(value_type const& val) { + *accessor = val; + return *this; + } + +private: + Accessor accessor; +}; + +} + +// references_holder is used as a surrogate for the +// references type from std::iterator_traits in CompositeRandomAccessor. +// It is assumed in CompositeRandomAccessor that +// References = tuple, +// Values = tuple by default, +// but they could be anything as long as References could be +// cast to Values. +// If you plan to use it with STL, for example, you will need to +// define 'swap` and `get`(aka std::get) methods. +template +class references_holder { +public: + using values = Values; + using references = References; + + C10_HOST_DEVICE + references_holder(references refs) + : refs{std::move(refs)} + {} + + C10_HOST_DEVICE + operator references() { + return refs; + } + + C10_HOST_DEVICE + operator values() { + return refs; + } + + C10_HOST_DEVICE + references_holder& operator=(values vals) { + refs = vals; + return *this; + } + + C10_HOST_DEVICE + references& data() { + return refs; + } + +protected: + references refs; +}; + +// CompositeRandomAccessor is essentially a simplified version of +// a random access iterator over two random access iterators. +// TupleInfo should contain a variadic type `tuple`, and a method `tie`, +// which constructs a tuple of references from a variadic list of arguments. +template +class CompositeRandomAccessor { + using self_type = CompositeRandomAccessor; + + using key_accessor_value_type = + typename std::iterator_traits::value_type; + using value_accessor_value_type = + typename std::iterator_traits::value_type; + using key_accessor_reference_type = + typename std::iterator_traits::reference; + using value_accessor_reference_type = + typename std::iterator_traits::reference; + + using composite_value_type = typename TupleInfo::template tuple< + key_accessor_value_type, + value_accessor_value_type>; + using composite_reference = typename TupleInfo::template tuple< + key_accessor_reference_type, + value_accessor_reference_type>; + +public: + using value_type = composite_value_type; + using reference = references_holder; + // Note that CompositeRandomAccessor does not hold key and values + // in a specific datastructure, which means that a pointer to a (key, value) + // is not defined. Hence we just use a pointer type of the KeyAccessor. + using pointer = typename std::iterator_traits::pointer; + using difference_type = typename std::iterator_traits::difference_type; + using iterator_category = std::random_access_iterator_tag; + + C10_HOST_DEVICE + CompositeRandomAccessor() = default; + + C10_HOST_DEVICE + CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values) + : keys(keys), values(values) + {} + + // Pointer-like operations { + C10_HOST_DEVICE + reference operator*() const { + return TupleInfo::tie(*keys, *values); + } + + // operator->() is supposed to return a pointer type. + // Since CompositeRandomAccessor does not hold pointers to pairs, + // we just return a pointer to a key. + C10_HOST_DEVICE + auto* operator->() const { + return keys.operator->(); + } + + C10_HOST_DEVICE + reference operator[](difference_type idx) { + return operator_brackets_proxy( + CompositeRandomAccessor(keys + idx, values + idx) + ); + } + // } + + // Prefix/postfix increment/decrement { + C10_HOST_DEVICE + CompositeRandomAccessor& operator++() { + ++keys; + ++values; + return *this; + } + + C10_HOST_DEVICE + CompositeRandomAccessor operator++(int) { + CompositeRandomAccessor copy(*this); + ++*this; + return copy; + } + + C10_HOST_DEVICE + CompositeRandomAccessor& operator--() { + --keys; + --values; + return *this; + } + + C10_HOST_DEVICE + CompositeRandomAccessor operator--(int) { + CompositeRandomAccessor copy(*this); + --*this; + return copy; + } + // } + + // Arithmetic operations { + C10_HOST_DEVICE + CompositeRandomAccessor& operator+=(difference_type offset) { + keys += offset; + values += offset; + return *this; + } + + C10_HOST_DEVICE + CompositeRandomAccessor operator+(difference_type offset) const { + return CompositeRandomAccessor(keys + offset, values + offset); + } + + C10_HOST_DEVICE + friend CompositeRandomAccessor operator+( + difference_type offset, + const CompositeRandomAccessor& accessor + ) { + return accessor + offset; + } + + C10_HOST_DEVICE + CompositeRandomAccessor& operator-=(difference_type offset) { + keys -= offset; + values -= offset; + return *this; + } + + C10_HOST_DEVICE + CompositeRandomAccessor operator-(difference_type offset) const { + return CompositeRandomAccessor(keys - offset, values - offset); + } + + C10_HOST_DEVICE + difference_type operator-(const CompositeRandomAccessor& other) const { + return keys - other.keys; + } + // } + + // Comparison operators { + C10_HOST_DEVICE + bool operator==(const CompositeRandomAccessor& other) const { + return keys == other.keys; + } + + C10_HOST_DEVICE + bool operator!=(const CompositeRandomAccessor& other) const { + return keys != other.keys; + } + + C10_HOST_DEVICE + bool operator<(const CompositeRandomAccessor& other) const { + return keys < other.keys; + } + + C10_HOST_DEVICE + bool operator<=(const CompositeRandomAccessor& other) const { + return keys <= other.keys; + } + + C10_HOST_DEVICE + bool operator>(const CompositeRandomAccessor& other) const { + return keys > other.keys; + } + + C10_HOST_DEVICE + bool operator>=(const CompositeRandomAccessor& other) const { + return keys >= other.keys; + } + // } + +protected: + KeyAccessor keys; + ValueAccessor values; +}; + +} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h new file mode 100644 index 0000000000000000000000000000000000000000..3de6763015c6616599a604ee169dacc55985a385 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h @@ -0,0 +1,14 @@ +#include + +namespace at::native { + +std::tuple slow_conv3d_backward_cpu( + const Tensor& grad_output, + const Tensor& self, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + std::array output_mask); + +} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..14abb32fa5ad4ba3cd8c78084569b313a4a692cd --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace at { + +class Tensor; +struct TensorIterator; +class TensorBase; + +namespace native { + +using copy_fn = void (*)(TensorIterator&, bool non_blocking); + +DECLARE_DISPATCH(copy_fn, copy_stub); + +TORCH_API void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src); + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..cd580020374a66aa058938e1186fbfd577a76980 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h @@ -0,0 +1,229 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \ + TORCH_CHECK( \ + T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \ + "Need " #T " of dimension ", \ + DIM, \ + " and " #T ".size[", \ + DIM_SIZE, \ + "] == ", \ + SIZE, \ + " but got input to be of shape ", \ + T.sizes()) + +namespace at::native::internal { +namespace { +inline bool all_positive(IntArrayRef& arr) { + return std::all_of( + arr.begin(), arr.end(), [](int64_t item) { return item > 0; }); +} + +inline bool all_nonnegative(std::vector& arr) { + return std::all_of( + arr.begin(), arr.end(), [](int64_t item) { return item >= 0; }); +} + +} // namespace + +// calculate the rear part of output tensor sizes +template +std::vector get_output_size( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride_size, + IntArrayRef pad_size, + IntArrayRef dilation_size) { + std::vector sizes; + for (const auto index : c10::irange(dim)) { + sizes.push_back( + div_rtn( + input.size(index + input.dim() - dim) + 2 * pad_size[index] - + (dilation_size[index] * (kernel_size[index] - 1) + 1), + stride_size[index]) + + 1); + } + return sizes; +} + +// calculate the sizes of output tensor +template +std::vector get_output_size( + const Tensor& input, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride_size, + IntArrayRef pad_size, + IntArrayRef dilation_size) { + auto output_size = get_output_size( + input, kernel_size, stride_size, pad_size, dilation_size); + output_size.insert(output_size.begin(), weight.size(0)); + if (input.dim() == dim + 2) { + output_size.insert(output_size.begin(), input.size(0)); + } + return output_size; +} +/* + slow_conv_dilated_shape_check - check user-input to dilated convolution + forward and backward functions. +*/ +template +void slow_conv_dilated_shape_check( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const Tensor& grad_output, + IntArrayRef kernel_size, + IntArrayRef stride_size, + IntArrayRef pad_size, + IntArrayRef dilation_size) { + /* + When the following tensors are defined: + + bias, grad_weight, grad_output + + then these are assumed to be contiguous without checking + because of these tensors are made contiguous by calling + .contiguous() method or by resizing of zero-sized tensors in + forward/backward functions. + + When grad_weight is defined then it is assumed without + checking to have the same shape as weight, see backward + functions. + */ + // Check size arguments + TORCH_CHECK( + kernel_size.size() == dim, + "kernel sizes length should be ", + dim, + ", but got ", + kernel_size.size()); + TORCH_CHECK( + stride_size.size() == dim, + "strides length should be ", + dim, + ", but got ", + stride_size.size()); + TORCH_CHECK( + dilation_size.size() == dim, + "dilations length should be ", + dim, + ", but got ", + dilation_size.size()); + TORCH_CHECK( + pad_size.size() == dim, + "pads length should be ", + dim, + ", but got ", + pad_size.size()); + + TORCH_CHECK( + all_positive(kernel_size), + "kernel size should be greater than zero, but got ", + kernel_size); + TORCH_CHECK( + all_positive(stride_size), + "stride should be greater than zero, but got ", + stride_size); + TORCH_CHECK( + all_positive(dilation_size), + "dilation should be greater than zero, but got ", + dilation_size); + + // check input + TORCH_CHECK(input.defined(), "input must be defined"); + bool is_batch = input.dim() == dim + 2; + int64_t n = (is_batch ? 2 : 1); + int64_t ndim = n + dim; + if (!is_batch) { + // input dim has to be dim + 1 if not batched + TORCH_CHECK( + input.dim() == dim + 1, + "input must be 4D or 5D tensor but got ", + input.dim(), + "D tensor"); + } + + // check output sizes + auto output_size = get_output_size( + input, kernel_size, stride_size, pad_size, dilation_size); + + TORCH_CHECK( + all_nonnegative(output_size), + "calculated output size ", + output_size, + " is too small (all sizes must be non-negative)"); + + // check weight + TORCH_CHECK(weight.defined(), "weight must be defined"); + TORCH_CHECK( + weight.dim() == dim + 2, + "weight must be ", + dim + 2, + "D tensor but got ", + weight.dim(), + "D tensor dim=", + dim); + TORCH_CHECK( + weight.sizes().slice(2) == kernel_size, + "weight[2:] shape ", + weight.sizes().slice(2), + " must be equal to kernel_size ", + kernel_size); + + TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1)); + + // check bias when present + if (bias.defined()) { + TORCH_CHECK( + bias.dim() == 1, + "bias must be 1D tensor but got ", + bias.dim(), + "D tensor"); + TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0)); + } + + // check grad_output when present + if (grad_output.defined()) { + TORCH_CHECK( + grad_output.dim() == ndim, + "grad_output must be ", + ndim, + "D tensor but got ", + grad_output.dim(), + "D tensor"); + if (is_batch) { + TORCH_CHECK( + grad_output.size(0) == input.size(0), + "grad_output.size(0)=", + grad_output.size(0), + " must be input.size(0)=", + input.size(0)); + } + TORCH_CHECK( + grad_output.size(n - 1) == weight.size(0), + "grad_output.size(", + n - 1, + ")=", + grad_output.size(n - 1), + " must be weight.size(0)=", + weight.size(0)); + TORCH_CHECK( + grad_output.sizes().slice(n) == output_size, + "grad_output[", + n, + ":] shape", + grad_output.sizes().slice(n), + " must be equal to output size ", + output_size); + } +} + +} // namespace at::native::internal diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..9c22c35ee940138219e1c905eca2ed03e2ed1bf4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h @@ -0,0 +1,371 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include + +namespace at::native { +namespace { +// Check if tensor list has either a boolean tensor or a integer tensor +inline bool has_integral_tensor(TensorList tensors, const bool includeBool) { + return std::any_of( + tensors.begin(), tensors.end(), [&includeBool](const auto& t) { + return at::isIntegralType(t.scalar_type(), includeBool); + }); +} +// check if tensor list has bool tensors +inline bool has_bool_tensor(TensorList tensors) { + return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool { + return t.scalar_type() == ScalarType::Bool; + }); +} + +// Check foreach API restrictions +// - Tensor lists must be non-empty. +// - All TensorLists and ScalarLists must have the same number of elements. +// - Corresponding tensors must have the same size. +inline void check_foreach_api_restrictions(TensorList tensors) { + TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor."); +} + +inline void check_foreach_api_restrictions( + TensorList tensors, + ArrayRef scalars) { + check_foreach_api_restrictions(tensors); + TORCH_CHECK( + tensors.size() == scalars.size(), + "Tensor list must have same number of elements as scalar list."); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2) { + TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK( + tensors1.size() == tensors2.size(), + "Tensor lists must have the same number of tensors, got ", + tensors1.size(), + " and ", + tensors2.size()); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2, + TensorList tensors3) { + TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK( + tensors1.size() == tensors2.size(), + "Tensor lists must have the same number of tensors, got ", + tensors1.size(), + " and ", + tensors2.size()); + TORCH_CHECK( + tensors1.size() == tensors3.size(), + "Tensor lists must have the same number of tensors, got ", + tensors1.size(), + " and ", + tensors3.size()); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2, + TensorList tensors3, + ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2, tensors3); + TORCH_CHECK( + tensors1.size() == scalars.size(), + "Tensor list must have same number of elements as scalar list, got ", + tensors1.size(), + " and ", + scalars.size()); +} + +// Helper function called in check_fast_path_restrictions to check whether all +// corresponding tensors (aligning in index across the tensorLists) share the +// same device and dtype. +inline bool _check_tensors_share_device_and_dtype( + ArrayRef tensorLists) { + const auto expected_dtype = tensorLists[0][0].dtype(); + const auto expected_device = tensorLists[0][0].device(); + + auto is_tensor_okay = [&](const Tensor& tensor) { + return tensor.dtype() == expected_dtype && + tensor.device() == expected_device && tensor.layout() == at::kStrided && + tensor.is_non_overlapping_and_dense(); + }; + + for (const auto& tensorList : tensorLists) { + for (const auto& tensor : tensorList) { + if (!is_tensor_okay(tensor)) { + return false; + } + } + } + + return true; +} + +// Helper function called in check_fast_path_restrictions to check if +// corresponding tensors in tensor lists have the same sizes and strides. +inline bool _check_tensors_share_sizes_and_strides( + ArrayRef tensorLists) { + for (const auto i : c10::irange(1, tensorLists.size())) { + for (const auto j : c10::irange(tensorLists[0].size())) { + if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() || + tensorLists[0][j].strides() != tensorLists[i][j].strides()) { + return false; + } + } + } + + return true; +} + +// Helper function called in check_fast_path_restrictions to check whether +// all tensors type promote properly with the scalars in scalarList. This +// function assumes that _check_tensors_share_device_and_dtype has already been +// called so that all corresponding tensors in tensorLists have the same dtype. +// Then, it is sufficient to check the type promotion with just one tensorList. +inline bool _check_tensors_do_type_promotion_with_scalars( + TensorList tensorList, + ArrayRef scalarList = {}, + bool does_op_promote_integer_inputs_to_float = false) { + for (const auto i : c10::irange(tensorList.size())) { + // For division, integer inputs will result in float. + if (does_op_promote_integer_inputs_to_float) { + if (at::isIntegralType( + tensorList[i].scalar_type(), /*includeBool*/ true)) { + return false; + } + } + if (!scalarList.empty()) { + const auto& scalar = + scalarList.size() == 1 ? scalarList[0] : scalarList[i]; + const auto& tensor = tensorList[i]; + // note(mkozuki): This check might be responsible for + // `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path. + if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) { + return false; + } + } + } + + return true; +} + +// To go via 'fast' path, several conditions must be satisfied +// - All tensors in all lists must have the same dtype. +// - All tensors must be on the same device +// - All tensors must have strided layout +// - All tensors must be non-overlapping and dense +// - Resulting tensor must have the same dtype as the input one + +// Please, make sure to call check_foreach_api_restrictions before calling this +// method. There is a set of preconditions that have to be satisfied. +inline bool check_fast_path_restrictions( + ArrayRef tensorLists, + ArrayRef scalarList = {}, + bool does_op_promote_integer_inputs_to_float = false) { + return _check_tensors_share_device_and_dtype(tensorLists) && + _check_tensors_share_sizes_and_strides(tensorLists) && + _check_tensors_do_type_promotion_with_scalars( + tensorLists[0], + scalarList, + does_op_promote_integer_inputs_to_float); +} + +inline std::vector convert_tensor_to_scalar_list( + const Tensor& scalarList_, + int64_t expect_length) { + std::vector scalarList; + TORCH_CHECK( + scalarList_.device() == c10::kCPU, + "Expected scalars to be on CPU, got ", + scalarList_.device(), + " instead."); + TORCH_CHECK( + scalarList_.is_contiguous(), "Expected scalars to be contiguous."); + TORCH_CHECK( + scalarList_.dim() == 1, + "Expected packed scalar Tensor to be of dimension 1. Got ", + scalarList_.dim(), + " instead."); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + kComplexHalf, + kHalf, + kBool, + kBFloat16, + scalarList_.scalar_type(), + "convert_tensor_to_scalar_list", + [&]() { + const scalar_t* scalar_data = scalarList_.data_ptr(); + TORCH_CHECK( + (expect_length == scalarList_.size(0)), + "Expected length of scalars to match input of length ", + expect_length, + " but got ", + scalarList_.size(0), + " instead."); + for (int64_t i = 0; i < scalarList_.size(0); i++) { + scalarList.emplace_back(scalar_data[i]); + } + }); + return scalarList; +} + +inline bool can_use_fast_route( + ArrayRef tensorLists, + ArrayRef scalarList = {}, + bool does_op_promote_integer_inputs_to_float = false) { + return check_fast_path_restrictions( + tensorLists, scalarList, does_op_promote_integer_inputs_to_float); +} + +inline bool can_use_fast_route( + TensorList tensors1, + TensorList tensors2, + bool does_op_promote_integer_inputs_to_float = false) { + return can_use_fast_route( + {tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float); +} + +using DeviceDtypeKey = std::pair; +using IndicesT = std::vector; +using nested_optional_tensorvec_t = + std::vector>>; +using TensorsAndIndicesT = std::pair; +using FlatMap = std::unordered_map< + DeviceDtypeKey, + TensorsAndIndicesT, + ParamsHash>; + +inline FlatMap _group_tensors_by_first_tensors_device_and_dtype( + const nested_optional_tensorvec_t& nested_tensorlist, + const bool with_indices) { + FlatMap grouped_tensors_with_indices; + + TORCH_CHECK(!nested_tensorlist.empty()); + TORCH_CHECK(!nested_tensorlist[0].empty()); + const auto num_lists = nested_tensorlist.size(); + const auto num_tensors = nested_tensorlist[0].size(); + + TORCH_CHECK(std::all_of( + nested_tensorlist.cbegin(), + nested_tensorlist.cend(), + [&](const auto& tensorlist) -> bool { + // note(crcrpar): Allow empty tensorlists following + // ref: + // https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24 + return tensorlist.size() == num_tensors || tensorlist.size() == 0; + })); + + for (const auto& tensor_index : c10::irange(num_tensors)) { + const auto key = [&]() -> DeviceDtypeKey { + const auto t = nested_tensorlist[0][tensor_index]; + TORCH_CHECK( + t.has_value(), + "Tensors of the first list of nested Tensor lists are supposed to be defined but ", + "the ", + tensor_index, + "-th Tensor is not."); + return {t->device(), t->scalar_type()}; + }(); + TORCH_CHECK( + std::all_of( + nested_tensorlist.cbegin(), + nested_tensorlist.cend(), + [&](const auto& tensorlist) -> bool { + if (tensorlist.size() == 0) { + return true; + } + const auto& tensor = tensorlist[tensor_index]; + // note(crcrpar): Currently the scope of this function is + // optimizers so there could be `state_steps` and other scalars + // whose elements are float tensors no matter what the parameter's + // dtype is. + if (!tensor.has_value()) { + return true; + } else { + const auto s = tensor->scalar_type(); + const auto d = tensor->device(); + // Note: `step` or `state_step` is float32 by default. + if (key.first == d) { + return key.second == s || s == at::ScalarType::Float || + s == at::ScalarType::Double; + } else if (d.is_cpu()) { + // note(crcrpar): There are some test cases (e.g. + // TestOptim::test_adam) where state_steps are on CPU and the + // others are on CUDA. Currently a state_step Tensor has the + // dtype of float. + return s == at::ScalarType::Float || + s == at::ScalarType::Double; + } else { + return false; + } + } + }), + "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding"); + if (!grouped_tensors_with_indices.count(key)) { + grouped_tensors_with_indices.insert( + {key, + TensorsAndIndicesT{ + [&]() -> nested_optional_tensorvec_t { + nested_optional_tensorvec_t nested_tensorvec; + nested_tensorvec.reserve(num_lists); + for (const auto& i : c10::irange(num_lists)) { + std::vector> tensors; + if (!nested_tensorlist[i].empty()) { + // NB: num_tensors is the max possible length for any of + // the inner lists of tensor references. Reserving the max + // trades memory for perf. This should not have significant + // impact. + tensors.reserve(num_tensors); + } + nested_tensorvec.emplace_back(tensors); + } + return nested_tensorvec; + }(), + [&]() -> IndicesT { + if (!with_indices) { + return {}; + } else { + IndicesT indices; + indices.reserve(num_tensors); + return indices; + } + }()}}); + } + for (const auto& list_index : c10::irange(num_lists)) { + if (!nested_tensorlist[list_index].empty()) { + grouped_tensors_with_indices[key].first[list_index].emplace_back( + nested_tensorlist[list_index][tensor_index]); + } + } + if (with_indices) { + grouped_tensors_with_indices[key].second.emplace_back(tensor_index); + } + } + + return grouped_tensors_with_indices; +} + +} // namespace +} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h new file mode 100644 index 0000000000000000000000000000000000000000..54d44b23a011c0dd79989150712fe97856832d5d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +namespace c10 { +class Scalar; +} + +namespace at { +struct TensorIterator; +} + +namespace at::native { + +using addr_fn = void (*)(TensorIterator &, const Scalar& beta, const Scalar& alpha); +DECLARE_DISPATCH(addr_fn, addr_stub); +} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..a0f9cfa8bf0ce554d198b7e55079ed13e1798175 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::native { + +// ensure we get good values and indices for kthvalue, mode +// this will always be with the reducing dim as 1-d +inline void _reduction_with_indices_allocate_or_resize_output( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim_, + bool keepdim) { + int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true); + auto result_sizes = self.sizes().vec(); + if (!result_sizes.empty()) { + result_sizes[dim] = 1; + } + if (values.defined()) { + TORCH_CHECK( + self.options().type_equal(values.options()), + "output values must be of same type as input"); + if (!keepdim && values.dim() == self.dim() - 1) { + // unsqueeze to preserve passed in noncontiguous tensor in resize + values.unsqueeze_(dim); + } + resize_output(values, result_sizes); + } else { + values = at::empty(result_sizes, self.options()); + } + if (indices.defined()) { + TORCH_CHECK( + indices.dtype() == kLong, "output indices must be of scalar type Long"); + TORCH_CHECK( + indices.device() == self.device(), + "output indices must be on same device as input"); + if (!keepdim && indices.dim() == self.dim() - 1) { + // unsqueeze to preserve passed in noncontiguous tensor in resize + indices.unsqueeze_(dim); + } + resize_output(indices, result_sizes); + } else { + indices = at::empty(result_sizes, self.options().dtype(kLong)); + } +} + +// ensure we get good values and indices for topk +inline void _allocate_or_resize_output_with_indices( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim_, + int64_t k) { + int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true); + auto result_sizes = self.sizes().vec(); + if (!result_sizes.empty()) { + result_sizes[dim] = k; + } + if (values.defined()) { + TORCH_CHECK( + self.options().type_equal(values.options()), + "output values must be of same type as input"); + values.resize_(result_sizes); + } else { + values = at::empty(result_sizes, self.options()); + } + if (indices.defined()) { + TORCH_CHECK( + indices.dtype() == kLong, "output indices must be of scalar type Long"); + TORCH_CHECK( + indices.device() == self.device(), + "output indices must be on same device as input"); + indices.resize_(result_sizes); + } else { + indices = at::empty(result_sizes, self.options().dtype(kLong)); + } +} + +} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..91d4d84d4630c0ab73168662aa97388bac84b0e6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h @@ -0,0 +1,130 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +class Tensor; +class TensorBase; +struct TensorIteratorBase; +} + +namespace at::native { + +using unary_fn = void(*)(TensorIteratorBase&); +using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a); + +inline namespace CPU_CAPABILITY { +void conj_kernel(TensorIteratorBase &iter); +void neg_kernel(TensorIteratorBase &iter); +void reciprocal_kernel(TensorIteratorBase &iter); +void rsqrt_kernel(TensorIteratorBase& iter); +void sqrt_kernel(TensorIteratorBase& iter); +} // namespace CPU_CAPABILITY + +DECLARE_DISPATCH(unary_fn, abs_stub); +DECLARE_DISPATCH(unary_fn, angle_stub); +DECLARE_DISPATCH(unary_fn, conj_physical_stub); +DECLARE_DISPATCH(unary_fn, acos_stub); +DECLARE_DISPATCH(unary_fn, acosh_stub); +DECLARE_DISPATCH(unary_fn, asinh_stub); +DECLARE_DISPATCH(unary_fn, atanh_stub); +DECLARE_DISPATCH(unary_fn, asin_stub); +DECLARE_DISPATCH(unary_fn, atan_stub); +DECLARE_DISPATCH(unary_fn, bitwise_not_stub); +DECLARE_DISPATCH(unary_fn, logical_not_stub); +DECLARE_DISPATCH(unary_fn, ceil_stub); +DECLARE_DISPATCH(unary_fn, cos_stub); +DECLARE_DISPATCH(unary_fn, cosh_stub); +DECLARE_DISPATCH(unary_fn, digamma_stub); +DECLARE_DISPATCH(unary_fn, special_entr_stub); +DECLARE_DISPATCH(unary_fn, special_erfcx_stub); +DECLARE_DISPATCH(unary_fn, erf_stub); +DECLARE_DISPATCH(unary_fn, erfc_stub); +DECLARE_DISPATCH(unary_fn, erfinv_stub); +DECLARE_DISPATCH(unary_fn, exp_stub); +DECLARE_DISPATCH(unary_fn, exp2_stub); +DECLARE_DISPATCH(unary_fn, expm1_stub); +DECLARE_DISPATCH(unary_fn, floor_stub); +DECLARE_DISPATCH(unary_fn, frac_stub); +DECLARE_DISPATCH(unary_fn, frexp_stub); +DECLARE_DISPATCH(unary_fn, i0_stub); +DECLARE_DISPATCH(unary_fn, special_i0e_stub); +DECLARE_DISPATCH(unary_fn, special_i1_stub); +DECLARE_DISPATCH(unary_fn, special_i1e_stub); +DECLARE_DISPATCH(unary_fn, log_stub); +DECLARE_DISPATCH(unary_fn, log10_stub); +DECLARE_DISPATCH(unary_fn, log1p_stub); +DECLARE_DISPATCH(unary_fn, log2_stub); +DECLARE_DISPATCH(unary_fn, special_ndtri_stub); +DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub); +DECLARE_DISPATCH(unary_fn, neg_stub); + +DECLARE_DISPATCH(unary_fn, reciprocal_stub); +DECLARE_DISPATCH(unary_fn, round_stub); +DECLARE_DISPATCH(unary_fn, rsqrt_stub); +DECLARE_DISPATCH(unary_fn, sigmoid_stub); +DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub); +DECLARE_DISPATCH(unary_fn, sign_stub); +DECLARE_DISPATCH(unary_fn, signbit_stub); +DECLARE_DISPATCH(unary_fn, sgn_stub); +DECLARE_DISPATCH(unary_fn, sin_stub); +DECLARE_DISPATCH(unary_fn, sinc_stub); +DECLARE_DISPATCH(unary_fn, sinh_stub); +DECLARE_DISPATCH(unary_fn, sqrt_stub); +DECLARE_DISPATCH(unary_fn, tan_stub); +DECLARE_DISPATCH(unary_fn, tanh_stub); +DECLARE_DISPATCH(unary_fn, trigamma_stub); +DECLARE_DISPATCH(unary_fn, trunc_stub); +DECLARE_DISPATCH(unary_fn, lgamma_stub); +DECLARE_DISPATCH(unary_fn, special_airy_ai_stub); +DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub); +DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub); +DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub); +DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub); +DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub); +DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub); +DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub); +DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub); +DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub); +DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub); +DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub); + +// NB: these are actually defined in Distribution +DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, c10::optional), bernoulli_tensor_stub); +DECLARE_DISPATCH(void(*)(const TensorBase&, const double, c10::optional), bernoulli_scalar_stub); +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional), cauchy_stub); +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, c10::optional), exponential_stub); +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, c10::optional), geometric_stub); +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional), log_normal_stub); +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional), uniform_stub); +DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, c10::optional), normal_stub); +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, c10::optional), random_from_to_stub); +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, c10::optional), random_full_64_bits_range_stub); +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, c10::optional), random_stub); + +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub); +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub); +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub); +DECLARE_DISPATCH( + void (*)(Tensor&, const Tensor&, int64_t, c10::optional), + multinomial_with_replacement_stub); +DECLARE_DISPATCH( + void (*)( + TensorIteratorBase&, + c10::optional, + c10::optional, + c10::optional), + nan_to_num_stub); +DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub); + +// Missing unary functions +// digamma +// lgamma +// erfinv +// clone +// contiguous +// zero +} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h new file mode 100644 index 0000000000000000000000000000000000000000..5b24ee4821c45baab25f37a3bfa3399eff8a1716 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h @@ -0,0 +1,37 @@ +#ifndef ATOMIC_ADD_FLOAT +#define ATOMIC_ADD_FLOAT + +#if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__)) +#include +#else +#define _mm_pause() +#endif + +#include + +static inline void cpu_atomic_add_float(float* dst, float fvalue) +{ + typedef union { + unsigned intV; + float floatV; + } uf32_t; + + uf32_t new_value, old_value; + std::atomic* dst_intV = (std::atomic*)(dst); + + old_value.floatV = *dst; + new_value.floatV = old_value.floatV + fvalue; + + unsigned* old_intV = (unsigned*)(&old_value.intV); + while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) { +#ifdef __aarch64__ + __asm__ __volatile__("yield;" : : : "memory"); +#else + _mm_pause(); +#endif + old_value.floatV = *dst; + new_value.floatV = old_value.floatV + fvalue; + } +} + +#endif diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..aedb4aec4f574700ab1060dd17d0c5dcd9846f79 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include +#include + +namespace at { namespace native { + +using cat_serial_fn = void(*)(const Tensor &, const MaterializedITensorListRef&, int64_t); +DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub); + +}} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..10e592cf59eb751bbd556597905b4c4279229eaa --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at { namespace native { + +using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t); +DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel); + +}} // at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..80970074b8e6c99d079f26aa6f576e67228a04f7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +/* + Depthwise 3x3 Winograd convolution operator +*/ + +namespace at { +class Tensor; + +namespace native { + +using convolution_depthwise3x3_winograd_fn = + Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t); + +DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub); + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h new file mode 100644 index 0000000000000000000000000000000000000000..f3b35328f1882729a9158eaed7eb2abf77097484 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h @@ -0,0 +1,33 @@ +#pragma once + +#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__)) +/* Clang-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(_MSC_VER) +/* Microsoft C/C++-compatible compiler */ +#include +#if _MSC_VER <= 1900 +#define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y]) +#endif +#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) +/* GCC-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(__GNUC__) && defined(__ARM_NEON__) +/* GCC-compatible compiler, targeting ARM with NEON */ +#include +#elif defined(__GNUC__) && defined(__IWMMXT__) +/* GCC-compatible compiler, targeting ARM with WMMX */ +#include +#elif (defined(__GNUC__) || defined(__xlC__)) && \ + (defined(__VEC__) || defined(__ALTIVEC__)) +/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ +#include +/* We need to undef those tokens defined by to avoid conflicts + with the C++ types. => Can still use __bool/__vector */ +#undef bool +#undef vector +#undef pixel +#elif defined(__GNUC__) && defined(__SPE__) +/* GCC-compatible compiler, targeting PowerPC with SPE */ +#include +#endif diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h new file mode 100644 index 0000000000000000000000000000000000000000..08c3bbe43500147f540406bb3cfe38fcc9c8968b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h @@ -0,0 +1,394 @@ +#pragma once + +// This file provides two functions to help write elementwise kernels: +// +// cpu_kernel(TensorIterator iter, ) +// cpu_kernel_vec(TensorIterator iter, , ) +// +// Both functions may generate vectorized code. The cpu_kernel implementation +// relies on the compiler's auto-vectorization. The cpu_kernel_vec +// implementation uses x86 SIMD intrinsics when available. These functions +// are only intended to be used in the ATen/native/cpu subdirectory, since files +// in other directories are not compiled with AVX/AVX2 enabled. See README.md +// for more details. +// +// For example, to write a multiplication kernel for float: +// +// cpu_kernel(iter, [](float a, float b) { return a * b; }); +// +// Or you may write: +// +// cpu_kernel_vec(iter, +// [](float a, float b) { return a * b; }, +// [](Vectorized a, Vectorized b) { return a * b; }); +// +// See BinaryOpsKernel.cpp for the complete implementation +// +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { namespace native { inline namespace CPU_CAPABILITY { + +using namespace vec; + +template +typename traits::ArgsTuple +dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, + std::index_sequence) { + return std::make_tuple( + c10::load::type>( + data[INDEX] + i * strides[INDEX])...); +} + +template +typename traits::ArgsTuple +dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) { + using Indices = std::make_index_sequence; + return dereference_impl(data, strides, i, Indices{}); +} + +template +typename traits::ArgsTuple +dereference_vec_impl(char* C10_RESTRICT data[], + const typename traits::result_type& opt_scalar, + size_t S, + int64_t i, + std::index_sequence) { + using Vec = typename traits::result_type; + using scalar_t = typename Vec::value_type; + return std::make_tuple( + S == INDEX + 1 ? + opt_scalar : + Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...); +} + +template +typename traits::ArgsTuple +dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) { + using Indices = std::make_index_sequence; + return dereference_vec_impl(data, opt_scalar, S, i, Indices{}); +} + +template ::result_type>::value>::type* = nullptr> +static inline void +execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + using result_type = typename traits::result_type; + for (; i < n; i++) { + result_type* out_ptr = (result_type*)(data[0] + i * strides[0]); + *out_ptr = c10::guts::apply(std::forward(op), dereference( + &data[1], + &strides[1], + i)); + } +} + +template ::result_type>::value>::type* = nullptr> +static inline void +execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + for (; i < n; i++) { + c10::guts::apply(std::forward(op), dereference( + &data[0], + &strides[0], + i)); + } +} + +// Basic loop operation (one output, N inputs). May be auto-vectorized +// by the compiler. Supports inputs and outputs of different types. +template +static inline void +basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + constexpr int ntensors = traits::arity + 1; + + // Copying strides to temporary array helps auto vectorization in older GCC + // versions. + int64_t strides[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + strides[arg] = strides_[arg]; + } + + execute_op(data, strides, i, n, std::forward(op)); +} + +// the recursive variadic template for iterating over the returned tuple +template +struct TupleOutput { + static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i, + const T &tuple) { + TupleOutput::handle(data, strides, i, tuple); + + auto output = std::get(tuple); + using output_type = decltype(output); + output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]); + *out_ptr = output; + } +}; + +// Base case for the above recursive template +template +struct TupleOutput { + static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i, + const T &tuple) { + auto output = std::get<0>(tuple); + using output_type = decltype(output); + output_type* out_ptr = (output_type *)(data[0] + i * strides[0]); + *out_ptr = output; + } +}; + +template +void handle_tuple_outputs(char* C10_RESTRICT data[], + const int64_t* strides, + int64_t i, + const std::tuple &tuple) { + TupleOutput::handle(data, strides, i, tuple); +} + +// Loop operation for `cpu_kernel_multiple_outputs`. +// 1. Use `c10::guts::apply` to make dynamic method invocation +// for the lambda passed in `cpu_kernel_multiple_outputs`. +// 2. Iterate over the members of the returned tuple, set the corresponding +// output tensor by the tuple member in `handle_tuple_outputs` function. +template +static inline void +multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + + using result_type = typename traits::result_type; + constexpr int num_outputs = std::tuple_size::value; + constexpr int ntensors = traits::arity + num_outputs; + + // Copying strides to temporary array helps auto vectorization in older GCC + // versions. + int64_t strides[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + strides[arg] = strides_[arg]; + } + + for (; i < n; i++) { + auto output = c10::guts::apply(op, dereference( + &data[num_outputs], + &strides[num_outputs], + i)); + handle_tuple_outputs(data, strides, i, output); + } +} + +// Explicitly vectorized loop implementation. All inputs and outputs must be +// the same type and contiguous with one exception: a single input may be +// a scalar (stride 0). It's position is indicated by the argument `S`. If `S` +// is 0, then there are no scalar inputs. +template +static inline void +vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) { + using traits = function_traits; + using scalar_t = typename function_traits::result_type; + using Vec = Vectorized; + constexpr int ntensors = traits::arity + 1; + + char* C10_RESTRICT data[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + data[arg] = data_[arg]; + } + + Vec opt_scalar = Vec(S > 0 ? *(scalar_t*)data[S] : scalar_t(0)); + int64_t i = 0; + for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) { + auto args1 = dereference_vec(&data[1], opt_scalar, S, i); + auto args2 = dereference_vec(&data[1], opt_scalar, S, i + Vec::size()); + auto out1 = c10::guts::apply(std::forward(vop), std::move(args1)); + auto out2 = c10::guts::apply(std::forward(vop), std::move(args2)); + out1.store(data[0] + i * sizeof(scalar_t)); + out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t)); + } + if (i < n) { + int64_t strides[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t); + } + basic_loop(data, strides, i, n, std::forward(op)); + } +} + + +template +static inline void unroll_contiguous_scalar_checks( + const int64_t* /*strides*/, + std::index_sequence<>, + cb_t&& cb) { + cb(0); +} + +template +static inline void unroll_contiguous_scalar_checks( + const int64_t* strides, + std::index_sequence, + cb_t&& cb) { + if (is_contiguous_scalar(strides)) { + cb(INDEX0 + 1); + } else { + unroll_contiguous_scalar_checks(strides, std::index_sequence{}, std::forward(cb)); + } +} + +template +struct VectorizedLoop2d { + op_t op; + vop_t vop; + + using traits = function_traits; + static constexpr int ntensors = traits::arity + 1; + using data_t = std::array; + + VectorizedLoop2d(const op_t &op, vop_t vop): + op(op), vop(std::move(vop)) {} + + static void advance(data_t &data, const int64_t *outer_strides) { + for (const auto arg : c10::irange(data.size())) { + data[arg] += outer_strides[arg]; + } + } + + void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) { + data_t data; + std::copy_n(base, ntensors, data.data()); + const int64_t *outer_strides = &strides[ntensors]; + + if (is_contiguous(strides)) { + for (const auto i C10_UNUSED : c10::irange(size1)) { + vectorized_loop(data.data(), size0, 0, op, vop); + advance(data, outer_strides); + } + } else { + using Indices = std::make_index_sequence; + unroll_contiguous_scalar_checks(strides, Indices{}, [&](size_t idx) { + if (idx) { + for (const auto i C10_UNUSED : c10::irange(size1)) { + vectorized_loop(data.data(), size0, idx, op, vop); + advance(data, outer_strides); + } + } else { + for (const auto i C10_UNUSED : c10::irange(size1)) { + basic_loop(data.data(), strides, 0, size0, op); + advance(data, outer_strides); + } + } + }); + } + } +}; + +template +VectorizedLoop2d make_vectorized_loop2d( + const op_t &op, const vop_t &vop) { + return VectorizedLoop2d(op, vop); +} + +template +void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) { + using traits = function_traits; + // this could be extended to work with void return types + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + // dynamic casting not currently supported on CPU + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + iter.for_each([&](char** data, const int64_t* strides, int64_t n) { + // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that + // iter.for_each is ever sending to the loop lambda + basic_loop(data, strides, 0, n, std::forward(op)); + }, grain_size); + iter.cast_outputs(); +} + +// This function helps write elementwise kernels that requires multiple outputs. +// It follows the similar structure of cpu_kernel. +// Instead of `basic_loop` function, a new `multiple_outputs_loop` function is +// manipulated to handle multiple return values. +// For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`) +// of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`. +// The `gpu_kernel_multiple_outputs` is also implemented without this check, +// We could extend `needs_dynamic_casting` to support both `std::tuple` and +// `thrust::tuple` in the future. +template +void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) { + using traits = function_traits; + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + + iter.for_each([&](char** data, const int64_t* strides, int64_t n) { + multiple_outputs_loop(data, strides, 0, n, std::forward(op)); + }, grain_size); + iter.cast_outputs(); +} + +template +void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) { + using traits = function_traits; + // this could be extended to work with void return types + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + // dynamic casting not currently supported on CPU, but some kernels (like Fill) + // explicitly dynamic_cast, so we give the opt-out of checking. + if constexpr (check_dynamic_cast) { + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + } + + iter.for_each(make_vectorized_loop2d(op, vop), grain_size); + iter.cast_outputs(); +} + +template +void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) { + using traits = function_traits; + constexpr bool result_void = std::is_void::value; + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity && + ((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1))); + // dynamic casting not currently supported on CPU + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) { + basic_loop(data, strides, 0, n, std::forward(op)); + }, range); + iter.cast_outputs(); +} + +template +void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) { + cpu_serial_kernel(iter, op, {0, iter.numel()}); +} + +template +void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) { + using traits = function_traits; + // this could be extended to work with void return types + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + // dynamic casting not currently supported on CPU + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + iter.serial_for_each(make_vectorized_loop2d(op, vop), range); + iter.cast_outputs(); +} + +template +void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) { + cpu_serial_kernel_vec(iter, op, vop, {0, iter.numel()}); +} + +}}} // namespace at::native:: diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1c6507909ca4aa7e49fbaa420e407b211023b1b7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +class Tensor; + +namespace native { + +using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&); + +DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel); +DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel); + +}} // at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..d6afac295aff691ed2527bd5dc18e9bc6ebfe858 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h @@ -0,0 +1,238 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +using namespace vec; + +#define AT_DISPATCH_REDUCTION_TYPES(op, ...) \ + [&] { \ + switch (op) { \ + case ReductionType::SUM: { \ + static constexpr auto reduce = ReductionType::SUM; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::MEAN: { \ + static constexpr auto reduce = ReductionType::MEAN; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::MIN: { \ + static constexpr auto reduce = ReductionType::MIN; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::MAX: { \ + static constexpr auto reduce = ReductionType::MAX; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::PROD: { \ + static constexpr auto reduce = ReductionType::PROD; \ + return __VA_ARGS__(); \ + } \ + } \ + }() + +template +inline vec_scalar_t init_value() { + using acc_t = vec_scalar_t; + acc_t val; + if (reduce == ReductionType::SUM || + reduce == ReductionType::MEAN) { + val = static_cast(0); + } else if (reduce == ReductionType::PROD) { + val = static_cast(1); + } else if (reduce == ReductionType::MAX) { + val = -std::numeric_limits::infinity(); + } else { + TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN); + val = std::numeric_limits::infinity(); + } + return val; +} + +template +inline vec_scalar_t init_value(const c10::optional& initial) { + using acc_t = vec_scalar_t; + if (initial.has_value()) { + return initial.value().to(); + } else { + return init_value(); + } +} + +template +inline void init(scalar_t* out, int64_t size, const vec_scalar_t& val) { + using Vec = Vectorized>; + map( + [val](Vec x) { return Vec(val); }, + out, + out, + size); +} + +template +inline void init(scalar_t* out, int64_t size, const c10::optional& initial) { + using acc_t = vec_scalar_t; + acc_t val = init_value(initial); + init(out, size, val); +} + +// overload with `include_self`, used by scatter_reduce +template +inline void init(scalar_t* out, int64_t size, bool include_self = false) { + using acc_t = vec_scalar_t; + if (!include_self) { + acc_t val = init_value(); + init(out, size, val); + } +} + +template +inline void _init(scalar_t* self_ptr, at::opmath_type* buffer_ptr, int64_t size, bool include_self) { + if (!include_self) { + init, reduce>(buffer_ptr, size, include_self); + } else { + vec::convert(self_ptr, buffer_ptr, size); + } +} + +template +inline typename std::enable_if::value, scalar_t>::type +_max(const scalar_t& x, const scalar_t& y) { + return at::_isnan(y) ? y : std::max(x, y); +} + +template +inline Vectorized _max(const Vectorized& x, const Vectorized& y) { + // vec::maximum propagates NaN + return vec::maximum(x, y); +} + +template +inline typename std::enable_if::value, Vec2>::type +_max(const vec_t& x, const vec_t& y) { + // vec::maximum propagates NaN + return maximum(x, y); +} + +template +inline typename std::enable_if::value, scalar_t>::type +_min(const scalar_t& x, const scalar_t& y) { + return at::_isnan(y) ? y : std::min(x, y); +} + +template +inline Vectorized _min(const Vectorized& x, const Vectorized& y) { + // vec::minimum propagates NaN + return vec::minimum(x, y); +} + +template +inline typename std::enable_if::value, Vec2>::type +_min(const vec_t& x, const vec_t& y) { + // vec::minimum propagates NaN + return minimum(x, y); +} + +template , int> = 0> +inline void map_acc( + const Op& vec_fun, + accumut* output_data, + const accumut* input_data, + const scalar_t* input_data2, + int64_t size) { + using Vec = vec::Vectorized; + using aVec = vec::Vectorized; + int64_t d = 0; + constexpr int64_t kVecSize = Vec::size(); + constexpr int64_t kaVecSize = aVec::size(); + for (d = 0; d < size - (size % kVecSize); d += kVecSize) { + Vec data2_vec = Vec::loadu(input_data2 + d); + auto [data2_avec0, data2_avec1] = convert_to_float(data2_vec); + aVec input_vec0 = aVec::loadu(input_data + d); + aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize); + vec_fun(input_vec0, data2_avec0).store(output_data + d); + vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize); + } + if (size - d > 0) { + int64_t tail_size = size - d; + Vec data2_vec = Vec::loadu(input_data2 + d, tail_size); + auto [data2_avec0, data2_avec1] = convert_to_float(data2_vec); + if (tail_size > kaVecSize) { + aVec input_vec0 = aVec::loadu(input_data + d); + aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize); + vec_fun(input_vec0, data2_avec0).store(output_data + d); + vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize); + } else { + aVec input_vec0 = aVec::loadu(input_data + d, tail_size); + vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size); + } + } +} + +// for Max and Min, propagate NaN: +template +inline T update(const T& x, const T& y) { + if (reduce == ReductionType::SUM || + reduce == ReductionType::MEAN) { + return x + y; + } else if (reduce == ReductionType::PROD) { + return x * y; + } else if (reduce == ReductionType::MAX) { + return _max(x, y); + } else { + TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN); + return _min(x, y); + } +} + +template +inline void update(scalar_t* out, const scalar_t* data, int64_t K) { + using Vec = vec::Vectorized>; + map2( + [](Vec x, Vec y) { return update(x, y); }, + out, + out, + data, + K); +} + +template , int> = 0> +inline void update(at::opmath_type* out, const scalar_t* data, int64_t K) { + using opmath_t = at::opmath_type; + using Vec = vec::Vectorized; + map_acc( + [](Vec x, Vec y) { return update(x, y); }, + out, + out, + data, + K); +} + +template +inline void write(scalar_t* out, int64_t count, int64_t K) { + using Vec = vec::Vectorized>; + if (reduce == ReductionType::MEAN) { + if (count > 0) { + vec::map( + [count](Vec x) { return x / Vec(count); }, + out, + out, + K); + } + } +} + +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cbcbf3c63d9984ab4d8727f06e50dede5d840fb8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); + +DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub); +DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub); +DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub); +DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub); +DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub); +DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub); + +} // at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h new file mode 100644 index 0000000000000000000000000000000000000000..f4fd3b7bc461fbf82e8b4a16dd9453e46e124efa --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h @@ -0,0 +1,522 @@ +#pragma once +/* + AVX implementation of sin, cos, sincos, exp and log + + Based on "sse_mathfun.h", by Julien Pommier + http://gruntthepeon.free.fr/ssemath/ + + Copyright (C) 2012 Giovanni Garberoglio + Interdisciplinary Laboratory for Computational Science (LISC) + Fondazione Bruno Kessler and University of Trento + via Sommarive, 18 + I-38123 Trento (Italy) + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + (this is the zlib license) +*/ + +#include + +/* The original source of this file has been modified. */ +#if defined(CPU_CAPABILITY_AVX2) + +#if defined(__GNUC__) +# define ALIGN32_BEG __attribute__((aligned(32))) +#elif defined(_WIN32) +# define ALIGN32_BEG __declspec(align(32)) +#endif + +typedef __m256 v8sf; // vector of 8 float (avx2) +typedef __m256i v8si; // vector of 8 int (avx2) + +/* declare some AVX constants -- why can't I figure a better way to do that? */ +#define _PS256_CONST(Name, Val) \ + static const ALIGN32_BEG float _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } +#define _PI32_CONST256(Name, Val) \ + static const ALIGN32_BEG int _pi32_256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } +#define _PS256_CONST_TYPE(Name, Type, Val) \ + static const ALIGN32_BEG Type _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } + +_PS256_CONST(1 , 1.0f); +_PS256_CONST(0p5, 0.5f); +/* the smallest non denormalized float number */ +_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000); +_PS256_CONST_TYPE(mant_mask, int, 0x7f800000); +_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000); + +_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000); +_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000); + +_PI32_CONST256(0, 0); +_PI32_CONST256(1, 1); +_PI32_CONST256(inv1, ~1); +_PI32_CONST256(2, 2); +_PI32_CONST256(4, 4); +_PI32_CONST256(0x7f, 0x7f); + +_PS256_CONST(cephes_SQRTHF, 0.707106781186547524); +_PS256_CONST(cephes_log_p0, 7.0376836292E-2); +_PS256_CONST(cephes_log_p1, - 1.1514610310E-1); +_PS256_CONST(cephes_log_p2, 1.1676998740E-1); +_PS256_CONST(cephes_log_p3, - 1.2420140846E-1); +_PS256_CONST(cephes_log_p4, + 1.4249322787E-1); +_PS256_CONST(cephes_log_p5, - 1.6668057665E-1); +_PS256_CONST(cephes_log_p6, + 2.0000714765E-1); +_PS256_CONST(cephes_log_p7, - 2.4999993993E-1); +_PS256_CONST(cephes_log_p8, + 3.3333331174E-1); +_PS256_CONST(cephes_log_q1, -2.12194440e-4); +_PS256_CONST(cephes_log_q2, 0.693359375); + + +/* natural logarithm computed for 8 simultaneous float + return NaN for x <= 0 +*/ +inline v8sf log256_ps(v8sf x) { + v8si imm0; + v8sf one = *(v8sf*)_ps256_1; + + //v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps()); + v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS); + + x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */ + + // can be done with AVX2 + imm0 = _mm256_srli_epi32(_mm256_castps_si256(x), 23); + + /* keep only the fractional part */ + x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask); + x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5); + + // this is again another AVX2 instruction + imm0 = _mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f); + v8sf e = _mm256_cvtepi32_ps(imm0); + + e = _mm256_add_ps(e, one); + + /* part2: + if( x < SQRTHF ) { + e -= 1; + x = x + x - 1.0; + } else { x = x - 1.0; } + */ + //v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF); + v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS); + v8sf tmp = _mm256_and_ps(x, mask); + x = _mm256_sub_ps(x, one); + e = _mm256_sub_ps(e, _mm256_and_ps(one, mask)); + x = _mm256_add_ps(x, tmp); + + v8sf z = _mm256_mul_ps(x,x); + + v8sf y = *(v8sf*)_ps256_cephes_log_p0; + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8); + y = _mm256_mul_ps(y, x); + + y = _mm256_mul_ps(y, z); + + tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1); + y = _mm256_add_ps(y, tmp); + + + tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5); + y = _mm256_sub_ps(y, tmp); + + tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2); + x = _mm256_add_ps(x, y); + x = _mm256_add_ps(x, tmp); + x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN + return x; +} + +_PS256_CONST(exp_hi, 88.3762626647949f); +_PS256_CONST(exp_lo, -88.3762626647949f); + +_PS256_CONST(cephes_LOG2EF, 1.44269504088896341); +_PS256_CONST(cephes_exp_C1, 0.693359375); +_PS256_CONST(cephes_exp_C2, -2.12194440e-4); + +_PS256_CONST(cephes_exp_p0, 1.9875691500E-4); +_PS256_CONST(cephes_exp_p1, 1.3981999507E-3); +_PS256_CONST(cephes_exp_p2, 8.3334519073E-3); +_PS256_CONST(cephes_exp_p3, 4.1665795894E-2); +_PS256_CONST(cephes_exp_p4, 1.6666665459E-1); +_PS256_CONST(cephes_exp_p5, 5.0000001201E-1); + +inline v8sf exp256_ps(v8sf x) { + v8sf tmp = _mm256_setzero_ps(), fx; + v8si imm0; + v8sf one = *(v8sf*)_ps256_1; + + x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi); + x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF); + fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5); + + /* how to perform a floorf with SSE: just below */ + //imm0 = _mm256_cvttps_epi32(fx); + //tmp = _mm256_cvtepi32_ps(imm0); + + tmp = _mm256_floor_ps(fx); + + /* if greater, subtract 1 */ + //v8sf mask = _mm256_cmpgt_ps(tmp, fx); + v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); + mask = _mm256_and_ps(mask, one); + fx = _mm256_sub_ps(tmp, mask); + + tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1); + v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2); + x = _mm256_sub_ps(x, tmp); + x = _mm256_sub_ps(x, z); + + z = _mm256_mul_ps(x,x); + + v8sf y = *(v8sf*)_ps256_cephes_exp_p0; + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5); + y = _mm256_mul_ps(y, z); + y = _mm256_add_ps(y, x); + y = _mm256_add_ps(y, one); + + /* build 2^n */ + imm0 = _mm256_cvttps_epi32(fx); + // another two AVX2 instructions + imm0 = _mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f); + imm0 = _mm256_slli_epi32(imm0, 23); + v8sf pow2n = _mm256_castsi256_ps(imm0); + y = _mm256_mul_ps(y, pow2n); + return y; +} + +_PS256_CONST(minus_cephes_DP1, -0.78515625); +_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4); +_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8); +_PS256_CONST(sincof_p0, -1.9515295891E-4); +_PS256_CONST(sincof_p1, 8.3321608736E-3); +_PS256_CONST(sincof_p2, -1.6666654611E-1); +_PS256_CONST(coscof_p0, 2.443315711809948E-005); +_PS256_CONST(coscof_p1, -1.388731625493765E-003); +_PS256_CONST(coscof_p2, 4.166664568298827E-002); +_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI + + +/* evaluation of 8 sines at onces using AVX intrinsics + + The code is the exact rewriting of the cephes sinf function. + Precision is excellent as long as x < 8192 (I did not bother to + take into account the special handling they have for greater values + -- it does not return garbage for arguments over 8192, though, but + the extra precision is missing). + + Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the + surprising but correct result. + +*/ +inline v8sf sin256_ps(v8sf x) { // any x + v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y; + v8si imm0, imm2; + + sign_bit = x; + /* take the absolute value */ + x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); + /* extract the sign bit (upper one) */ + sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask); + + /* scale by 4/Pi */ + y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI); + + /* + Here we start a series of integer operations, which are in the + realm of AVX2. + If we don't have AVX, let's perform them using SSE2 directives + */ + + /* store the integer part of y in mm0 */ + imm2 = _mm256_cvttps_epi32(y); + /* j=(j+1) & (~1) (see the cephes sources) */ + // another two AVX2 instruction + imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1); + imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1); + y = _mm256_cvtepi32_ps(imm2); + + /* get the swap sign flag */ + imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4); + imm0 = _mm256_slli_epi32(imm0, 29); + /* get the polynom selection mask + there is one polynom for 0 <= x <= Pi/4 + and another one for Pi/4 +#include +#include + +#ifdef USE_FBGEMM +#include +#endif + +namespace at { +namespace native { + +template +inline void _store(T* dst, at::vec::Vectorized src) { + src.store(dst); +} + +inline void _store(at::BFloat16* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_float_bfloat16(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +inline void _store(at::Half* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_float_half(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +inline namespace CPU_CAPABILITY { + +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T& x, const T& X, Args&&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +// Helper struct for bfloat16 vectorization +// Useful when you need float as immediate dtype or accumulate dtype +using namespace vec; +struct Vec2 { + Vectorized val0, val1; + Vec2(Vectorized v0, Vectorized v1) : val0(v0), val1(v1) {} + Vec2(float v) : val0(v), val1(v) {} + static Vec2 loadu(const BFloat16* ptr) { + auto [v0, v1] = convert_bfloat16_float(Vectorized::loadu(ptr)); + return {v0, v1}; + } + static Vec2 loadu(const float* ptr) { + return {Vectorized::loadu(ptr), Vectorized::loadu(ptr + Vectorized::size())}; + } + void store(BFloat16* ptr) const { + Vectorized val = convert_float_bfloat16(val0, val1); + val.store(ptr); + } + void store(float* ptr) const { + val0.store(ptr); + val1.store(ptr + Vectorized::size()); + } +}; +inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; } +inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; } +inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; } +inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; } +inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; } +inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; } + +template struct VectorizedType { using type = Vectorized; }; +template <> struct VectorizedType { using type = Vec2; }; +template using VecType = typename VectorizedType::type; + +// Helper for mixed data type parameter Vec::load +inline std::tuple, Vectorized> load2f(const BFloat16* ptr) { + return convert_bfloat16_float(Vectorized::loadu(ptr)); +} + +inline std::tuple, Vectorized> load2f(const Half* ptr) { + return convert_half_float(Vectorized::loadu(ptr)); +} + +inline std::tuple, Vectorized> load2f(const float* ptr) { + using Vec = Vectorized; + return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size())); +} + +inline std::tuple, Vectorized> load2f(const BFloat16* ptr, int64_t count) { + return convert_bfloat16_float(Vectorized::loadu(ptr, count)); +} + +inline std::tuple, Vectorized> load2f(const Half* ptr, int64_t count) { + return convert_half_float(Vectorized::loadu(ptr, count)); +} + +inline std::tuple, Vectorized> load2f(const float* ptr, int64_t count) { + using Vec = Vectorized; + if (count > Vec::size()) { + return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size())); + } else { + return std::make_tuple(Vec::loadu(ptr, count), Vec(0)); + } +} + +} // namespace + +namespace utils { + +template +T CeilLog2(const T& x) { + if (x <= 2) { + return 1; + } + // Last set bit is floor(log2(x)), floor + 1 is ceil + // except when x is an exact powers of 2, so subtract 1 first + return static_cast(llvm::findLastSet(static_cast(x) - 1)) + 1; +} + +// matrix transpose: +// src has shape of M by N, with leading dimension of ld_src +// dst has shape of N by M, with leading dimension of ld_dst +template +inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { + for (int64_t j = 0; j < N; j++) { + for (int64_t i = 0; i < M; i++) { + dst[j * ld_dst + i] = src[i * ld_src + j]; + } + } +} + +#ifdef USE_FBGEMM +template <> +inline void transpose(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); +} +#endif + +template +inline void parallel_sparse_csr( + const TensorAccessor& crow_acc, + const int64_t M, + const int64_t nnz, + const F& f) { + TORCH_CHECK(crow_acc.size(0) == M + 1); + + // directly parallel on `M` may lead to load imbalance, + // statically determine thread partition here to average payload + // for each thread. + int num_threads = at::get_num_threads(); + std::vector thread_splits(num_threads + 1, M); + + int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads)); + + thread_splits[0] = 0; + int64_t sum = 0; + int64_t t = 1; + for (const auto m : c10::irange(M)) { + int64_t row_start = crow_acc[m]; + int64_t row_end = crow_acc[m + 1]; + sum += row_end - row_start; + if (sum > t * thread_averge_payload) { + thread_splits[t] = m; + t++; + } + } + // need to restore the last index, + // due to rounding error when calculating `thread_averge_payload`. + thread_splits[num_threads] = M; + + at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) { + int tid = at::get_thread_num(); + int64_t begin = thread_splits[tid]; + int64_t end = thread_splits[tid + 1]; + f(begin, end); + }); +} + +} // namespace utils + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/zmath.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/zmath.h new file mode 100644 index 0000000000000000000000000000000000000000..9b52039e84f91861fcfae9b8ee21d0a9f5c363ac --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/zmath.h @@ -0,0 +1,250 @@ +#pragma once + +// Complex number math operations that act as no-ops for other dtypes. +#include +#include +#include + +namespace at { namespace native { +inline namespace CPU_CAPABILITY { + +template +inline VALUE_TYPE zabs (SCALAR_TYPE z) { + return z; +} + +template<> +inline c10::complex zabs > (c10::complex z) { + return c10::complex(std::abs(z)); +} + +template<> +inline float zabs , float> (c10::complex z) { + return std::abs(z); +} + +template<> +inline c10::complex zabs > (c10::complex z) { + return c10::complex(std::abs(z)); +} + +template<> +inline double zabs , double> (c10::complex z) { + return std::abs(z); +} + +// This overload corresponds to non-complex dtypes. +// The function is consistent with its NumPy equivalent +// for non-complex dtypes where `pi` is returned for +// negative real numbers and `0` is returned for 0 or positive +// real numbers. +// Note: `nan` is propagated. +template +inline VALUE_TYPE angle_impl (SCALAR_TYPE z) { + if (at::_isnan(z)) { + return z; + } + return z < 0 ? c10::pi : 0; +} + +template<> +inline c10::complex angle_impl > (c10::complex z) { + return c10::complex(std::arg(z), 0.0); +} + +template<> +inline float angle_impl , float> (c10::complex z) { + return std::arg(z); +} + +template<> +inline c10::complex angle_impl > (c10::complex z) { + return c10::complex(std::arg(z), 0.0); +} + +template<> +inline double angle_impl , double> (c10::complex z) { + return std::arg(z); +} + +template +constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) { + return z; //No-Op +} + +template<> +constexpr c10::complex real_impl > (c10::complex z) { + return c10::complex(z.real(), 0.0); +} + +template<> +constexpr float real_impl , float> (c10::complex z) { + return z.real(); +} + +template<> +constexpr c10::complex real_impl > (c10::complex z) { + return c10::complex(z.real(), 0.0); +} + +template<> +constexpr double real_impl , double> (c10::complex z) { + return z.real(); +} + +template +constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) { + return 0; +} + +template<> +constexpr c10::complex imag_impl > (c10::complex z) { + return c10::complex(z.imag(), 0.0); +} + +template<> +constexpr float imag_impl , float> (c10::complex z) { + return z.imag(); +} + +template<> +constexpr c10::complex imag_impl > (c10::complex z) { + return c10::complex(z.imag(), 0.0); +} + +template<> +constexpr double imag_impl , double> (c10::complex z) { + return z.imag(); +} + +template +inline TYPE conj_impl (TYPE z) { + return z; //No-Op +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex{z.real(), -z.imag()}; +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex(z.real(), -z.imag()); +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex(z.real(), -z.imag()); +} + +template +inline TYPE ceil_impl (TYPE z) { + return std::ceil(z); +} + +template <> +inline c10::complex ceil_impl (c10::complex z) { + return c10::complex(std::ceil(z.real()), std::ceil(z.imag())); +} + +template <> +inline c10::complex ceil_impl (c10::complex z) { + return c10::complex(std::ceil(z.real()), std::ceil(z.imag())); +} + +template +inline c10::complex sgn_impl (c10::complex z) { + if (z == c10::complex(0, 0)) { + return c10::complex(0, 0); + } else { + return z / zabs(z); + } +} + +template +inline TYPE floor_impl (TYPE z) { + return std::floor(z); +} + +template <> +inline c10::complex floor_impl (c10::complex z) { + return c10::complex(std::floor(z.real()), std::floor(z.imag())); +} + +template <> +inline c10::complex floor_impl (c10::complex z) { + return c10::complex(std::floor(z.real()), std::floor(z.imag())); +} + +template +inline TYPE round_impl (TYPE z) { + return std::nearbyint(z); +} + +template <> +inline c10::complex round_impl (c10::complex z) { + return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag())); +} + +template <> +inline c10::complex round_impl (c10::complex z) { + return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag())); +} + +template +inline TYPE trunc_impl (TYPE z) { + return std::trunc(z); +} + +template <> +inline c10::complex trunc_impl (c10::complex z) { + return c10::complex(std::trunc(z.real()), std::trunc(z.imag())); +} + +template <> +inline c10::complex trunc_impl (c10::complex z) { + return c10::complex(std::trunc(z.real()), std::trunc(z.imag())); +} + +template ::value, int> = 0> +inline TYPE max_impl (TYPE a, TYPE b) { + if (_isnan(a) || _isnan(b)) { + return std::numeric_limits::quiet_NaN(); + } else { + return std::max(a, b); + } +} + +template ::value, int> = 0> +inline TYPE max_impl (TYPE a, TYPE b) { + if (_isnan(a)) { + return a; + } else if (_isnan(b)) { + return b; + } else { + return std::abs(a) > std::abs(b) ? a : b; + } +} + +template ::value, int> = 0> +inline TYPE min_impl (TYPE a, TYPE b) { + if (_isnan(a) || _isnan(b)) { + return std::numeric_limits::quiet_NaN(); + } else { + return std::min(a, b); + } +} + +template ::value, int> = 0> +inline TYPE min_impl (TYPE a, TYPE b) { + if (_isnan(a)) { + return a; + } else if (_isnan(b)) { + return b; + } else { + return std::abs(a) < std::abs(b) ? a : b; + } +} + +} // end namespace +}} //end at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Activation.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Activation.h new file mode 100644 index 0000000000000000000000000000000000000000..5fbfe0d2c65569522dfbf878cc82b5ac66c3c4ad --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Activation.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at { namespace native { + +void launch_glu_backward_kernel(const TensorIteratorBase& iter, + int64_t gI_stride, int64_t I_stride); + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter); + +void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); +void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); + +}} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b8eb85fd4eb2eec771759f5de11e16f934b31437 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh @@ -0,0 +1,348 @@ +#pragma once + +// This file provides two functions to help write GPU elementwise kernels: +// +// gpu_kernel(TensorIterator iter, ) +// gpu_kernel_with_scalars(TensorIterator iter, ) +// +// The gpu_kernel_with_scalars generates specializations that support a +// single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar +// is lifted to a kernel parameter instead of copying to device memory. +// This should be used in conjunction with TensorIterator::allow_cpu_scalars_, +// which is the default for TensorIterator::binary_op. Otherwise, all inputs +// and the output must be on the GPU. +// +// For example, to write a reciprocal kernel for GPU float Tensors: +// +// gpu_kernel(iter, []GPU_LAMBDA(float a) { +// return 1.0f / a; +// }); +// +// To write a multiplication kernel for GPU float Tensors where one argument +// may be a CPU scalar: +// +// gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) { +// return a * b; +// }); +// +// See BinaryOpsKernel.cu for the complete implementation +// + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __NVCC__ +#define ASSERT_HOST_DEVICE_LAMBDA(type) \ + static_assert( \ + __nv_is_extended_host_device_lambda_closure_type(type), \ + #type " must be a __host__ __device__ lambda") +#else +#define ASSERT_HOST_DEVICE_LAMBDA(type) +#endif + +namespace at { +namespace native { + +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { + using traits = function_traits; + int remaining = N - block_work_size() * blockIdx.x; + + if (remaining < block_work_size()) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + elementwise_kernel_helper( + f, memory::policies::vectorized(data)); + } +} + +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + int remaining = N - block_work_size() * blockIdx.x; + auto policy = memory::policies:: + unroll( + data, remaining, ic, oc, l, s); + elementwise_kernel_helper(f, policy); +} + +// this function assume trivial 1d and no dynamic casting +template +static inline void launch_vectorized_kernel( + int64_t N, + const func_t& f, + array_t data) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + using traits = function_traits; + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + int vec_size = memory::can_vectorize_up_to(data); + + switch (vec_size) { + case 4: + vectorized_elementwise_kernel<4, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_elementwise_kernel<2, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 1: { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + unrolled_elementwise_kernel + <<>>( + N, f, data, input_calc, output_calc, loader, storer); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + } + default: + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} + +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +static inline void launch_unrolled_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_elementwise_kernel + <<>>(N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void elementwise_kernel(int N, func_t f) { + int tid = threadIdx.x; + int nv = nt * vt; + int idx = nv * blockIdx.x + tid; +#pragma unroll + for (int i = 0; i < vt; i++) { + if (idx < N) { + f(idx); + idx += nt; + } + } +} + +template +static void launch_legacy_kernel(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + dim3 block(nt); + dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + auto stream = at::cuda::getCurrentCUDAStream(); + elementwise_kernel<<>>(N, f); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::load::type>( + data[INDEX] + i * strides[INDEX])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, i, Indices{}); +} + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::fetch_and_cast::type>( + dtypes[I], data[I] + i * strides[I])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, dtypes, i, Indices{}); +} + +template +void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + at::detail::Array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { + return launch_vectorized_kernel(numel, f, data); + } + auto offset_calc = ::make_offset_calculator(iter); + constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4; + launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + arg0_t* out = (arg0_t*)(data[0] + offsets[0]); + *out = invoke(f, &data.data[1], &offsets.data[1], 1); + }); +} + +template +void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { + if (!needs_dynamic_casting::check(iter)) { + return gpu_kernel_impl_nocast(iter, f); + } + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + + at::detail::Array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { +#ifdef USE_ROCM + at::detail::Array dtypes; + auto inner_strides = iter.get_inner_strides(); + at::detail::Array strides; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + strides[i] = inner_strides[i]; + } + launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) { + void* out = data[0] + strides[0] * idx; + arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx); + c10::cast_and_store(dtypes[0], out, result); + }); +#else + auto loader = memory::LoadWithCast(iter); + auto storer = memory::StoreWithCast<1>(iter); + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + launch_unrolled_kernel( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); +#endif + } else { + at::detail::Array dtypes; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + } + auto offset_calc = ::make_offset_calculator(iter); + launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1); + c10::cast_and_store(dtypes[0], out, result); + }); + } +} + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Copy.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..5639567d666686dd81ca5b4b032fb44f039eb782 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Copy.h @@ -0,0 +1,10 @@ +#pragma once + +namespace at { +struct TensorIteratorBase; + +namespace native { + +void direct_copy_kernel_cuda(TensorIteratorBase &iter); + +}} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..4b02f914d7e20ff914e248d203be3f9434bacb3b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace at { namespace native { + +// This means that max dim is 3 + 2 = 5 with batch dimension and possible +// complex dimension +constexpr int max_rank = 3; + +static inline std::string _cudaGetErrorEnum(cufftResult error) +{ + switch (error) + { + case CUFFT_SUCCESS: + return "CUFFT_SUCCESS"; + case CUFFT_INVALID_PLAN: + return "CUFFT_INVALID_PLAN"; + case CUFFT_ALLOC_FAILED: + return "CUFFT_ALLOC_FAILED"; + case CUFFT_INVALID_TYPE: + return "CUFFT_INVALID_TYPE"; + case CUFFT_INVALID_VALUE: + return "CUFFT_INVALID_VALUE"; + case CUFFT_INTERNAL_ERROR: + return "CUFFT_INTERNAL_ERROR"; + case CUFFT_EXEC_FAILED: + return "CUFFT_EXEC_FAILED"; + case CUFFT_SETUP_FAILED: + return "CUFFT_SETUP_FAILED"; + case CUFFT_INVALID_SIZE: + return "CUFFT_INVALID_SIZE"; + case CUFFT_UNALIGNED_DATA: + return "CUFFT_UNALIGNED_DATA"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: + return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: + return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: + return "CUFFT_PARSE_ERROR"; + case CUFFT_NO_WORKSPACE: + return "CUFFT_NO_WORKSPACE"; + case CUFFT_NOT_IMPLEMENTED: + return "CUFFT_NOT_IMPLEMENTED"; +#if !defined(USE_ROCM) + case CUFFT_LICENSE_ERROR: + return "CUFFT_LICENSE_ERROR"; +#endif + case CUFFT_NOT_SUPPORTED: + return "CUFFT_NOT_SUPPORTED"; + default: + std::ostringstream ss; + ss << "unknown error " << error; + return ss.str(); + } +} + +static inline void CUFFT_CHECK(cufftResult error) +{ + if (error != CUFFT_SUCCESS) { + std::ostringstream ss; + ss << "cuFFT error: " << _cudaGetErrorEnum(error); + AT_ERROR(ss.str()); + } +} + +}} // at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/IndexKernel.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/IndexKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..edd9190deb0dba12979556a9f1bc12705f5801b4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/IndexKernel.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at { +namespace native { +/// @param maskPrefixSum[in,out] +void launch_masked_scatter_kernel( + const TensorBase &self, const TensorBase &mask, + const TensorBase &maskPrefixSum, const TensorBase &source); +}} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..c9640b15b18c8a2d6d4f3dd92379701ae1ec5164 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +namespace native { + +// returns 2**floor(log2(n)) +static int lastPow2(unsigned int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Loops.cuh b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Loops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cb14f275e21718db44bc5f175ce6e650426966ff --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Loops.cuh @@ -0,0 +1,326 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include + + +namespace at { namespace native { + +template +static OffsetCalculator make_input_offset_calculator(const TensorIteratorBase& iter) { + // array size can not be 0, this happens when N == 0 + constexpr int array_size = std::max(N, 1); + TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs()); + std::array strides; + int64_t element_sizes[array_size]; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i + iter.noutputs()).data(); + element_sizes[i] = iter.element_size(i + iter.noutputs()); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +static OffsetCalculator make_output_offset_calculator(const TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs()); + std::array strides; + int64_t element_sizes[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + strides[i] = iter.strides(i).data(); + element_sizes[i] = iter.element_size(i); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { + using traits = function_traits; + using return_t = typename traits::result_type; + using args_t = typename traits::ArgsTuple; + + int idx = blockIdx.x; + + return_t results[thread_work_size()]; + args_t args[thread_work_size()]; + + // load + policy.load(args, idx); + + // compute + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (policy.check_inbounds(i)) { + results[i] = c10::guts::apply(f, args[i]); + } + } + + // store + policy.store(results, idx); +} + +}} // namespace at::native + +#include + +namespace at:: native { + +template +void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) { + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_cuda(), + "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_nocast(sub_iter, f); + } + return; + } + + gpu_kernel_impl_nocast(iter, f); +} + +template +void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_cuda(), + "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel(sub_iter, f); + } + return; + } + + gpu_kernel_impl(iter, f); +} + +template +struct AUnaryFunctor { + using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + __device__ return_t operator()(arg2_t b) const { + return f(a, b); + } + // NB: scalar is stored in higher precision! + AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {} + private: + func_t f; + opmath_arg1_t a; +}; + +template +struct BUnaryFunctor { + using traits = function_traits; + using opmath_arg2_t = typename traits::template arg<1>::type; + __device__ return_t operator()(arg1_t a) const { + return f(a, b); + } + // NB: scalar is stored in higher precision! + BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {} + private: + func_t f; + opmath_arg2_t b; +}; + +// Though seemingly noop, this inserts casts from arg1_t to func_t's type +// (which may be higher precision), as well as casts to return_t +template +struct BinaryFunctor { + __device__ return_t operator()(arg1_t a, arg2_t b) const { + return f(a, b); + } + BinaryFunctor(func_t f_): f(f_) {} + private: + func_t f; +}; + +// Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which +// accepts inputs at higher precision (typically opmath_t), but then +// ensure that we load from memory at the correct precision (scalar_t) +// to avoid expensive loads. For the whole sordid story see +// https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302 +template +void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + + using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + using opmath_arg2_t = typename traits::template arg<1>::type; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + + if (iter.is_cpu_scalar(1)) { + AUnaryFunctor af(f, iter.scalar_value(1)); + iter.remove_operand(1); + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + const OptionalDeviceGuard device_guard(iter.device(1)); + gpu_kernel(iter, af); + } else if (iter.is_cpu_scalar(2)) { + BUnaryFunctor bf(f, iter.scalar_value(2)); + iter.remove_operand(2); + gpu_kernel(iter, bf); + } else { + gpu_kernel(iter, BinaryFunctor(f)); + } +} + +template +void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + // Use symmetric property of the functor to reduce number of kernels, + // requires f(a, b) == f(b, a) + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + + using traits = function_traits; + using opmath_arg_t = typename traits::template arg<0>::type; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + static_assert(std::is_same::type>::value, + "f is not symmetric"); + + OptionalDeviceGuard device_guard; + opmath_arg_t scalar_val{}; + + if (iter.is_cpu_scalar(1)) { + scalar_val = iter.scalar_value(1); + iter.remove_operand(1); + + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + device_guard.reset_device(iter.device(1)); + } else if (iter.is_cpu_scalar(2)) { + scalar_val = iter.scalar_value(2); + iter.remove_operand(2); + } + + if (iter.ninputs() == 2) { + gpu_kernel(iter, BinaryFunctor(f)); + } else { + AUnaryFunctor unary_f(f, scalar_val); + gpu_kernel(iter, unary_f); + } +} + +// Legacy variant that assumes that func_t has the correct types +// that we expect to load from memory +template +void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + using arg1_t = typename traits::template arg<0>::type; + using arg2_t = typename traits::template arg<1>::type; + using return_t = typename traits::result_type; + opmath_gpu_kernel_with_scalars(iter, f); +} + +namespace { // functions for `gpu_kernel_multiple_outputs`. + +// check the return type is `thrust::tuple`, not `std::tuple`. +template struct is_tuple: std::false_type {}; + +template struct is_tuple>: std::true_type {}; + +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) { + int remaining = N - block_work_size() * blockIdx.x; + elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll(data, remaining, ic, oc)); +} + +template +static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_elementwise_kernel_for_multi_outputs<<>>(N, f, data, ic, oc); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using output_t = typename traits::result_type; + static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); + constexpr int num_outputs = thrust::tuple_size::value; + constexpr int num_inputs = traits::arity; + constexpr int ntensors = num_outputs + num_inputs; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors); + + at::detail::Array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + if (iter.is_contiguous()) { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator(); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } else { + auto input_calc = make_input_offset_calculator(iter); + auto output_calc = make_output_offset_calculator(iter); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } +} +} // namespace + +template +void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) { + ASSERT_HOST_DEVICE_LAMBDA(func_t); + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda()); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_multiple_outputs(sub_iter, f); + } + return; + } + + gpu_kernel_multiple_outputs_impl(iter, f); +} + +} //namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Normalization.cuh b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Normalization.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ab2b316bc8a4b99a1af9df2b380bda7d91797ab0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Normalization.cuh @@ -0,0 +1,1742 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif + +namespace at { namespace native { + +// The maximum number of threads in a block +#if defined(USE_ROCM) +constexpr int MAX_BLOCK_SIZE = 256; +#else +constexpr int MAX_BLOCK_SIZE = 512; +#endif + +constexpr unsigned MAX_GRID_SIZE = 65535u; + +// Number of threads in a block given an input size up to MAX_BLOCK_SIZE +static int getNumThreads(int nElem) { +#if defined(USE_ROCM) + int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; +#else + int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; +#endif + for (int i = 0; i != 5; ++i) { + if (nElem <= threadSizes[i]) { + return threadSizes[i]; + } + } + return MAX_BLOCK_SIZE; +} + +// Returns the index of the most significant 1 bit in `val`. +__device__ __forceinline__ int getMSB(int val) { + return 31 - __clz(val); +} + +template +struct Float2 { + accscalar_t v1, v2; + __device__ Float2() {} + __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast(v1)), v2(static_cast(v2)) {} + __device__ Float2(int v) : v1(static_cast(v)), v2(static_cast(v)) {} + __device__ Float2& operator+=(const Float2& a) { + v1 += a.v1; + v2 += a.v2; + return *this; + } + __device__ friend Float2 operator+(Float2 a, const Float2& b) { + a += b; + return a; + } +}; + +template +struct GradOp { + __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g) + : mean(m), input(i), grad_output(g) {} + __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { + accscalar_t g = grad_output[batch][plane][n]; + accscalar_t c = static_cast(input[batch][plane][n]) - mean; + return Float2(g, g * c); + } + const accscalar_t mean; + const PTA& input; + const PTA& grad_output; +}; + +template +struct SumReduceOp { + __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } + + __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +}; + +template +struct SumReduceOp> { + using acc_t = Float2; + + __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } + + __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { + return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)}; + } +}; + +// Sum across (batch, x/y/z) applying Op() pointwise +// this works by first having each thread sum it's part +// of the data. Then there is a double-shuffling reduction. +// First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its +// data to the "warp leader", who writes its value into shared memory. +// Then a single warp reads the remaining (at most C10_WARP_SIZE) items +// and reduces them using another warpSum. +// The implicit assumption is that there are no more +// than C10_WARP_SIZE**2 threads. +template +__device__ scalar_t reduce(Op op, PTA tensor, int plane) { + // first the reductions each thread does separately + scalar_t sum = static_cast(0); + for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { + sum += op(batch, plane, x); + } + } + __shared__ scalar_t shared[C10_WARP_SIZE]; + SumReduceOp reduce_op; + sum = cuda_utils::BlockReduce, cuda_utils::Block2D>(sum, reduce_op, 0, shared); + if (threadIdx.x == 0 && threadIdx.y == 0) { + shared[0] = sum; + } + __syncthreads(); + // Everyone picks it up, should be broadcast into the whole grad_input + return shared[0]; +} + +constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency +constexpr int ELEMENTS_PER_THREAD = 16; +constexpr int OPTIMAL_TILE_W = 32; +constexpr int MAX_H_BLOCK = 128; + +__host__ void flexible_launch_configs( + const int reduction, + const int stride, + dim3 &block, + dim3 &grid, + const bool coop_flag = false) { + int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W); + int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)), + MAX_BLOCK_SIZE / block_x); + if (block_x * block_y != MAX_BLOCK_SIZE) { + block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y); + } + + int grid_x = at::ceil_div(stride, block_x); + int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK); + if (coop_flag) { + // it's not worth having a grid reduction if the reduction dimension is not big enough + grid_y = grid_y < 8 ? 1 : grid_y; + } + + block.x = block_x; + block.y = block_y; + block.z = 1; + grid.x = grid_x; + grid.y = grid_y; + grid.z = 1; +} + +template +__device__ __forceinline__ void welford_merge_element(C& count, + T& mean, + T& m2n, + const C& count_new, + const T& mean_new, + const T& m2n_new) { + T factor = T(1.0) / ::max(1, (count + count_new)); + T delta0 = mean - mean_new; + mean = (mean_new * count_new + mean * count) * factor; + m2n += m2n_new + delta0 * delta0 * count_new * count * factor; + count += count_new; +} + +// merge mean/m2n among threadIdx.y within block +template +__device__ __forceinline__ void welford_merge_block_vertical(C& count, + T& mean, + T& m2n, + C* shmem_count, + T* shmem_mean, + T* shmem_m2n) { + // write to shared memory + auto address_base = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset*2) { + shmem_mean[address_base] = mean; + shmem_m2n[address_base] = m2n; + shmem_count[address_base] = count; + } + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + auto address = address_base + offset * blockDim.x; + // read shared memory back to register for reduction + auto count_new = shmem_count[address]; + auto mean_new = shmem_mean[address]; + auto m2n_new = shmem_m2n[address]; + + welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new); + } + } +} + +template +__global__ void batch_norm_transform_input_kernel( + const GenericPackedTensorAccessor input, + GenericPackedTensorAccessor output, + const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> mean_, + const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> var_or_invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor bias, + stat_accscalar_t epsilon) { + + index_t plane = blockIdx.x; + + if (plane >= input.size(1)) { + return; + } + + stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast(weight[plane]) : static_cast(1); + stat_accscalar_t beta = bias.size(0) > 0 ? static_cast(bias[plane]) : static_cast(0); + stat_accscalar_t mean = static_cast(mean_[plane]); + stat_accscalar_t invstd; + if (train) { + invstd = var_or_invstd[plane]; + } else { + invstd = static_cast(1) / device_sqrt(static_cast(var_or_invstd[plane]) + epsilon); + } + + index_t bs = input.size(0); + index_t fs = input.size(2); + + index_t bstep = blockDim.y * gridDim.y; + for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { + auto o = output[batch][plane]; + auto i = input[batch][plane]; + for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { + o[feature] = static_cast(gamma * (i[feature] - mean) * invstd + beta); + } + } +} + +struct InvStd { + template + __device__ __forceinline__ T operator()(T var, double epsilon) const { + T invstd = 0; + if (var != static_cast(0) || epsilon != static_cast(0)) { + invstd = static_cast(1) / device_sqrt(var + epsilon); + } + return invstd; + } +}; + +struct Var { + template + __device__ __forceinline__ T operator()(T var, double epsilon) const { + return var; + } +}; + +template +__global__ void batch_norm_collect_statistics_kernel( + const GenericPackedTensorAccessor input, + const stat_accscalar_t epsilon, + const stat_accscalar_t momentum, + GenericPackedTensorAccessor save_mean, + GenericPackedTensorAccessor save_transformed_var) { + + __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE]; + + int plane = blockIdx.x; + int N = input.size(0) * input.size(2); + int tid = threadIdx.x + threadIdx.y * blockDim.x; + + // Compute the mean and variance across (batch, x/y/z) + // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block) + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm + // and the parallel algorithm on the same page. + // We use two shuffles to reduce across the entire block. + // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description. + stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE]; + + // first the reductions each thread does separately + stat_accscalar_t avg = 0; + stat_accscalar_t var_n = 0; + int n = 0; + for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { + stat_accscalar_t v = input[batch][plane][x]; + stat_accscalar_t d1 = v - avg; + n++; + avg += d1 / n; + var_n += d1 * (v - avg); + } + } + + // first warpSum to get one value per thread to + // one value per warp + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); + stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); + var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } + + // this writes each warps item into shared memory + // there are at most C10_WARP_SIZE items left because + // there are at most C10_WARP_SIZE**2 threads at the beginning + __syncthreads(); + if (tid % C10_WARP_SIZE == 0) { + shared_n[tid / C10_WARP_SIZE] = n; + shared_avg_var[tid / C10_WARP_SIZE * 2] = avg; + shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n; + } + __syncthreads(); + // now have a second warpSum to reduce the intermediate values + // from shared memory to a single number. The very first + // thread writes it to shared memory. + + if (tid < C10_WARP_SIZE) { + n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0); + avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0)); + var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0)); + } + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); + stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); + var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } + + // Save the mean, variance, and moving averages + if (tid == 0) { + if (save_mean.data() != NULL) { + save_mean[plane] = avg; + } + if (save_transformed_var.data() != NULL) { + save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon); + } + } + +} + +template +__global__ void batch_norm_backward_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor grad_input, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor running_mean, + const GenericPackedTensorAccessor running_var, + const GenericPackedTensorAccessor save_mean, + const GenericPackedTensorAccessor save_invstd, + bool train, + stat_accscalar_t epsilon) { + + index_t plane = blockIdx.x; + index_t N = grad_output.size(0) * grad_output.size(2); + + stat_accscalar_t mean, invstd; + if (train) { + mean = save_mean[plane]; + invstd = save_invstd[plane]; + } else { + mean = static_cast(running_mean[plane]); + invstd = static_cast(1) / device_sqrt(static_cast(running_var[plane]) + epsilon); + } + + stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1); + stat_accscalar_t norm = stat_accscalar_t(1) / N; + + // Compute two values across (batch, x/y/z) in one pass: + // 1. Sum(grad_output) + // 2. DotProduct(input - mean, grad_output) + GradOp> g(mean, input, grad_output); + auto res = reduce>(g, grad_output, plane); + + stat_accscalar_t grad_output_sum = res.v1; + stat_accscalar_t dot_p = res.v2; + + stat_accscalar_t grad_mean = grad_output_sum * norm; + stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd; + stat_accscalar_t grad_scale = invstd * weight_val; + + if (grad_input.data() != NULL) { + for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) { + input_scalar_t go = grad_output[batch][plane][x]; + if (train) { + stat_accscalar_t inp = input[batch][plane][x]; + stat_accscalar_t proj = (inp - mean) * proj_scale; + grad_input[batch][plane][x] = static_cast((go - proj - grad_mean) * grad_scale); + } else { + grad_input[batch][plane][x] = static_cast(go * grad_scale); + } + } + } + } + + if (grad_weight.size(0) > 0) { + if (threadIdx.x == 0) { + grad_weight[plane] = static_cast(dot_p * invstd); + } + } + + if (grad_bias.size(0) > 0) { + if (threadIdx.x == 0) { + grad_bias[plane] = static_cast(grad_output_sum); + } + } +} + +template +__global__ void batch_norm_reduce_statistics_kernel( + const GenericPackedTensorAccessor vec_mean, + const GenericPackedTensorAccessor vec_invstd, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor running_mean, + GenericPackedTensorAccessor running_var, + const accscalar_t epsilon, + const accscalar_t momentum, + const GenericPackedTensorAccessor counts) { + + int feature_size = vec_mean.size(1); + int world_size = vec_mean.size(0); + + int bid = blockIdx.x; + int tid = threadIdx.x; + + // first the reductions each thread does separately + for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) { + accscalar_t avg = 0; + accscalar_t var_n = 0; + index_t n = 0; + for (int j = 0; j < world_size; j++) { + scalar_t count = counts[j]; + accscalar_t m = vec_mean[j][i]; + accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]); + v = (v * v - epsilon) * count; + accscalar_t factor = 1.0 / (n + count); + var_n += v + (avg - m) * (avg - m) * n * count * factor; + avg = n * factor * avg + count * factor * m; + n += count; + } + mean[i] = avg; + invstd[i] = static_cast(1) / device_sqrt(var_n / n + epsilon); + if (running_mean.data() != NULL) { + running_mean[i] = static_cast((1 - momentum) * running_mean[i] + momentum * avg); + } + accscalar_t unbiasedVar = var_n / (n - 1); + if (running_var.data() != NULL) { + running_var[i] = static_cast((1 - momentum) * running_var[i] + momentum * unbiasedVar); + } + } + +} + +template +__global__ void batch_norm_backward_reduce_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor sum_dy, + GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias) { + + index_t plane = blockIdx.x; + + stat_accscalar_t r_mean = mean[plane]; + stat_accscalar_t factor = invstd[plane]; + + GradOp> g(r_mean, input, grad_output); + auto res = reduce>(g, grad_output, plane); + + if (threadIdx.x == 0) { + if (grad_weight.size(0) > 0) { + grad_weight[plane] = static_cast(res.v2 * factor); + } + if (grad_bias.size(0) > 0) { + grad_bias[plane] = static_cast(res.v1); + } + if (sum_dy.size(0) > 0) { + sum_dy[plane] = static_cast(res.v1); + } + if (sum_dy_xmu.size(0) > 0) { + sum_dy_xmu[plane] = static_cast(res.v2); + } + } +} + +template +__device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const stat_accscalar_t norm_fct) { + index_t plane = blockIdx.x; + + if (plane >= input.size(1)) { + return; + } + + stat_accscalar_t m_c = mean[plane]; + stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct; + stat_accscalar_t factor_1_c = invstd[plane]; + stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1); + factor_2_c *= factor_1_c; + factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct; + + index_t bs = input.size(0); + index_t fs = input.size(2); + + index_t bstep = blockDim.y * gridDim.y; + for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { + auto g_i = grad_input[batch][plane]; + auto g_o = grad_output[batch][plane]; + auto i = input[batch][plane]; + for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { + g_i[feature] = static_cast((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c); + } + } +} + +template +__global__ void batch_norm_backward_elemt_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const int* __restrict__ numel, const int world_size) { + int64_t total_numel = 0; + for (int i = 0; i < world_size; i ++) { + total_numel += numel[i]; + } + + const stat_accscalar_t norm_fct = + static_cast(1) / static_cast(total_numel); + batch_norm_backward_elemt_kernel_impl( + input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); +} + +template +__global__ void batch_norm_backward_elemt_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const stat_accscalar_t norm_fct) { + batch_norm_backward_elemt_kernel_impl( + input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); +} + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +static GenericPackedTensorAccessor get_packed_accessor( + const Tensor& t, c10::string_view var_name) { + constexpr auto expect_type = c10::CppTypeToScalarType::value; + const auto actual_type = t.scalar_type(); + TORCH_CHECK(actual_type == expect_type, "Expected ", var_name, + " to have type ", expect_type, " but got ", actual_type); + return t.generic_packed_accessor(); +} + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +static GenericPackedTensorAccessor packed_accessor_or_dummy( + const Tensor& t, c10::string_view var_name) { + if (!t.defined()) { + const std::array zeros{{0}}; + return GenericPackedTensorAccessor(nullptr, zeros.data(), zeros.data()); + } + return get_packed_accessor(t, var_name); +} + +template +std::tuple batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_, + const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_, + bool train, double epsilon, std::array grad_input_mask) { + + using accscalar_t = at::acc_type; + Tensor grad_input_; + Tensor grad_input_reshaped; + Tensor grad_weight_; + Tensor grad_bias_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (grad_input_mask[0]) { + grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + grad_input_reshaped = grad_input_.view(input_reshaped.sizes()); + } + if (grad_input_mask[1]) { + grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[2]) { + grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto grad_input = packed_accessor_or_dummy< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto grad_weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); + auto grad_bias = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); + auto running_mean = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean"); + auto running_var = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var"); + auto save_mean = packed_accessor_or_dummy< + accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean"); + auto save_invstd = packed_accessor_or_dummy< + accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd"); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 blocks(input.size(1)); + int tf = getNumThreads(input.size(2)); + dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); + + batch_norm_backward_kernel <<>> + (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var, + save_mean, save_invstd, train, epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(grad_input_, grad_weight_, grad_bias_); +} + +template +void batch_norm_stats_cuda_template( + const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) { + + using accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + Tensor dummy_mean_; + Tensor dummy_var_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + + resize_output(out_mean, {n_input}); + resize_output(out_invstd, {n_input}); + auto input = get_packed_accessor< + scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); + TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && + out_mean.sizes()[0]); + + auto mean = packed_accessor_or_dummy< + accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean"); + auto invstd = packed_accessor_or_dummy< + accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(input.size(1)); + int tf = getNumThreads(input.size(2)); + dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); + batch_norm_collect_statistics_kernel <<>> + (input, epsilon, 0.0, mean, invstd); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_, + const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1}); + + auto input = get_packed_accessor< + input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); + auto output = get_packed_accessor< + input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output"); + auto weight = packed_accessor_or_dummy< + stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight"); + auto bias = packed_accessor_or_dummy< + stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + // NOTE: We use transform_input_kernel in training mode, which ignores epsilon + const double dummy_epsilon = 1e-5; + + // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + batch_norm_transform_input_kernel <<>> + (input, output, mean, invstd, weight, bias, dummy_epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +std::tuple batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_, + const Tensor& running_mean_, const Tensor& running_var_, + double momentum, double epsilon, const Tensor& counts_) { + + Tensor save_mean_; + Tensor save_invstd_; + + auto features = mean_.size(1); + auto input_options = mean_.options(); + if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) { + input_options = input_options.dtype(ScalarType::Float); + } + save_mean_ = at::empty({features}, input_options); + save_invstd_ = at::empty({features}, input_options); + + auto mean = packed_accessor_or_dummy< + accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd"); + auto running_mean = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean"); + auto running_var = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean"); + auto counts = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts"); + + auto save_mean = get_packed_accessor< + accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean"); + auto save_invstd = get_packed_accessor< + accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + int block = getNumThreads(features); + int grid = std::max(1, features/block); + batch_norm_reduce_statistics_kernel <<>> + (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(save_mean_, save_invstd_); +} + +template +std::tuple batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_, + const bool input_g, const bool weight_g, const bool bias_g) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + Tensor sum_dy_; + Tensor sum_dy_xmu_; + Tensor grad_weight_; + Tensor grad_bias_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (input_g) { + sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (weight_g) { + grad_weight_ = at::empty({n_input}, weight_.options()); + } + if (bias_g) { + grad_bias_ = at::empty({n_input}, weight_.options()); + } + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto grad_weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); + auto grad_bias = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto batch_size = input_reshaped.size(0); + auto feature_size = input_reshaped.size(2); + auto stream = at::cuda::getCurrentCUDAStream(); + + int warp_size = at::cuda::warp_size(); + int block_y = std::min(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size); + // We want block_x to be at least a warp width + int block_x = std::min(std::max(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y); + const dim3 block(block_x, block_y); + const dim3 grid(n_input); + + batch_norm_backward_reduce_kernel <<>> + (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_); +} + +template +Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, + const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + auto reduction_size = input_.numel() / n_input; + auto norm_fct = static_cast(1.0 / reduction_size); + batch_norm_backward_elemt_kernel + <<>> + (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return grad_input_reshaped.view(input_.sizes()); +} + +template +Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, + const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + batch_norm_backward_elemt_kernel <<>> + (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.const_data_ptr(), count.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return grad_input_reshaped.view(input_.sizes()); +} + +// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance +// original apex name: welford_kernel_c_last +template + +__global__ void +batch_norm_collect_statistics_channels_last_kernel( + const scalar_t* __restrict__ input, + accscalar_t* __restrict__ out_mean, + accscalar_t* __restrict__ out_invstd, + volatile accscalar_t* staging_data, + int* semaphores, + const int reduction_size, + const int stride, + accscalar_t epsilon) { + // hide latency with concurrency + accscalar_t x_mean[PARALLEL_LOADS]; + accscalar_t m_2_n[PARALLEL_LOADS]; + int count[PARALLEL_LOADS]; + +#pragma unroll + for (int i = 0; i < PARALLEL_LOADS; i++) { + x_mean[i] = accscalar_t(0); + m_2_n[i] = accscalar_t(0); + count[i] = accscalar_t(0); + } + // tensor dimension (m,c) + + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { + accscalar_t x_math[PARALLEL_LOADS]; + accscalar_t x_count_inv[PARALLEL_LOADS]; + accscalar_t is_valid[PARALLEL_LOADS]; + + // load multiple data in +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + x_math[j] = input[address_base]; + count[j]++; + x_count_inv[j] = accscalar_t(1) / count[j]; + is_valid[j] = accscalar_t(1); + } else { + x_math[j] = accscalar_t(0); + x_count_inv[j] = accscalar_t(0); + is_valid[j] = accscalar_t(0); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + + // calculate mean/m2n with welford +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + accscalar_t delta0 = x_math[j] - x_mean[j]; + x_mean[j] += delta0 * x_count_inv[j]; + accscalar_t delta1 = x_math[j] - x_mean[j]; + m_2_n[j] += delta0 * delta1 * is_valid[j]; + } + } + + // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS +#pragma unroll + for (int j = 1; j < PARALLEL_LOADS; j++) { + welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]); + } + + // release x_mean / m_2_n + auto mean_th = x_mean[0]; + auto m2_th = m_2_n[0]; + auto count_th = count[0]; + + // block-wise reduction with shared memory (since reduction cannot be done within a warp) + static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE]; + static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE]; + static __shared__ int shmem_count[MAX_BLOCK_SIZE]; + + welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); + + if (gridDim.y > 1) { + volatile accscalar_t* staging_mean = staging_data; + volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y]; + volatile int* staging_count = reinterpret_cast(&staging_m2n[stride*gridDim.y]); + + address_base = c_offset + blockIdx.y * stride; + // write data to staging_data; + if (threadIdx.y == 0 && c_offset < stride) { + staging_mean[address_base] = mean_th; + staging_m2n[address_base] = m2_th; + staging_count[address_base] = count_th; + } + + __threadfence(); + __syncthreads(); // ensuring writes to staging_ is visible to all blocks + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y-1)); + } + + __syncthreads(); + + // check that all data is now available in global memory + if (is_last_block_done) { + count_th = 0; + mean_th = accscalar_t(0.0); + m2_th = accscalar_t(0.0); + + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + address_base = c_offset + y * stride; + int count_new = c_offset < stride ? staging_count[address_base] : 0; + accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0); + accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0); + + welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new); + } + + welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); + if (threadIdx.y == 0 && c_offset < stride) { + out_mean[c_offset] = static_cast(mean_th); + out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); + } + } + } else { + if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { + out_mean[c_offset] = static_cast(mean_th); + out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); + } + } +} + +// elementwise BN kernel +// original apex name: batchnorm_forward_c_last_kernel +template < + typename scalar_t, + typename accscalar_t, + typename layerscalar_t, + int PARALLEL_LOADS> +__global__ void batch_norm_transform_input_channels_last_kernel( + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ z, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const layerscalar_t* __restrict__ shift, + scalar_t* __restrict__ out, + const int reduction_size, + const int stride, + const bool fuse_relu) { + // tensor dimension (m,c) + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + auto m_c = mean[c_offset]; + auto inv_std_c = static_cast(inv_std[c_offset]); + auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast(weight[c_offset]); + auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast(shift[c_offset]); + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + auto tmp = w_c * (static_cast(input[address_base]) - m_c ) * inv_std_c + s_c; + if (z != nullptr) { + tmp += z[address_base]; + } + out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast(tmp)); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + } +} + +template +__device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy, + T& sum_dy_xmu, + T* shmem_sum_dy, + T* shmem_sum_dy_xmu) { + // write to shared memory + auto address_base = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset*2) { + shmem_sum_dy[address_base] = sum_dy; + shmem_sum_dy_xmu[address_base] = sum_dy_xmu; + } + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + auto address = address_base + offset * blockDim.x; + + sum_dy += shmem_sum_dy[address]; + sum_dy_xmu += shmem_sum_dy_xmu[address]; + } + } +} + +// batchnorm backward kernel for c last tensor +// original apex name: reduce_bn_c_last_kernel +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_reduce_channels_last_kernel( + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ grad_output, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + accscalar_t* __restrict__ sum_dy_o, + accscalar_t* __restrict__ sum_dy_xmu_o, + layerscalar_t* __restrict__ grad_weight, + layerscalar_t* __restrict__ grad_bias, + volatile accscalar_t* staging_data, + int* semaphores, + const int reduction_size, + const int stride) { + + // hide latency with concurrency + accscalar_t sum_dy[PARALLEL_LOADS]; + accscalar_t sum_dy_xmu[PARALLEL_LOADS]; + +#pragma unroll + for (int i = 0; i < PARALLEL_LOADS; i++) { + sum_dy[i] = accscalar_t(0); + sum_dy_xmu[i] = accscalar_t(0); + } + // tensor dimension (m,c) + + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + auto r_mean = mean[c_offset]; + auto factor = inv_std[c_offset]; + + for (int i = 0; i < loop_count; i++) { + accscalar_t x_input[PARALLEL_LOADS]; + accscalar_t x_grad_output[PARALLEL_LOADS]; + + // load multiple data in +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + x_input[j] = input[address_base]; + x_grad_output[j] = grad_output[address_base]; + } else { + x_input[j] = accscalar_t(0); + x_grad_output[j] = accscalar_t(0); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + + // calculate sum_dy / sum_dy_xmu +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + sum_dy[j] += x_grad_output[j]; + sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean); + } + } + + // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS +#pragma unroll + for (int j = 1; j < PARALLEL_LOADS; j++) { + sum_dy[0] += sum_dy[j]; + sum_dy_xmu[0] += sum_dy_xmu[j]; + } + + // release array of registers + auto sum_dy_th = sum_dy[0]; + auto sum_dy_xmu_th = sum_dy_xmu[0]; + + // block-wise reduction with shared memory (since reduction cannot be done within a warp) + static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE]; + static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE]; + + merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); + + if (gridDim.y > 1) { + volatile accscalar_t* staging_sum_dy = staging_data; + volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y]; + + address_base = c_offset + blockIdx.y * stride; + // write data to staging_data; + if (threadIdx.y == 0 && c_offset < stride) { + staging_sum_dy[address_base] = sum_dy_th; + staging_sum_dy_xmu[address_base] = sum_dy_xmu_th; + } + + __threadfence(); + __syncthreads(); // ensuring writes to staging_ is visible to all blocks + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y-1)); + } + + __syncthreads(); + + // check that all data is now available in global memory + if (is_last_block_done) { + sum_dy_th = accscalar_t(0.0); + sum_dy_xmu_th = accscalar_t(0.0); + + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + address_base = c_offset + y * stride; + sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0)); + sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0)); + } + + merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); + if (threadIdx.y == 0 && c_offset < stride) { + if (grad_bias != nullptr) { + grad_bias[c_offset] = static_cast(sum_dy_th); + } + if (grad_weight != nullptr) { + grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); + } + //mean_dy[c_offset] = sum_dy_th / reduction_size; + //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; + sum_dy_o[c_offset] = sum_dy_th; + sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; + } + } + } else { + if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { + if (grad_bias != nullptr) { + grad_bias[c_offset] = static_cast(sum_dy_th); + } + if (grad_weight != nullptr) { + grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); + } + //mean_dy[c_offset] = sum_dy_th / reduction_size; + //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; + sum_dy_o[c_offset] = sum_dy_th; + sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; + } + } +} + +// elementwise BN kernel +// original apex name: batchnorm_backward_c_last_kernel +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + scalar_t* __restrict__ grad_input, + const accscalar_t norm_fct, + const int reduction_size, + const int stride) { + // tensor dimension (m,c) + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + auto m_c = mean[c_offset]; + auto m_dy_c = sum_dy[c_offset] * norm_fct; + auto factor_1_c = inv_std[c_offset]; + auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast(weight[c_offset])) * factor_1_c; + factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct; + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + grad_input[address_base] = static_cast( + (static_cast(grad_output[address_base]) - m_dy_c - + (static_cast(input[address_base]) - m_c) * factor_1_c) + * factor_2_c); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + } +} + +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_elemt_channels_last_kernel( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + const int* __restrict__ numel, + scalar_t* __restrict__ grad_input, + const int64_t world_size, + const int reduction_size, + const int stride) { + + int64_t total_numel = 0; + for (int i = 0; i < world_size; i++) { + total_numel += numel[i]; + } + + auto norm_fct = static_cast(1) / static_cast(total_numel); + batch_norm_backward_elemt_channels_last_kernel_impl( + grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, + grad_input, norm_fct, reduction_size, stride); +} + +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_elemt_channels_last_kernel( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + scalar_t* __restrict__ grad_input, + const accscalar_t norm_fct, + const int reduction_size, + const int stride) { + batch_norm_backward_elemt_channels_last_kernel_impl( + grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, + grad_input, norm_fct, reduction_size, stride); +} + +template +void batch_norm_stats_channels_last_cuda_template( + const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) { + using accscalar_t = at::acc_type; + + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + resize_output(out_mean, {stride}); + resize_output(out_invstd, {stride}); + TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && + out_mean.sizes()[0]); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid, true); + + at::Tensor staging_data; + at::Tensor semaphores; + if (grid.y > 1) { + staging_data = at::empty({4*stride*grid.y}, out_mean.options()); + semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); + } + + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_collect_statistics_channels_last_kernel + <<>>( + input.const_data_ptr(), + out_mean.mutable_data_ptr(), + out_invstd.mutable_data_ptr(), + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride, + epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void batch_norm_elemt_channels_last_cuda_template( + const at::Tensor& output, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& shift, // bias of BN + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::optional& z = c10::nullopt, // bias after BN + const bool fuse_relu = false) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + const auto second_dtype = weight.defined() ? weight.scalar_type() : + (shift.defined() ? shift.scalar_type() : input.scalar_type()); + + if (input.scalar_type() != second_dtype) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { + using accscalar_t = at::acc_type; + batch_norm_transform_input_channels_last_kernel + <<>>( + input.const_data_ptr(), + z.has_value() ? z.value().const_data_ptr() : nullptr, + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + shift.defined() ? shift.const_data_ptr() : nullptr, + output.mutable_data_ptr(), + reduction_size, + stride, + fuse_relu); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()){ + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { + using accscalar_t = at::acc_type; + batch_norm_transform_input_channels_last_kernel + <<>>( + input.const_data_ptr(), + z.has_value() ? z.value().const_data_ptr() : nullptr, + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + shift.defined() ? shift.const_data_ptr(): nullptr, + output.mutable_data_ptr(), + reduction_size, + stride, + fuse_relu); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} + +std::tuple +batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const bool input_g, const bool weight_g, const bool bias_g) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + at::Tensor sumn_dy = at::empty({stride}, mean.options()); + at::Tensor sum_dy_xmu = at::empty({stride}, mean.options()); + + at::Tensor grad_weight; + at::Tensor grad_bias; + if (weight.defined()) { + grad_weight = at::empty({stride}, weight.options()); + grad_bias = at::empty({stride}, weight.options()); + } else { + // because I cannot return an uninitialized at::Tensor + grad_weight = at::empty({0}, mean.options()); + grad_bias = at::empty({0}, mean.options()); + } + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid, true); + + at::Tensor staging_data; + at::Tensor semaphores; + if (grid.y > 1) { + staging_data = at::empty({2*stride*grid.y}, mean.options()); + semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); + } + auto stream = at::cuda::getCurrentCUDAStream(); + + if (weight.defined() && input.scalar_type() != weight.scalar_type()) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { + using accscalar_t = at::acc_type; + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_backward_reduce_channels_last_kernel + <<>>( + input.const_data_ptr(), + grad_output.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + sumn_dy.mutable_data_ptr(), + sum_dy_xmu.mutable_data_ptr(), + grad_weight.mutable_data_ptr(), + grad_bias.mutable_data_ptr(), + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()) { + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { + using accscalar_t = at::acc_type; + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_backward_reduce_channels_last_kernel + <<>>( + input.const_data_ptr(), + grad_output.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + sumn_dy.mutable_data_ptr(), + sum_dy_xmu.mutable_data_ptr(), + weight.defined() ? grad_weight.mutable_data_ptr() : nullptr, + weight.defined() ? grad_bias.mutable_data_ptr() : nullptr, + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + + return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias); +} + +at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const at::Tensor& sum_dy, + const at::Tensor& sum_dy_xmu, + const at::Tensor& count) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + + if (weight.defined() && weight.scalar_type() != input.scalar_type()) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.const_data_ptr(), + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + count.const_data_ptr(), + grad_input.mutable_data_ptr(), + count.numel(), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()) { + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + count.const_data_ptr(), + grad_input.mutable_data_ptr(), + count.numel(), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + + return grad_input; +} + +at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const at::Tensor& sum_dy, + const at::Tensor& sum_dy_xmu) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + auto norm_fct = 1.0 / reduction_size; + + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + + if (weight.defined() && weight.scalar_type() != input.scalar_type()) { + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.const_data_ptr(), + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + static_cast(norm_fct), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + static_cast(norm_fct), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + + return grad_input; +} + +} } // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Pow.cuh b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Pow.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9530b0ede27459d33fe9c8a01b71129621da499c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Pow.cuh @@ -0,0 +1,58 @@ +#pragma once +#include +#include + +namespace at { namespace native { + +namespace { + + +// SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt. +// So we need to define the functions with the explicit function signatures. +// As for pow, the following signatures are defined as the device function: +// pow(float, int) +// pow(double, int) +// pow(float, float) +// pow(double, double) +#ifdef _MSC_VER +// Functions for pow +// pow for at::Half +static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow for at::BFloat16 +static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow (floating, floating/int) +template +static inline __host__ __device__ typename std::enable_if::value && (std::is_same::value || std::is_same::value), Base_type>::type + pow_(Base_type base, Exp_type exp) { + return std::pow(base, exp); +} +// pow (Otherwise) +template +static inline __host__ __device__ typename std::enable_if::value && !std::is_same::value, Base_type>::type + pow_(Base_type base, Exp_type exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +#else +template +static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) { + return ::pow(base, exp); +} +#endif + +template +static inline __host__ __device__ std::enable_if_t::value, T> pow_( + T base, T exp) { + return at::native::powi(base, exp); +} + +template +static inline __host__ __device__ c10::complex pow_(c10::complex base, c10::complex exp) { + return c10_complex_math::pow(base, exp); +} + +} // namespace +}} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9d3fb6b2a47bbd65bde7628eaf0161066346efa9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh @@ -0,0 +1,344 @@ +#pragma once +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define HAS_WARP_MERGE_SORT() (CUDA_VERSION >= 110600) + + +namespace at { namespace native { + +template +__device__ inline void swapVars(T& t1, T& t2) { + T tmp = t1; + t1 = t2; + t2 = tmp; +} + +template +__device__ inline void bitonicSwap(K& kA, V& vA, bool& validA, + K& kB, V& vB, bool& validB, + bool dir, + const Comparator& comp) { + // Invalid entries always sort to the end + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(vA, vB); + swapVars(validA, validB); + } +}; + +template +__device__ inline void bitonicSort(K *keys, + V *values, + bool *valid, + const Comparator& comp) { +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap( + keys[pos], values[pos], valid[pos], + keys[pos + stride], values[pos + stride], valid[pos + stride], + flag, comp); + } + } + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap( + keys[pos], values[pos], valid[pos], + keys[pos + stride], values[pos + stride], valid[pos + stride], + false, comp); + } + + __syncthreads(); + +} + +// at::cuda::detail::TensorInfo version +// Sorts (key, value) pairs (in different tensors) in-place; i.e., +// modifies the input `keys` and `values` +template +C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y) +__global__ void +bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If the entire block is out of bounds exit early + if (blockIndex * blockDim.y >= keySlices) { + return; + } + // It's also possible for some rows of a block to be out of bounds + // but all thread need to run for __syncthreads to work. + const bool row_valid = linearIndex < keySlices; + + constexpr int items_per_thread = 2; + constexpr int Power2SortSize = block_dim_x * items_per_thread; + + // Storage for max_block_dim_y sorts performed in parallel + __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize]; + __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize]; + __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize]; + + auto sharedKeys = blockSharedKeys[threadIdx.y]; + auto sharedValues = blockSharedValues[threadIdx.y]; + auto sharedValid = blockSharedValid[threadIdx.y]; + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + // Load 2 values per thread into the shared workspace + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + bool valid = row_valid && idx < keySliceSize; + + sharedKeys[idx] = valid ? + keys.data[idx * keySliceStride + keyStartOffset] : K{}; + sharedValues[idx] = valid ? + values.data[idx * valueSliceStride + valueStartOffset] : V{}; + sharedValid[idx] = valid; + } + + // Sort! + bitonicSort( + sharedKeys, sharedValues, sharedValid, comp); + + if (!row_valid) { + return; + } + + // Store outputs + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + if (idx < keySliceSize) { + keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx]; + values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx]; + } + } +} + +#if HAS_WARP_MERGE_SORT() + +template +C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y) +__global__ void +warpMergeSortKVInPlace( + at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp, + K invalid_key) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If this row is out of bounds exit early + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); + + CUDA_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE); + CUDA_KERNEL_ASSERT(blockDim.y <= max_block_dim_y); + constexpr int items_per_thread = sort_size / C10_WARP_SIZE; + static_assert( + items_per_thread * C10_WARP_SIZE == sort_size, + "sort_size must be a multiple of C10_WARP_SIZE"); + + + using LoadKeys = cub::WarpLoad; + using LoadValues = cub::WarpLoad; + using Sort = cub::WarpMergeSort; + using StoreKeys = cub::WarpStore; + using StoreValues = cub::WarpStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage[max_block_dim_y]; + + auto& warp_storage = tmp_storage[threadIdx.y]; + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + const auto invalid_value = V{}; + LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + WARP_SYNC(); + LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + WARP_SYNC(); + + // Sort! We use stable sort to ensure that invalid values are never + // sorted before valid values. In testing it performed the same as + // .Sort, so there is no down-side. + Sort(warp_storage.sort).StableSort( + local_keys, local_values, comp, keySliceSize, invalid_key); + WARP_SYNC(); + + // Store outputs + StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + WARP_SYNC(); + StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + +#endif // HAS_WARP_MERGE_SORT() + +template +C10_LAUNCH_BOUNDS_1(block_size) +__global__ void +radixSortKVInPlace(at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + bool descending) { + static_assert(block_size > 0, ""); + + // Find the slice of the tensor that we are sorting + const IndexType linearIndex = getLinearBlockId(); + // Tiling the slices could have us be out of bounds, if there are a + // lot of slices to sort + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); + + using key_t = typename at::cuda::cub::detail::cuda_type::type; + using LoadKeys = cub::BlockLoad; + using LoadValues = cub::BlockLoad; + using Sort = cub::BlockRadixSort; + using StoreKeys = cub::BlockStore; + using StoreValues = cub::BlockStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage; + + // cub's Block operations operate on a fixed number of items, but the + // actual slice we are sorting might be smaller. So, we need to make + // up the difference with keys that will always sort higher. + const K invalid_key = [descending] { + using radix_t = typename cub::Traits::UnsignedBits; + union { + K key; + radix_t radix; + } tmp; + tmp.radix = descending ? + cub::Traits::LOWEST_KEY : + cub::Traits::MAX_KEY; + return tmp.key; + }(); + const V invalid_value = static_cast(0); + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + __syncthreads(); + LoadValues(tmp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + __syncthreads(); + + // Sort! + if (descending) { + Sort(tmp_storage.sort).SortDescending( + reinterpret_cast(local_keys), + local_values); + } else { + Sort(tmp_storage.sort).Sort( + reinterpret_cast(local_keys), + local_values); + } + __syncthreads(); + + // Store outputs + StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + __syncthreads(); + StoreValues(tmp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + +}} // at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Sorting.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Sorting.h new file mode 100644 index 0000000000000000000000000000000000000000..bd10ffb1a0274182c77bebe1097169f891dad3d3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Sorting.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at { +namespace native { + +void launch_kthvalue_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t k); +void launch_median_kernel( + const TensorBase &vals, const TensorBase &inds, + const TensorBase &in, int64_t dim, bool ignore_nan); + +}} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c4a8ec6864a1dd030a7a07f73ae8df3c81e9b329 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh @@ -0,0 +1,193 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +// Is this questionable namespace pollution? +#if defined(USE_ROCM) +constexpr int MAX_BLOCK_SIZE = 256; + +#else +constexpr int MAX_BLOCK_SIZE = 1024; +#endif + +// Maximum size per grid dimension that we assume (compute capability >= 2.0) +constexpr int64_t MAX_GRID_SIZE = 65535LL; + +static bool getGridFromTiles(int64_t gridTiles, dim3& grid) { + if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { + return false; + } + + int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + int64_t gridY = 1; + int64_t gridZ = 1; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + } + } + + grid = dim3(gridX, gridY, gridZ); + return true; +} + +template +struct GTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && at::_isnan(lhs) && !at::_isnan(rhs)) || (lhs > rhs); + } +}; + +template +struct LTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && at::_isnan(rhs) && !at::_isnan(lhs)) || (lhs < rhs); + } +}; + +template +__device__ __forceinline__ index_t getLinearBlockId() { + return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + + blockIdx.x; +} + +// For slice sorting in Thrust; extracts a slice index from a linear +// index and uses that for comparison +struct SliceComp { + SliceComp(int64_t size) : sliceSize(size) {} + + __device__ bool operator()(const int64_t& a, const int64_t& b) const { + // Since the slices are guaranteed to be innermost, + // the segment is just via int64_t division + int64_t segA = a / sliceSize; + int64_t segB = b / sliceSize; + return segA < segB; + } + + const int64_t sliceSize; +}; + +// For sorting in Thurst; extracts a within-slice index from a linear index +struct GlobalIndexToPerSliceIndex { + GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {} + + __device__ inline void operator()(int64_t& v) const { + v = v % sliceSize; + } + + const int64_t sliceSize; +}; + +// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks +static uint64_t nextHighestPowerOf2(uint64_t n) { + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; +#ifndef _MSC_VER + n |= n >> 32; +#endif + n++; + + return n; +} + + +// WARNING: This function assumes input tensors are contiguous +template +void run_launcher( + const TensorBase &values, + const TensorBase &indices, + const TensorBase &self, + int64_t dim, + Launcher l) { + auto self_info = cuda::detail::getTensorInfo(self); + auto values_info = cuda::detail::getTensorInfo(values); + auto indices_info = cuda::detail::getTensorInfo(indices); + + int64_t slice_size = self.size(dim); + /* We use these structures solely to find the offset to */ + /* each slice we are operating on */ + self_info.reduceDim(dim); + values_info.reduceDim(dim); + indices_info.reduceDim(dim); + + /* Collapse all other dims */ + int collapse_self_dim = self_info.collapseDims(dim); + int collapse_values_dim = values_info.collapseDims(dim); + int collapse_indices_dim = indices_info.collapseDims(dim); + + int64_t num_slices = 1; + for (int i = 0; i < self_info.dims; ++i) { + num_slices *= self_info.sizes[i]; + } + + /* This is used as a template parameter to calculate indices. */ + /* We only specialize it if all collapsed dim sizes are the */ + /* same; otherwise, we use -1 which is the specialization */ + /* parameter for arbitrary dimensions */ + int all_dims = self_info.dims; + if (values_info.dims != all_dims || indices_info.dims != all_dims) { + all_dims = -1; + } + + if (all_dims == 1) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 2) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 3) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } +} + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b5660747997d4eb1ad56f79ec2d1f519921c05c2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h @@ -0,0 +1,19 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at { +namespace native { + +void launch_fused_mode_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t slice_size, int64_t slices); + +void launch_apply_mode_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t ndim); + +}} // namespace at::native diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorTopK.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorTopK.h new file mode 100644 index 0000000000000000000000000000000000000000..9eebf2cd6040c4f2df9ad64599910ba0e0cee58f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorTopK.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at { +namespace native { +void launch_gather_topk_kernel( + const TensorBase& self, + int64_t k, int64_t dim, bool largest, + const TensorBase& values, const TensorBase& indices); +}} diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e8fd69c0aec93f34985e257dbe3f24f6205f5e72 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh @@ -0,0 +1,143 @@ +#pragma once + +#include + +#include +#include + +namespace at { +namespace native { +namespace cuda_utils { + +constexpr int kCUDABlockReduceNumThreads = 512; +// Algorithmic limitation: BlockReduce does two WarpReduce calls, each +// of which reduces C10_WARP_SIZE elements. So, at most +// C10_WARP_SIZE**2 elements can be reduced at a time. +// NOTE: This is >= the max block size on current hardware anyway (1024). +constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; + +// Sums `val` across all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceSum(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN(val, offset); + } + return val; +} + +// Picks the maximum `val` across all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceMax(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset)); + } + return val; +} + +struct Block1D { + static __forceinline__ __device__ int Tid() { return threadIdx.x; } + + static __forceinline__ __device__ int Warps() { + return blockDim.x / C10_WARP_SIZE; + } +}; + +struct Block2D { + static __forceinline__ __device__ int Tid() { + return threadIdx.x + threadIdx.y * blockDim.x; + } + + static __forceinline__ __device__ int Warps() { + return blockDim.x * blockDim.y / C10_WARP_SIZE; + } +}; + +// Sums `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceSum(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceSum(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(0); + if (wid == 0) { + val = WarpReduceSum(val); + } + return val; +} + +// Picks out the maximum `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceMax(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceMax(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(0); + if (wid == 0) { + val = WarpReduceMax(val); + } + return val; +} + +template +__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = op.combine(val, op.warp_shfl_down(val, offset)); + } + return val; +} + +template +__inline__ __device__ T +BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduce(val, op); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : identity_element; + if (wid == 0) { + val = WarpReduce(val, op); + } + return val; +} + +} // namespace cuda_utils +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/im2col.cuh b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/im2col.cuh new file mode 100644 index 0000000000000000000000000000000000000000..06eef13208c67e88924dea3030ba732aa0671da0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/im2col.cuh @@ -0,0 +1,345 @@ +#pragma once + +#include +#include +#include + +#include + +namespace at { +namespace native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy +// (borrowed from Caffe: +// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu) +// CUDA_NUM_THREADS = 1024 + +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void im2col_kernel( + const int64_t n, + const dt* data_im, + const int64_t height, + const int64_t width, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_col) { + CUDA_KERNEL_LOOP(index, n) { + int64_t w_out = index % width_col; + + int64_t idx = index / width_col; + + int64_t h_out = idx % height_col; + int64_t channel_in = idx / height_col; + int64_t channel_out = channel_in * kernel_height * kernel_width; + int64_t h_in = h_out * stride_height - pad_height; + int64_t w_in = w_out * stride_width - pad_width; + + dt* col = data_col + (channel_out * height_col + h_out) * width_col + w_out; + const dt* im = data_im + (channel_in * height + h_in) * width + w_in; + + for (int64_t i = 0; i < kernel_height; ++i) { + for (int64_t j = 0; j < kernel_width; ++j) { + int64_t h = h_in + i * dilation_height; + int64_t w = w_in + j * dilation_width; + *col = (h >= 0 && w >= 0 && h < height && w < width) + ? im[i * dilation_height * width + j * dilation_width] + : static_cast
(0); + col += height_col * width_col; + } + } + } +} + +template +void im2col( + cudaStream_t stream, + const dt* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_col) { + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int64_t num_kernels = channels * height_col * width_col; + // Launch CUDA_NUM_THREADS = 1024 + im2col_kernel<<>>( + num_kernels, + data_im, + height, + width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_col); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__forceinline__ __device__ void col2im_device( + const int64_t index, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + accT val = static_cast(0); + const int64_t w_im = index % width + pad_width; + const int64_t h_im = (index / width) % height + pad_height; + const int64_t c_im = index / (width * height); + int64_t kernel_extent_w = (kernel_w - 1) * dilation_width + 1; + int64_t kernel_extent_h = (kernel_h - 1) * dilation_height + 1; + // compute the start and end of the output + const int64_t w_col_start = (w_im < kernel_extent_w) + ? 0 + : (w_im - kernel_extent_w) / stride_width + 1; + const int64_t w_col_end = ::min(w_im / stride_width + 1, width_col); + const int64_t h_col_start = (h_im < kernel_extent_h) + ? 0 + : (h_im - kernel_extent_h) / stride_height + 1; + const int64_t h_col_end = ::min(h_im / stride_height + 1, height_col); + + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int64_t h_k = (h_im - h_col * stride_height); + int64_t w_k = (w_im - w_col * stride_width); + if (h_k % dilation_height == 0 && w_k % dilation_width == 0) { + h_k /= dilation_height; + w_k /= dilation_width; + int64_t data_col_index = + (((c_im * kernel_h + h_k) * kernel_w + w_k) * height_col + + h_col) * + width_col + + w_col; + val += data_col[data_col_index]; + } + } + } + data_im[index] = static_cast
(val); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_kernel( + const int64_t n, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + CUDA_KERNEL_LOOP(index, n) { + col2im_device( + index, + data_col, + height, + width, + channels, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + } +} + +template +void col2im( + cudaStream_t stream, + const dt* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im) { + int64_t num_kernels = channels * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_kernel + <<>>( + num_kernels, + data_col, + height, + width, + channels, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_batched_kernel( + const int64_t n, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im, + const int64_t im_batch_stride) { + using accT = at::acc_type; + const auto im_numel = n * nbatch; + + CUDA_KERNEL_LOOP_TYPE(index, im_numel, int64_t) { + const auto ibatch = index / n; + const auto slice_index = index % n; + + col2im_device( + slice_index, + data_col + ibatch * col_batch_stride, + height, + width, + channels, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im + ibatch * im_batch_stride); + } +} + +template +void col2im_batched( + cudaStream_t stream, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im, + const int64_t im_batch_stride) { + const int64_t num_kernels = channels * height * width; + const int64_t output_numel = nbatch * num_kernels; + if (output_numel == 0) { + return; // No work to do + } + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_batched_kernel<<>>( + num_kernels, + data_col, + col_batch_stride, + nbatch, + height, + width, + channels, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im, + im_batch_stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h new file mode 100644 index 0000000000000000000000000000000000000000..31526c3ec3c52057463cd00f0dd8556160d4d2df --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h @@ -0,0 +1,47 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +// Quantize a float value into a uint value given scale and zero_point +template +TORCH_API T quantize_val(double scale, int64_t zero_point, float value); +// TODO combine this with quantize_val once the numerics for ARM are aligned +// with it +template +T quantize_val_arm( + const float scale, + const int32_t zero_point, + const float value); +template +void quantize_vec( + double scale, + int64_t zero_point, + const float* src, + T* dst, + size_t count = 8); +template +TORCH_API float dequantize_val(double scale, int64_t zero_point, T value); +template +TORCH_API float dequantize_vec( + double scale, + int64_t zero_point, + const T* src, + float* dst, + size_t count = 8); +template +TORCH_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src); + +// Given a multiplier and a zero_point, requantize int32_t computed values back +// to quantized values. See comment above +// make_per_tensor_affine_quantizer function for the usage of int64_t +template +TORCH_API DST_T +requantize_from_int(double multiplier, int64_t zero_point, int64_t src); + +int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax); + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/Copy.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..d52c8ff0fb2c7f7f6eed17acceb660482144eef9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/Copy.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace at { +namespace native { + +Tensor& quantized_copy_from_float_(Tensor& self, const Tensor& src); +} +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h new file mode 100644 index 0000000000000000000000000000000000000000..1fb7cfbb0e721f83ba5a9194ad72ea98c97d997d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +namespace at { + +struct TensorIterator; + +namespace native { + +using fake_quant_tensor_cachemask_fn = void (*)( + Tensor& output, + Tensor& mask, + const Tensor& input, + float sc, + int64_t z_point, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_tensor_cachemask_tensor_qparams_fn = void (*)( + Tensor& output, + Tensor& mask, + const Tensor& input, + const Tensor& sc, + const Tensor& z_point, + const Tensor& fake_quant_enabled, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_learnable_grad_tensor_fn = void (*)( + TensorIterator& iter, + float scale, + float inv_scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + float grad_factor); + +DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub); +DECLARE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_fn, fake_quant_tensor_cachemask_tensor_qparams_stub); +DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub); + +using fake_quant_per_channel_fn = void (*)( + TensorIterator &iter, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_per_channel_cachemask_fn = void (*)( + TensorIterator &iter, + TensorIterator &iter_mask, + int64_t quant_min, + int64_t quant_max); + +DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub); + +using fake_quant_learnable_per_channel_fn = void (*)( + TensorIterator &iter, + int64_t quant_min, + int64_t quant_max, + float grad_factor); + +DECLARE_DISPATCH(fake_quant_learnable_per_channel_fn, fake_quant_grad_learnable_channel_stub); + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..cf86a13c139a1f429ecb2cc4918c04df9e4b3246 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h @@ -0,0 +1,8 @@ +#include + +namespace at { +namespace native { +TORCH_API Tensor +quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point); +} +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..72abe1ad817f484e0d269b31cf78b98bf0694e5a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h @@ -0,0 +1,21 @@ +#pragma once + +#ifdef USE_RUY_QMATMUL + +#include + +namespace at { +namespace native { +namespace ruy_utils { + +ruy::Context* get_ruy_context(); + +void quantize_multiplier(double scale, + int* multiplier_fixedpoint, + int* multiplier_exponent); + +} // namespace ruy_utils +} // namespace native +} // namespace + +#endif // USE_RUY_QMATMUL diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..bfaf5b93d667bf6286561cf72c3fb5c487cc1704 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h @@ -0,0 +1,411 @@ +#pragma once + +#include +#include +#include +#include +#include + +#ifdef USE_FBGEMM +#include +#include +#include + +// The struct for the packed weight matrix (PackBMatrix) and the corresponding +// column offsets used for the fully connect layer, which are both prepared in +// the prepacking step to save the computations in the inference. Note the +// column offsets include the sum of the B columns as well as the scalar term +// B_zero_point * K, whereas the row offsets created by +// PackAWithQuantRowOffset/PackAWithIm2Col/PackAWithRowOffset are only the sum +// of the A rows. The column offsets are needed for the asymmetric quantization +// (affine quantization) of input matrix. +// Note that in JIT mode we can think of a way to fuse col_offsets with bias. +struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase { + PackedLinearWeight( + std::unique_ptr> w, + c10::optional bias, + std::vector col_offsets, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme) + : w(std::move(w)), + bias_(std::move(bias)), + col_offsets(std::move(col_offsets)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(std::move(q_scheme)) {} + std::unique_ptr> w; + c10::optional bias_; + std::vector col_offsets; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor& apply_out( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output) override; + + at::Tensor& apply_relu_out( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output) override; + + at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) override; + + at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) + override; + + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) + override; + + std::tuple> unpack() override; + + c10::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + c10::optional bias); + + private: + template + at::Tensor& apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output); + + template + at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32_impl( + const at::Tensor& input, + double input_scale, + int64_t input_zero_point); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range = false); +}; + +struct TORCH_API PackedLinearWeightFp16 : public LinearPackedParamsBase { + PackedLinearWeightFp16( + std::unique_ptr w, + c10::optional bias) + : w(std::move(w)), bias_(std::move(bias)) {} + + std::unique_ptr w; + c10::optional bias_; + + at::Tensor apply( + at::Tensor /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/) override { + TORCH_INTERNAL_ASSERT(false); + } + at::Tensor apply_relu( + at::Tensor /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/) override { + TORCH_INTERNAL_ASSERT(false); + } + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) + override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) + override; + + at::Tensor& apply_dynamic_out( + const at::Tensor& input, + at::Tensor& output, + bool reduce_range = false) override; + at::Tensor& apply_dynamic_relu_out( + const at::Tensor& input, + at::Tensor& output, + bool reduce_range = false) override; + + std::tuple> unpack() override; + + c10::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + c10::optional bias); + + void set_bias(c10::optional bias) override; + + private: + template + at::Tensor& apply_dynamic_impl(const at::Tensor& input, at::Tensor& output); +}; + +template +struct TORCH_API PackedConvWeight : public ConvPackedParamsBase { + PackedConvWeight( + std::unique_ptr> w, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + uint8_t transpose, + std::vector col_offsets, + std::vector kernel, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme) + : w(std::move(w)), + bias(std::move(bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose), + col_offsets(std::move(col_offsets)), + kernel(std::move(kernel)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(q_scheme) {} + + std::unique_ptr> w; + c10::optional bias; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + uint8_t transpose_; + std::vector col_offsets; + std::vector kernel; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override; + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + const float* GetBiasData(at::Tensor* bias); + + void GetQuantizationParams( + float act_scale, + float out_scale, + std::vector* output_multiplier_float, + std::vector* act_times_w_scale); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return (bool)transpose_; + } + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +}; + +// PackWeight: Convert the weight from uint8 to int8. +inline void convert_uint8_int8( + int len, + const uint8_t* src_uint8, + int8_t* dst_int8) { + for (const auto i : c10::irange(len)) { + dst_int8[i] = static_cast(static_cast(src_uint8[i]) - 128); + } +} + +// UnpackWeight: Convert the weight from int8 to uint8. +inline void convert_int8_uint8( + int len, + const int8_t* src_int8, + uint8_t* dst_uint8) { + for (const auto i : c10::irange(len)) { + dst_uint8[i] = + static_cast(static_cast(src_int8[i]) + 128); + } +} + +namespace at { +namespace native { +namespace fbgemm_utils { + +template +fbgemm::conv_param_t MakeFbgemmConvParam( + int N, + int C, + int M, + const std::vector& image_shape, + int groups, + const std::vector& kernels, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + const std::vector& output_padding = std::vector(kSpatialDim, 0), + bool transposed = false); + +// TODO: Remove functions below when ChannelsLast3d is ready. +Tensor MakeStridedQTensorCPU( + const IntArrayRef& sizes, + const IntArrayRef& strides, + const TensorOptions& options, + QuantizerPtr quantizer); + +Tensor MakeEmptyAffineQuantizedChannelsLast3dTensor( + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W, + const TensorOptions& options, + double scale, + int64_t zero_point); + +Tensor MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor( + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W, + const TensorOptions& options, + const Tensor& scales, + const Tensor& zero_points); + +Tensor ConvertToChannelsLast3dTensor(const Tensor& src); + +template +Tensor TransposeConvTensorUnpackConversion(const Tensor& src, int groups); + +template +Tensor ConvertConvWeightsToChannelLastTensor( + const at::Tensor& src, + int groups, + bool transpose); +} // namespace fbgemm_utils +} // namespace native +} // namespace at + +#endif // USE_FBGEMM + +struct TORCH_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { + PackedEmbeddingBagWeight( + at::Tensor packed_w, + std::vector w_scale, + std::vector w_zp, + int64_t bit_rate, + c10::QScheme q_scheme, + int64_t version) + : packed_w(std::move(packed_w)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + bit_rate_(bit_rate), + q_scheme(q_scheme), + version_(version) { + // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move) + if (!packed_w.is_contiguous()) { + packed_w = packed_w.contiguous(); + } + } + + at::Tensor packed_w; + std::vector w_scale; + std::vector w_zp; + int64_t bit_rate_; + c10::QScheme q_scheme; + int64_t version_; + + at::Tensor unpack() override; + static c10::intrusive_ptr prepack( + at::Tensor weight); + + int64_t bit_rate() const override { + return bit_rate_; + } + + int64_t version() const override { + return version_; + } + + at::Tensor embeddingbag_byte( + const at::Tensor& indices, + const c10::optional& offsets, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) override; + + at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const c10::optional& offsets, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) override; +}; diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h new file mode 100644 index 0000000000000000000000000000000000000000..0a65f3f07f397b931c1a4b6bd781e6308643117f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h @@ -0,0 +1,13 @@ +#pragma once +#include + +namespace at { namespace native { + +Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight); + +Tensor qembeddingbag_byte_prepack(const Tensor& weight); + +Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight); + +} // namespace native +} // namespace at diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/stubs.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/stubs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc3aa8c84a1f7903534e6eb63307ecaf071be71b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/stubs.cpython-311.pyc differ