Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/error.py +56 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_base.py +435 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/proxy_value.py +41 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/union.py +69 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__init__.py +150 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/config.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/coordinate_descent_tuner.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/debug.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/dependencies.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/freezing.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/hooks.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/metrics.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/sizevars.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_case.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_operators.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/virtualized.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_prefix.h +595 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py +1851 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py +328 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py +374 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/gemm_template.py +706 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/memory_planning.py +799 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/multi_kernel.py +413 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_foreach.py +250 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_split_scan.py +180 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_utils.py +130 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/wrapper.py +1543 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py +273 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py +2159 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ops_handler.py +655 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/optimize_indexing.py +118 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_heuristics.py +1527 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/error.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ExportErrorType(Enum):
|
| 5 |
+
# User providing invalid inputs to either tracer, or other public facing APIs
|
| 6 |
+
INVALID_INPUT_TYPE = 1
|
| 7 |
+
|
| 8 |
+
# User returning values from their models that we don’t support.
|
| 9 |
+
INVALID_OUTPUT_TYPE = 2
|
| 10 |
+
|
| 11 |
+
# Generated IR does not conform to Export IR Specification.
|
| 12 |
+
VIOLATION_OF_SPEC = 3
|
| 13 |
+
|
| 14 |
+
# User’s code contains types and functionalities we don’t support.
|
| 15 |
+
NOT_SUPPORTED = 4
|
| 16 |
+
|
| 17 |
+
# User's code didn't provide necessary details for us to successfully trace and export.
|
| 18 |
+
# For example, we use a lot of decorators and ask users to annotate their model.
|
| 19 |
+
MISSING_PROPERTY = 5
|
| 20 |
+
|
| 21 |
+
# User is using an API without proper initialization step.
|
| 22 |
+
UNINITIALIZED = 6
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def internal_assert(pred: bool, assert_msg: str) -> None:
|
| 26 |
+
"""
|
| 27 |
+
This is exir's custom assert method. It internally just throws InternalError.
|
| 28 |
+
Note that the sole purpose is to throw our own error while maintaining similar syntax
|
| 29 |
+
as python assert.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
if not pred:
|
| 33 |
+
raise InternalError(assert_msg)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class InternalError(Exception):
|
| 37 |
+
"""
|
| 38 |
+
Raised when an internal invariance is violated in EXIR stack.
|
| 39 |
+
Should hint users to report a bug to dev and expose the original
|
| 40 |
+
error message.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, message: str) -> None:
|
| 44 |
+
super().__init__(message)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ExportError(Exception):
|
| 48 |
+
"""
|
| 49 |
+
This type of exception is raised for errors that are directly caused by the user
|
| 50 |
+
code. In general, user errors happen during model authoring, tracing, using our public
|
| 51 |
+
facing APIs, and writing graph passes.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, error_code: ExportErrorType, message: str) -> None:
|
| 55 |
+
prefix = f"[{error_code}]: "
|
| 56 |
+
super().__init__(prefix + message)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_base.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import operator
|
| 2 |
+
import traceback
|
| 3 |
+
import typing
|
| 4 |
+
from contextlib import nullcontext
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from functorch.experimental.control_flow import _unstack_pytree
|
| 9 |
+
from torch import fx
|
| 10 |
+
from torch._dispatch.python import enable_python_dispatcher
|
| 11 |
+
from torch._export.pass_infra.node_metadata import NodeMetadata
|
| 12 |
+
from torch._export.pass_infra.proxy_value import ProxyValue
|
| 13 |
+
from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
|
| 14 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 15 |
+
from torch.fx import traceback as fx_traceback
|
| 16 |
+
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
|
| 17 |
+
from torch.fx.graph import CodeGen
|
| 18 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 19 |
+
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
| 20 |
+
from torch.utils import _pytree as pytree
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
Argument = Any
|
| 27 |
+
Value = Any
|
| 28 |
+
Fn = Callable[..., Any]
|
| 29 |
+
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
_TORCH_SYM_OPS: Set[Callable] = {
|
| 33 |
+
torch.sym_int,
|
| 34 |
+
torch.sym_ite,
|
| 35 |
+
torch.sym_max,
|
| 36 |
+
torch.sym_min,
|
| 37 |
+
torch.sym_not,
|
| 38 |
+
torch.sym_sqrt,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ExportPassBaseError(RuntimeError):
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
| 47 |
+
"""
|
| 48 |
+
Interpreter-based pass class to help users maintain the IR spec while writing
|
| 49 |
+
transformations.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def _create_dummy_node_metadata():
|
| 54 |
+
return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ExportTracer(PythonKeyTracer):
|
| 58 |
+
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.callback = callback
|
| 61 |
+
self.root = torch.nn.Module()
|
| 62 |
+
self.graph = torch.fx.Graph()
|
| 63 |
+
self.graph.set_codegen(codegen)
|
| 64 |
+
self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment]
|
| 65 |
+
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
| 66 |
+
self.submodules: Dict[torch.nn.Module, str] = {}
|
| 67 |
+
|
| 68 |
+
def trace(self) -> None:
|
| 69 |
+
raise ExportPassBaseError("ExportTracer doesn't support trace().")
|
| 70 |
+
|
| 71 |
+
def create_arg(self, a: Argument) -> torch.fx.Node:
|
| 72 |
+
if isinstance(a, torch.nn.Module):
|
| 73 |
+
if a not in self.submodules:
|
| 74 |
+
name_submodule = f"submodule_{len(self.submodules)}"
|
| 75 |
+
self.root.add_module(name_submodule, a)
|
| 76 |
+
self.submodules[a] = name_submodule
|
| 77 |
+
elif isinstance(a, FakeTensor):
|
| 78 |
+
if not hasattr(a, "constant") or a.constant is None:
|
| 79 |
+
raise ExportPassBaseError(f"Cannot add {a} to graph.")
|
| 80 |
+
a = a.constant
|
| 81 |
+
node = super().create_arg(a)
|
| 82 |
+
if (
|
| 83 |
+
isinstance(a, torch.Tensor)
|
| 84 |
+
and isinstance(node, torch.fx.Node)
|
| 85 |
+
and node.op == "get_attr"
|
| 86 |
+
):
|
| 87 |
+
self.set_metadata(node, a)
|
| 88 |
+
self.callback.on_attr(ProxyValue(a, node))
|
| 89 |
+
return node
|
| 90 |
+
|
| 91 |
+
def set_metadata(
|
| 92 |
+
self, node: torch.fx.Node, value: Argument,
|
| 93 |
+
) -> None:
|
| 94 |
+
# propagate the fake tensor or sym nodes
|
| 95 |
+
def make_val(
|
| 96 |
+
x: Argument,
|
| 97 |
+
) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
|
| 98 |
+
if isinstance(x, FakeTensor):
|
| 99 |
+
return x
|
| 100 |
+
elif isinstance(x, torch.Tensor):
|
| 101 |
+
if x.is_quantized:
|
| 102 |
+
# TODO (tmanlaibaatar) properly support Quantized FakeTensor
|
| 103 |
+
x = torch.dequantize(x)
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
assert self.fake_tensor_mode is not None
|
| 107 |
+
# TODO we should allocate static shapes
|
| 108 |
+
# for param/buffer values
|
| 109 |
+
if isinstance(x, torch.nn.Parameter):
|
| 110 |
+
fake_tensor = self.fake_tensor_mode.from_tensor(
|
| 111 |
+
x, static_shapes=True
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
fake_tensor = self.fake_tensor_mode.from_tensor(x)
|
| 115 |
+
except UnsupportedFakeTensorException:
|
| 116 |
+
# TODO: This is just a workaround to get over the
|
| 117 |
+
# x.as_subclass error
|
| 118 |
+
print(
|
| 119 |
+
"Fakeifying a Tensor subclass is not supported \
|
| 120 |
+
right now. Instead a TensorMetadata is used."
|
| 121 |
+
)
|
| 122 |
+
fake_tensor = None
|
| 123 |
+
return fake_tensor
|
| 124 |
+
elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
|
| 125 |
+
return x
|
| 126 |
+
else:
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
node.meta["val"] = pytree.tree_map(make_val, value)
|
| 130 |
+
|
| 131 |
+
# Set the tensor_metadata for values that do not have a corresponding FakeTensor
|
| 132 |
+
def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
|
| 133 |
+
if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
|
| 134 |
+
if x.is_quantized:
|
| 135 |
+
# TODO (tmanlaibaatar) properly support Quantized FakeTensor
|
| 136 |
+
x = torch.dequantize(x)
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
assert self.fake_tensor_mode is not None
|
| 140 |
+
_ = self.fake_tensor_mode.from_tensor(x)
|
| 141 |
+
tensor_meta = None
|
| 142 |
+
except UnsupportedFakeTensorException:
|
| 143 |
+
# TODO: This is just a workaround to get over the
|
| 144 |
+
# x.as_subclass error
|
| 145 |
+
tensor_meta = _extract_tensor_metadata(x)
|
| 146 |
+
return tensor_meta
|
| 147 |
+
else:
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
|
| 151 |
+
|
| 152 |
+
class ExportInterpreter(fx.Interpreter):
|
| 153 |
+
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
|
| 154 |
+
super().__init__(gm)
|
| 155 |
+
self.callback = callback
|
| 156 |
+
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
|
| 157 |
+
|
| 158 |
+
def placeholder(
|
| 159 |
+
self,
|
| 160 |
+
target: str,
|
| 161 |
+
args: Tuple[Argument, ...],
|
| 162 |
+
kwargs: Dict[str, Argument],
|
| 163 |
+
) -> ProxyValue:
|
| 164 |
+
arg = super().placeholder(target, args, kwargs)
|
| 165 |
+
return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
|
| 166 |
+
|
| 167 |
+
def output(
|
| 168 |
+
self,
|
| 169 |
+
target: torch.fx.node.Target,
|
| 170 |
+
args: Tuple[Argument, ...],
|
| 171 |
+
kwargs: Dict[str, Argument],
|
| 172 |
+
) -> ProxyValue:
|
| 173 |
+
return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
|
| 174 |
+
|
| 175 |
+
def call_function(
|
| 176 |
+
self,
|
| 177 |
+
target: torch.fx.node.Target,
|
| 178 |
+
args: Tuple[Argument, ...],
|
| 179 |
+
kwargs: Dict[str, Argument],
|
| 180 |
+
) -> ProxyValue:
|
| 181 |
+
meta = NodeMetadata(self.node.meta)
|
| 182 |
+
|
| 183 |
+
if target == operator.getitem:
|
| 184 |
+
value, key = args
|
| 185 |
+
return self.callback.call_getitem(value, key, meta)
|
| 186 |
+
elif getattr(target, "__module__", None) in {"_operator", "math"}:
|
| 187 |
+
assert callable(target)
|
| 188 |
+
return self.callback.call_sym(target, args, meta)
|
| 189 |
+
elif target in _TORCH_SYM_OPS:
|
| 190 |
+
assert callable(target)
|
| 191 |
+
return self.callback.call_sym(target, args, meta)
|
| 192 |
+
elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
|
| 193 |
+
return self.callback.call_operator(
|
| 194 |
+
target,
|
| 195 |
+
args,
|
| 196 |
+
kwargs,
|
| 197 |
+
meta,
|
| 198 |
+
)
|
| 199 |
+
elif target == torch.ops.higher_order.cond:
|
| 200 |
+
pred, true_fn, false_fn, inputs = args
|
| 201 |
+
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
|
| 202 |
+
elif target == torch.ops.higher_order.map_impl:
|
| 203 |
+
f, mapped_args, operands = args # type: ignore[assignment]
|
| 204 |
+
return self.callback.call_map(f, mapped_args, operands, meta)
|
| 205 |
+
# For other unregistered HigherOrderOps, just interpret them blindly
|
| 206 |
+
elif isinstance(target, torch._ops.HigherOrderOperator):
|
| 207 |
+
return self.callback._fx(
|
| 208 |
+
"call_function",
|
| 209 |
+
target,
|
| 210 |
+
args,
|
| 211 |
+
kwargs,
|
| 212 |
+
meta,
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
raise ExportPassBaseError(f"Unsupported target type: {target}")
|
| 216 |
+
|
| 217 |
+
def get_attr(
|
| 218 |
+
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
| 219 |
+
) -> Argument:
|
| 220 |
+
return super().get_attr(target, args, kwargs)
|
| 221 |
+
|
| 222 |
+
def call_module(
|
| 223 |
+
self,
|
| 224 |
+
target: torch.fx.node.Target,
|
| 225 |
+
args: Tuple[Argument, ...],
|
| 226 |
+
kwargs: Dict[str, Argument],
|
| 227 |
+
) -> None:
|
| 228 |
+
raise ExportPassBaseError("call_module is not supported.")
|
| 229 |
+
|
| 230 |
+
def call_method(
|
| 231 |
+
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
| 232 |
+
) -> None:
|
| 233 |
+
raise ExportPassBaseError("call_method is not supported.")
|
| 234 |
+
|
| 235 |
+
def run_node(self, n: torch.fx.Node) -> Argument:
|
| 236 |
+
self.node = n
|
| 237 |
+
self.callback.node_debug_str = n.format_node()
|
| 238 |
+
return super().run_node(n)
|
| 239 |
+
|
| 240 |
+
def __init__(self) -> None:
|
| 241 |
+
self.interpreter = torch.fx.Interpreter(
|
| 242 |
+
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
| 243 |
+
)
|
| 244 |
+
self.tracer = self.ExportTracer(self, CodeGen())
|
| 245 |
+
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
| 246 |
+
self._initialized = True
|
| 247 |
+
self.node_debug_str: typing.Optional[str] = None
|
| 248 |
+
|
| 249 |
+
def _fx(
|
| 250 |
+
self,
|
| 251 |
+
kind: str,
|
| 252 |
+
target: torch.fx.node.Target,
|
| 253 |
+
args: Tuple[Argument, ...],
|
| 254 |
+
kwargs: Dict[str, Argument],
|
| 255 |
+
meta: NodeMetadata,
|
| 256 |
+
) -> ProxyValue:
|
| 257 |
+
args_data, kwargs_data = pytree.tree_map_only(
|
| 258 |
+
ProxyValue, lambda x: x.data, (args, kwargs)
|
| 259 |
+
)
|
| 260 |
+
res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
|
| 261 |
+
args_proxy, kwargs_proxy = pytree.tree_map_only(
|
| 262 |
+
ProxyValue, lambda x: x.proxy, (args, kwargs)
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
name = None
|
| 266 |
+
if isinstance(target, torch._ops.OpOverload):
|
| 267 |
+
name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
|
| 268 |
+
|
| 269 |
+
res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
|
| 270 |
+
res_proxy.node.meta.update(meta.data)
|
| 271 |
+
self.tracer.set_metadata(res_proxy.node, res_data)
|
| 272 |
+
return ProxyValue(res_data, res_proxy)
|
| 273 |
+
|
| 274 |
+
def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
|
| 275 |
+
# TODO(angelayi): Update this with what we decide to do for metadata in
|
| 276 |
+
# the exported graph module
|
| 277 |
+
if (args := graph_module.meta.get("args", None)) is not None:
|
| 278 |
+
return list(args)
|
| 279 |
+
|
| 280 |
+
def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
|
| 281 |
+
if "val" in node.meta:
|
| 282 |
+
fake = node.meta["val"]
|
| 283 |
+
if hasattr(fake, "constant") and fake.constant is not None:
|
| 284 |
+
return fake.constant
|
| 285 |
+
return fake
|
| 286 |
+
elif tensor_meta := node.meta.get("tensor_meta"):
|
| 287 |
+
assert self.fake_tensor_mode is not None
|
| 288 |
+
return FakeTensor(
|
| 289 |
+
self.fake_tensor_mode,
|
| 290 |
+
torch.empty(
|
| 291 |
+
tensor_meta.shape,
|
| 292 |
+
dtype=tensor_meta.dtype,
|
| 293 |
+
device="meta",
|
| 294 |
+
requires_grad=tensor_meta.requires_grad,
|
| 295 |
+
memory_format=tensor_meta.memory_format,
|
| 296 |
+
),
|
| 297 |
+
torch.device("cpu"),
|
| 298 |
+
)
|
| 299 |
+
elif len(node.users) == 0:
|
| 300 |
+
return None
|
| 301 |
+
raise ExportPassBaseError(
|
| 302 |
+
f"Cannot construct an input for graph module: {graph_module}.",
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
return [
|
| 306 |
+
extract_input(node)
|
| 307 |
+
for node in graph_module.graph.nodes
|
| 308 |
+
if node.op == "placeholder"
|
| 309 |
+
]
|
| 310 |
+
|
| 311 |
+
def on_attr(self, attr: ProxyValue) -> None:
|
| 312 |
+
pass
|
| 313 |
+
|
| 314 |
+
def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
|
| 315 |
+
arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
|
| 316 |
+
arg_proxy.node.meta = meta.data
|
| 317 |
+
self.tracer.set_metadata(arg_proxy.node, arg)
|
| 318 |
+
return ProxyValue(arg, arg_proxy)
|
| 319 |
+
|
| 320 |
+
def call_operator(
|
| 321 |
+
self,
|
| 322 |
+
op,
|
| 323 |
+
args: Tuple[Argument, ...],
|
| 324 |
+
kwargs: Dict[str, Argument],
|
| 325 |
+
meta: NodeMetadata,
|
| 326 |
+
) -> ProxyValue:
|
| 327 |
+
return self._fx("call_function", op, args, kwargs, meta)
|
| 328 |
+
|
| 329 |
+
def call_sym(
|
| 330 |
+
self,
|
| 331 |
+
target: Fn,
|
| 332 |
+
args: Tuple[Argument, ...],
|
| 333 |
+
meta: NodeMetadata,
|
| 334 |
+
) -> ProxyValue:
|
| 335 |
+
return self._fx("call_function", target, args, {}, meta)
|
| 336 |
+
|
| 337 |
+
def call_cond(
|
| 338 |
+
self,
|
| 339 |
+
pred: ProxyValue,
|
| 340 |
+
true_fn: torch.fx.GraphModule,
|
| 341 |
+
false_fn: torch.fx.GraphModule,
|
| 342 |
+
inputs: List[Argument],
|
| 343 |
+
meta: NodeMetadata,
|
| 344 |
+
) -> ProxyValue:
|
| 345 |
+
true_branch = self.call_submodule(true_fn, tuple(inputs))
|
| 346 |
+
false_branch = self.call_submodule(false_fn, tuple(inputs))
|
| 347 |
+
assert true_branch is not None
|
| 348 |
+
assert false_branch is not None
|
| 349 |
+
return self._fx(
|
| 350 |
+
"call_function",
|
| 351 |
+
torch.ops.higher_order.cond,
|
| 352 |
+
(pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
|
| 353 |
+
{},
|
| 354 |
+
meta,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
def call_map(
|
| 358 |
+
self,
|
| 359 |
+
f: torch.fx.GraphModule,
|
| 360 |
+
mapped_args: List[ProxyValue],
|
| 361 |
+
operands: List[ProxyValue],
|
| 362 |
+
meta: NodeMetadata,
|
| 363 |
+
) -> ProxyValue:
|
| 364 |
+
xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
|
| 365 |
+
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
|
| 366 |
+
assert f_branch is not None
|
| 367 |
+
return self._fx(
|
| 368 |
+
"call_function",
|
| 369 |
+
torch.ops.higher_order.map_impl,
|
| 370 |
+
(f_branch.graph_module, mapped_args, operands),
|
| 371 |
+
{},
|
| 372 |
+
meta,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
def call_getitem(
|
| 376 |
+
self, value: ProxyValue, key: int, meta: NodeMetadata
|
| 377 |
+
) -> ProxyValue:
|
| 378 |
+
return self._fx("call_function", operator.getitem, (value, key), {}, meta)
|
| 379 |
+
|
| 380 |
+
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
|
| 381 |
+
return self._fx("output", "output", (results,), {}, meta)
|
| 382 |
+
|
| 383 |
+
def call_submodule(
|
| 384 |
+
self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
|
| 385 |
+
) -> PassResult:
|
| 386 |
+
prev_tracer, self.tracer = self.tracer, self.ExportTracer(
|
| 387 |
+
self, graph_module.graph._codegen
|
| 388 |
+
)
|
| 389 |
+
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
|
| 390 |
+
interpreter = self.ExportInterpreter(self, graph_module)
|
| 391 |
+
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
|
| 392 |
+
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
| 393 |
+
)
|
| 394 |
+
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
|
| 395 |
+
with fx_traceback.preserve_node_meta():
|
| 396 |
+
interpreter.run(*inputs_data)
|
| 397 |
+
|
| 398 |
+
new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
|
| 399 |
+
|
| 400 |
+
self.tracer = prev_tracer
|
| 401 |
+
self.interpreter = prev_interpreter
|
| 402 |
+
return PassResult(
|
| 403 |
+
new_graph_module,
|
| 404 |
+
True,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
def call(self, graph_module: fx.GraphModule) -> PassResult:
|
| 408 |
+
if not getattr(self, "_initialized", False):
|
| 409 |
+
raise ExportPassBaseError(
|
| 410 |
+
"ExportPass is not initialized with __init__().",
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
inputs = self.inputs(graph_module)
|
| 414 |
+
|
| 415 |
+
fake_tensor_mode = None
|
| 416 |
+
for i in inputs:
|
| 417 |
+
if isinstance(i, FakeTensor):
|
| 418 |
+
assert (
|
| 419 |
+
fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
|
| 420 |
+
), "Multiple fake tensor mode detected."
|
| 421 |
+
fake_tensor_mode = i.fake_mode
|
| 422 |
+
if fake_tensor_mode is None:
|
| 423 |
+
self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
|
| 424 |
+
fake_tensor_mode = nullcontext() # type: ignore[assignment]
|
| 425 |
+
dispatcher_mode = nullcontext() # type: ignore[assignment]
|
| 426 |
+
else:
|
| 427 |
+
fake_tensor_mode.allow_non_fake_inputs = True
|
| 428 |
+
self.tracer.fake_tensor_mode = fake_tensor_mode
|
| 429 |
+
dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment]
|
| 430 |
+
self.fake_tensor_mode = self.tracer.fake_tensor_mode
|
| 431 |
+
|
| 432 |
+
with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr]
|
| 433 |
+
result = self.call_submodule(graph_module, tuple(inputs))
|
| 434 |
+
|
| 435 |
+
return result
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-311.pyc
ADDED
|
Binary file (2.86 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/pass_infra/proxy_value.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pyre-strict
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ProxyValue:
|
| 8 |
+
# pyre-ignore
|
| 9 |
+
def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]):
|
| 10 |
+
# pyre-ignore
|
| 11 |
+
self.data = data
|
| 12 |
+
self.proxy_or_node = proxy
|
| 13 |
+
|
| 14 |
+
@property
|
| 15 |
+
def node(self) -> torch.fx.Node:
|
| 16 |
+
if isinstance(self.proxy_or_node, torch.fx.Node):
|
| 17 |
+
return self.proxy_or_node
|
| 18 |
+
assert isinstance(self.proxy_or_node, torch.fx.Proxy)
|
| 19 |
+
return self.proxy_or_node.node
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def proxy(self) -> torch.fx.Proxy:
|
| 23 |
+
if not isinstance(self.proxy_or_node, torch.fx.Proxy):
|
| 24 |
+
raise RuntimeError(
|
| 25 |
+
f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}"
|
| 26 |
+
)
|
| 27 |
+
return self.proxy_or_node
|
| 28 |
+
|
| 29 |
+
def to_tensor(self) -> torch.Tensor:
|
| 30 |
+
assert isinstance(self.data, torch.Tensor)
|
| 31 |
+
return self.data
|
| 32 |
+
|
| 33 |
+
def is_tensor(self) -> bool:
|
| 34 |
+
return isinstance(self.data, torch.Tensor)
|
| 35 |
+
|
| 36 |
+
# pyre-ignore
|
| 37 |
+
def __iter__(self):
|
| 38 |
+
yield from self.data
|
| 39 |
+
|
| 40 |
+
def __bool__(self) -> bool:
|
| 41 |
+
return bool(self.data)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (220 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/union.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from dataclasses import fields
|
| 3 |
+
from typing import Hashable, Set
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class _UnionTag(str):
|
| 7 |
+
_cls: Hashable
|
| 8 |
+
|
| 9 |
+
@staticmethod
|
| 10 |
+
def create(t, cls):
|
| 11 |
+
tag = _UnionTag(t)
|
| 12 |
+
assert not hasattr(tag, "_cls")
|
| 13 |
+
tag._cls = cls
|
| 14 |
+
return tag
|
| 15 |
+
|
| 16 |
+
def __eq__(self, cmp) -> bool:
|
| 17 |
+
assert isinstance(cmp, str)
|
| 18 |
+
other = str(cmp)
|
| 19 |
+
assert other in _get_field_names(
|
| 20 |
+
self._cls
|
| 21 |
+
), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
|
| 22 |
+
return str(self) == other
|
| 23 |
+
|
| 24 |
+
def __hash__(self):
|
| 25 |
+
return hash(str(self))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@functools.lru_cache(maxsize=None)
|
| 29 |
+
def _get_field_names(cls) -> Set[str]:
|
| 30 |
+
return {f.name for f in fields(cls)}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class _Union:
|
| 34 |
+
_type: _UnionTag
|
| 35 |
+
|
| 36 |
+
@classmethod
|
| 37 |
+
def create(cls, **kwargs):
|
| 38 |
+
assert len(kwargs) == 1
|
| 39 |
+
obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]
|
| 40 |
+
obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls)
|
| 41 |
+
return obj
|
| 42 |
+
|
| 43 |
+
def __post_init__(self):
|
| 44 |
+
assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc]
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def type(self) -> str:
|
| 48 |
+
try:
|
| 49 |
+
return self._type
|
| 50 |
+
except AttributeError as e:
|
| 51 |
+
raise RuntimeError(
|
| 52 |
+
f"Please use {type(self).__name__}.create to instantiate the union type."
|
| 53 |
+
) from e
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def value(self):
|
| 57 |
+
return getattr(self, self.type)
|
| 58 |
+
|
| 59 |
+
def __getattribute__(self, name):
|
| 60 |
+
attr = super().__getattribute__(name)
|
| 61 |
+
if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type]
|
| 62 |
+
raise AttributeError(f"Field {name} is not set.")
|
| 63 |
+
return attr
|
| 64 |
+
|
| 65 |
+
def __str__(self):
|
| 66 |
+
return self.__repr__()
|
| 67 |
+
|
| 68 |
+
def __repr__(self):
|
| 69 |
+
return f"{type(self).__name__}({self.type}={getattr(self, self.type)})"
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__init__.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
import torch.fx
|
| 4 |
+
import torch.utils._pytree as pytree
|
| 5 |
+
|
| 6 |
+
__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def compile(
|
| 10 |
+
gm: torch.fx.GraphModule,
|
| 11 |
+
example_inputs: List[torch.Tensor],
|
| 12 |
+
options: Optional[Dict[str, Any]] = None,
|
| 13 |
+
):
|
| 14 |
+
"""
|
| 15 |
+
Compile a given FX graph with TorchInductor. This allows compiling
|
| 16 |
+
FX graphs captured without using TorchDynamo.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
gm: The FX graph to compile.
|
| 20 |
+
example_inputs: List of tensor inputs.
|
| 21 |
+
options: Optional dict of config options. See `torch._inductor.config`.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Callable with same behavior as gm but faster.
|
| 25 |
+
"""
|
| 26 |
+
from .compile_fx import compile_fx
|
| 27 |
+
|
| 28 |
+
return compile_fx(gm, example_inputs, config_patches=options)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def aot_compile(
|
| 32 |
+
gm: torch.fx.GraphModule,
|
| 33 |
+
example_inputs: List[torch.Tensor],
|
| 34 |
+
options: Optional[Dict[str, Any]] = None,
|
| 35 |
+
) -> str:
|
| 36 |
+
"""
|
| 37 |
+
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
gm: The FX graph to compile.
|
| 41 |
+
example_inputs: List of tensor inputs.
|
| 42 |
+
options: Optional dict of config options. See `torch._inductor.config`.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Path to the generated shared library
|
| 46 |
+
"""
|
| 47 |
+
from .compile_fx import compile_fx_aot
|
| 48 |
+
|
| 49 |
+
# We will serialize the pytree info into the .so as constant strings
|
| 50 |
+
in_spec = None
|
| 51 |
+
out_spec = None
|
| 52 |
+
if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen):
|
| 53 |
+
codegen = gm.graph._codegen
|
| 54 |
+
gm.graph._codegen = torch.fx.graph.CodeGen()
|
| 55 |
+
gm.recompile()
|
| 56 |
+
|
| 57 |
+
if codegen.pytree_info.in_spec is not None:
|
| 58 |
+
in_spec = codegen.pytree_info.in_spec
|
| 59 |
+
if codegen.pytree_info.out_spec is not None:
|
| 60 |
+
out_spec = codegen.pytree_info.out_spec
|
| 61 |
+
|
| 62 |
+
else:
|
| 63 |
+
if hasattr(gm, "_in_spec"):
|
| 64 |
+
in_spec = gm._in_spec
|
| 65 |
+
if hasattr(gm, "_out_spec"):
|
| 66 |
+
out_spec = gm._out_spec
|
| 67 |
+
|
| 68 |
+
serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else ""
|
| 69 |
+
serialized_out_spec = (
|
| 70 |
+
pytree.treespec_dumps(out_spec) if out_spec is not None else ""
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
options = (
|
| 74 |
+
{
|
| 75 |
+
"aot_inductor.serialized_in_spec": serialized_in_spec,
|
| 76 |
+
"aot_inductor.serialized_out_spec": serialized_out_spec,
|
| 77 |
+
}
|
| 78 |
+
if options is None
|
| 79 |
+
else {
|
| 80 |
+
**options,
|
| 81 |
+
"aot_inductor.serialized_in_spec": serialized_in_spec,
|
| 82 |
+
"aot_inductor.serialized_out_spec": serialized_out_spec,
|
| 83 |
+
}
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return compile_fx_aot(
|
| 87 |
+
gm,
|
| 88 |
+
example_inputs,
|
| 89 |
+
config_patches=options,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def list_mode_options(
|
| 94 |
+
mode: Optional[str] = None, dynamic: Optional[bool] = None
|
| 95 |
+
) -> Dict[str, Any]:
|
| 96 |
+
r"""Returns a dictionary describing the optimizations that each of the available
|
| 97 |
+
modes passed to `torch.compile()` performs.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
mode (str, optional): The mode to return the optimizations for.
|
| 101 |
+
If None, returns optimizations for all modes
|
| 102 |
+
dynamic (bool, optional): Whether dynamic shape is enabled.
|
| 103 |
+
|
| 104 |
+
Example::
|
| 105 |
+
>>> torch._inductor.list_mode_options()
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
mode_options: Dict[str, Dict[str, bool]] = {
|
| 109 |
+
"default": {},
|
| 110 |
+
# enable cudagraphs
|
| 111 |
+
"reduce-overhead": {
|
| 112 |
+
"triton.cudagraphs": True,
|
| 113 |
+
},
|
| 114 |
+
# enable max-autotune
|
| 115 |
+
"max-autotune-no-cudagraphs": {
|
| 116 |
+
"max_autotune": True,
|
| 117 |
+
},
|
| 118 |
+
# enable max-autotune
|
| 119 |
+
# enable cudagraphs
|
| 120 |
+
"max-autotune": {
|
| 121 |
+
"max_autotune": True,
|
| 122 |
+
"triton.cudagraphs": True,
|
| 123 |
+
},
|
| 124 |
+
}
|
| 125 |
+
return mode_options[mode] if mode else mode_options # type: ignore[return-value]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def list_options() -> List[str]:
|
| 129 |
+
r"""Returns a dictionary describing the optimizations and debug configurations
|
| 130 |
+
that are available to `torch.compile()`.
|
| 131 |
+
|
| 132 |
+
The options are documented in `torch._inductor.config`.
|
| 133 |
+
|
| 134 |
+
Example::
|
| 135 |
+
|
| 136 |
+
>>> torch._inductor.list_options()
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
from torch._inductor import config
|
| 140 |
+
|
| 141 |
+
current_config: Dict[str, Any] = config.shallow_copy_dict()
|
| 142 |
+
|
| 143 |
+
return list(current_config.keys())
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def cudagraph_mark_step_begin():
|
| 147 |
+
"Indicates that a new iteration of inference or training is about to begin."
|
| 148 |
+
from .cudagraph_trees import mark_step_begin
|
| 149 |
+
|
| 150 |
+
mark_step_begin()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (5.23 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-311.pyc
ADDED
|
Binary file (68.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/coordinate_descent_tuner.cpython-311.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/debug.cpython-311.pyc
ADDED
|
Binary file (38.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/dependencies.cpython-311.pyc
ADDED
|
Binary file (33 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/freezing.cpython-311.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/hooks.cpython-311.pyc
ADDED
|
Binary file (1.33 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-311.pyc
ADDED
|
Binary file (4.98 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-311.pyc
ADDED
|
Binary file (730 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-311.pyc
ADDED
|
Binary file (64.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/sizevars.cpython-311.pyc
ADDED
|
Binary file (39.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_case.cpython-311.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/test_operators.cpython-311.pyc
ADDED
|
Binary file (1.99 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/virtualized.cpython-311.pyc
ADDED
|
Binary file (21.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-311.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_prefix.h
ADDED
|
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <atomic>
|
| 5 |
+
#include <cmath>
|
| 6 |
+
#include <cstdlib>
|
| 7 |
+
#include <limits>
|
| 8 |
+
#include <omp.h>
|
| 9 |
+
|
| 10 |
+
#include <ATen/NumericUtils.h>
|
| 11 |
+
#include <ATen/core/PhiloxRNGEngine.h>
|
| 12 |
+
#include <ATen/native/Math.h>
|
| 13 |
+
|
| 14 |
+
#include <c10/util/Float8_e4m3fn.h>
|
| 15 |
+
#include <c10/util/Float8_e5m2.h>
|
| 16 |
+
#include <c10/util/BFloat16.h>
|
| 17 |
+
#include <c10/util/BFloat16-math.h>
|
| 18 |
+
#include <c10/util/generic_math.h>
|
| 19 |
+
#include <c10/util/Half.h>
|
| 20 |
+
#include <c10/util/TypeCast.h>
|
| 21 |
+
|
| 22 |
+
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR)
|
| 23 |
+
#define INDUCTOR_USE_VECTOR_TYPES() 1
|
| 24 |
+
#else
|
| 25 |
+
#define INDUCTOR_USE_VECTOR_TYPES() 0
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
#if INDUCTOR_USE_VECTOR_TYPES()
|
| 29 |
+
#include <ATen/cpu/vec/functional.h>
|
| 30 |
+
#include <ATen/cpu/vec/vec.h>
|
| 31 |
+
#include <ATen/cpu/vec/vec_n.h>
|
| 32 |
+
#endif
|
| 33 |
+
|
| 34 |
+
typedef at::Half half;
|
| 35 |
+
typedef at::BFloat16 bfloat16;
|
| 36 |
+
|
| 37 |
+
typedef at::Float8_e4m3fn float8_e4m3fn;
|
| 38 |
+
typedef at::Float8_e5m2 float8_e5m2;
|
| 39 |
+
|
| 40 |
+
template <typename T>
|
| 41 |
+
struct Welford {
|
| 42 |
+
T mean = T(0);
|
| 43 |
+
T m2 = T(0);
|
| 44 |
+
T weight = T(0);
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
template <typename T>
|
| 49 |
+
struct IsVecType: std::false_type {};
|
| 50 |
+
|
| 51 |
+
#if INDUCTOR_USE_VECTOR_TYPES()
|
| 52 |
+
template <typename T>
|
| 53 |
+
struct IsVecType<at::vec::Vectorized<T>>: std::true_type {};
|
| 54 |
+
#endif
|
| 55 |
+
|
| 56 |
+
template <typename T>
|
| 57 |
+
Welford<T> welford_combine(const Welford<T> &a, const Welford<T> &b) {
|
| 58 |
+
if constexpr (!IsVecType<T>::value) {
|
| 59 |
+
if (a.weight == 0) {
|
| 60 |
+
return b;
|
| 61 |
+
}
|
| 62 |
+
if (b.weight == 0) {
|
| 63 |
+
return a;
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
auto delta = b.mean - a.mean;
|
| 67 |
+
auto new_weight = a.weight + b.weight;
|
| 68 |
+
auto wb_over_w = b.weight / new_weight;
|
| 69 |
+
if constexpr (IsVecType<T>::value) {
|
| 70 |
+
// Guard against division by zero
|
| 71 |
+
wb_over_w = T::blendv(wb_over_w, T(0), new_weight == T(0));
|
| 72 |
+
}
|
| 73 |
+
auto result = Welford<T>{
|
| 74 |
+
a.mean + delta * wb_over_w,
|
| 75 |
+
a.m2 + b.m2 + delta * delta * a.weight * wb_over_w,
|
| 76 |
+
new_weight
|
| 77 |
+
};
|
| 78 |
+
return result;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <typename T>
|
| 82 |
+
Welford<T> welford_combine(const Welford<T> &acc, T data) {
|
| 83 |
+
// Add a single data point
|
| 84 |
+
auto delta = data - acc.mean;
|
| 85 |
+
auto new_weight = acc.weight + T(1);
|
| 86 |
+
auto new_mean = acc.mean + delta / new_weight;
|
| 87 |
+
auto new_delta = data - new_mean;
|
| 88 |
+
auto result = Welford<T>{
|
| 89 |
+
new_mean,
|
| 90 |
+
acc.m2 + delta * new_delta,
|
| 91 |
+
new_weight
|
| 92 |
+
};
|
| 93 |
+
return result;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
// Refer to https://github.com/pytorch/pytorch/blob/b5b36cf0c4e1958f1ff25120f5d4beeef3288187/
|
| 97 |
+
// aten/src/ATen/native/SharedReduceOps.h#L419-L445
|
| 98 |
+
template <typename scalar_t>
|
| 99 |
+
inline bool greater_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) {
|
| 100 |
+
// If (a == b), then choose the one with lower idx, else max(a, b)
|
| 101 |
+
if (at::_isnan(a)) {
|
| 102 |
+
if (at::_isnan(b)) {
|
| 103 |
+
return idx_a < idx_b;
|
| 104 |
+
}
|
| 105 |
+
return true;
|
| 106 |
+
}
|
| 107 |
+
return (a == b) ? idx_a < idx_b : (a > b);
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
template <typename scalar_t>
|
| 111 |
+
inline bool less_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) {
|
| 112 |
+
// If (a == b), then choose the one with lower idx, else min(a, b)
|
| 113 |
+
if (at::_isnan(a)) {
|
| 114 |
+
if (at::_isnan(b)) {
|
| 115 |
+
return idx_a < idx_b;
|
| 116 |
+
}
|
| 117 |
+
return true;
|
| 118 |
+
}
|
| 119 |
+
return (a == b) ? idx_a < idx_b : (a < b);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
#if INDUCTOR_USE_VECTOR_TYPES()
|
| 123 |
+
template <typename scalar_t>
|
| 124 |
+
inline at::vec::Vectorized<scalar_t> vec_shuffle_down(at::vec::Vectorized<scalar_t> x, size_t n) {
|
| 125 |
+
using Vec = at::vec::Vectorized<scalar_t>;
|
| 126 |
+
alignas(alignof(Vec)) scalar_t array[Vec::size()];
|
| 127 |
+
x.store(array);
|
| 128 |
+
for (size_t i = 0; i + n < Vec::size(); i += 2 * n) {
|
| 129 |
+
array[i] = array[i + n];
|
| 130 |
+
}
|
| 131 |
+
return Vec::loadu(array);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
#ifdef CPU_CAPABILITY_AVX2
|
| 135 |
+
inline at::vec::Vectorized<float> vec_shuffle_down(at::vec::Vectorized<float> x, size_t n) {
|
| 136 |
+
using vec_t = at::vec::Vectorized<float>;
|
| 137 |
+
#define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w)
|
| 138 |
+
switch (n) {
|
| 139 |
+
case 1:
|
| 140 |
+
return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3)));
|
| 141 |
+
case 2:
|
| 142 |
+
return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2)));
|
| 143 |
+
case 4:
|
| 144 |
+
return vec_t(_mm256_permute2f128_ps(x, x, SHUFFLE_MASK(1, 1, 1, 1)));
|
| 145 |
+
}
|
| 146 |
+
TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n);
|
| 147 |
+
}
|
| 148 |
+
#endif
|
| 149 |
+
|
| 150 |
+
template <typename scalar_t>
|
| 151 |
+
Welford<scalar_t> welford_vec_reduce_all(Welford<at::vec::Vectorized<scalar_t>> acc) {
|
| 152 |
+
using Vec = at::vec::Vectorized<scalar_t>;
|
| 153 |
+
for (size_t n = 1; n < Vec::size(); n *= 2) {
|
| 154 |
+
auto shuffled = Welford<Vec>{
|
| 155 |
+
vec_shuffle_down(acc.mean, n),
|
| 156 |
+
vec_shuffle_down(acc.m2, n),
|
| 157 |
+
vec_shuffle_down(acc.weight, n)
|
| 158 |
+
};
|
| 159 |
+
acc = welford_combine(acc, shuffled);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
Welford<scalar_t> result;
|
| 163 |
+
alignas(alignof(Vec)) scalar_t array[Vec::size()];
|
| 164 |
+
acc.mean.store(array);
|
| 165 |
+
result.mean = array[0];
|
| 166 |
+
|
| 167 |
+
acc.m2.store(array);
|
| 168 |
+
result.m2 = array[0];
|
| 169 |
+
|
| 170 |
+
acc.weight.store(array);
|
| 171 |
+
result.weight = array[0];
|
| 172 |
+
|
| 173 |
+
return result;
|
| 174 |
+
}
|
| 175 |
+
#endif
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
template <typename T, typename U> inline typename std::common_type<T, U>::type mod(T a, U b) { return a % b; }
|
| 179 |
+
template <> inline float mod(float a, float b) { return std::fmod(a, b); }
|
| 180 |
+
template <> inline double mod(double a, double b) { return std::fmod(a, b); }
|
| 181 |
+
|
| 182 |
+
template <typename scalar_t>
|
| 183 |
+
inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
|
| 184 |
+
if (at::_isnan(a)) {
|
| 185 |
+
return a;
|
| 186 |
+
}
|
| 187 |
+
return a > b ? a : b;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
template <typename scalar_t>
|
| 191 |
+
inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
|
| 192 |
+
if (at::_isnan(a)) {
|
| 193 |
+
return a;
|
| 194 |
+
}
|
| 195 |
+
return a < b ? a : b;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
constexpr float uint32_to_uniform_float(uint32_t value) {
|
| 199 |
+
// maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
|
| 200 |
+
constexpr float scale = 4.6566127342e-10;
|
| 201 |
+
return static_cast<float>(value & 0x7FFFFFFF) * scale;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
float normalized_rand_cpu(uint32_t seed, uint32_t offset) {
|
| 205 |
+
return uint32_to_uniform_float(at::Philox4_32(seed, 0, offset)());
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
float randn_cpu(uint32_t seed, uint32_t offset) {
|
| 209 |
+
at::Philox4_32 engine(seed, 0, offset);
|
| 210 |
+
return engine.randn(10);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
int64_t randint64_cpu(uint32_t seed, uint32_t offset, int64_t low, int64_t high) {
|
| 214 |
+
auto gen = at::Philox4_32(seed, 0, offset);
|
| 215 |
+
uint64_t r0 = gen();
|
| 216 |
+
uint64_t r1 = gen();
|
| 217 |
+
uint64_t result = r0 | (r1 << 32);
|
| 218 |
+
return static_cast<int64_t>(result % (high - low)) + low;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
template <typename T> struct AsIntegerType { typedef T type; };
|
| 222 |
+
template <> struct AsIntegerType<float> { typedef uint32_t type; };
|
| 223 |
+
template <> struct AsIntegerType<double> { typedef uint64_t type; };
|
| 224 |
+
template <> struct AsIntegerType<bfloat16> { typedef uint16_t type; };
|
| 225 |
+
|
| 226 |
+
template <typename T>
|
| 227 |
+
typename std::enable_if<!std::is_reduced_floating_point<T>::value, T>::type
|
| 228 |
+
inline fetch_value(volatile T *addr) {
|
| 229 |
+
return *addr;
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
template <typename T>
|
| 233 |
+
typename std::enable_if<std::is_reduced_floating_point<T>::value, T>::type
|
| 234 |
+
inline fetch_value(volatile T *addr) {
|
| 235 |
+
return T(addr->x, T::from_bits());
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
template <typename T>
|
| 239 |
+
typename std::enable_if<!std::is_integral<T>::value>::type
|
| 240 |
+
atomic_add(volatile T *addr, T offset) {
|
| 241 |
+
typedef typename AsIntegerType<T>::type alt_type;
|
| 242 |
+
|
| 243 |
+
static_assert(sizeof(std::atomic<alt_type>) == sizeof(T),
|
| 244 |
+
"std::atomic issue");
|
| 245 |
+
|
| 246 |
+
alt_type expected;
|
| 247 |
+
|
| 248 |
+
alt_type desired;
|
| 249 |
+
|
| 250 |
+
std::atomic<alt_type> *atomic_addr = (std::atomic<alt_type> *)addr;
|
| 251 |
+
do {
|
| 252 |
+
T val = fetch_value(addr);
|
| 253 |
+
reinterpret_cast<T *>(&expected)[0] = val;
|
| 254 |
+
reinterpret_cast<T *>(&desired)[0] = val + offset;
|
| 255 |
+
} while (!atomic_addr->compare_exchange_weak(expected, desired,
|
| 256 |
+
std::memory_order_relaxed));
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// Since C++20 float is supported by fetch_add, but the performance may not
|
| 260 |
+
// better than compare_exchange_weak, which can be checked by microbenchmark
|
| 261 |
+
// inductor_cpu_atomic.py
|
| 262 |
+
template <typename T>
|
| 263 |
+
typename std::enable_if<std::is_integral<T>::value>::type
|
| 264 |
+
atomic_add(volatile T *addr, T offset) {
|
| 265 |
+
static_assert(sizeof(std::atomic<T>) == sizeof(T),
|
| 266 |
+
"std::atomic issue");
|
| 267 |
+
std::atomic<T> *atomic_addr = (std::atomic<T> *)addr;
|
| 268 |
+
atomic_addr->fetch_add(offset, std::memory_order_relaxed);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
// This function is used to convert bool or uint8 to float mask for
|
| 272 |
+
// vectorization. The caller needs to make sure the src represents TRUE/FALSE
|
| 273 |
+
// correctly.
|
| 274 |
+
template <typename T>
|
| 275 |
+
inline float flag_to_float_scalar(T src) {
|
| 276 |
+
float ret;
|
| 277 |
+
*(uint32_t*)(&ret) = src ? 0xFFFFFFFF : 0;
|
| 278 |
+
return ret;
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR)
|
| 282 |
+
|
| 283 |
+
inline at::vec::Vectorized<float> masked_load(const float* src, at::vec::Vectorized<float> mask) {
|
| 284 |
+
# if defined(CPU_CAPABILITY_AVX512)
|
| 285 |
+
at::vec::Vectorized<float> zero_vec(0);
|
| 286 |
+
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
|
| 287 |
+
auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ);
|
| 288 |
+
return _mm512_mask_loadu_ps(zero_vec, mmask, src);
|
| 289 |
+
# elif defined(CPU_CAPABILITY_AVX2)
|
| 290 |
+
auto all_ones = _mm256_set1_epi32(0xFFFFFFFF);
|
| 291 |
+
auto mmask = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones);
|
| 292 |
+
return _mm256_maskload_ps(src, mmask);
|
| 293 |
+
# elif defined(CPU_CAPABILITY_ZVECTOR)
|
| 294 |
+
auto result = at::vec::Vectorized<float>::loadu(src);
|
| 295 |
+
return (result & mask);
|
| 296 |
+
# else
|
| 297 |
+
# error Unsupported vectorization CPU capability
|
| 298 |
+
# endif
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
template <typename T>
|
| 302 |
+
typename std::enable_if<std::is_same<T, bfloat16>::value || std::is_same<T, half>::value, at::vec::Vectorized<T>>::type
|
| 303 |
+
inline masked_load(const T* src, at::vec::Vectorized<float> mask) {
|
| 304 |
+
# if defined(CPU_CAPABILITY_AVX512)
|
| 305 |
+
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
|
| 306 |
+
auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ);
|
| 307 |
+
auto zero = _mm256_set1_epi16(0);
|
| 308 |
+
auto temp = _mm256_mask_loadu_epi16(zero, mmask, src);
|
| 309 |
+
return _mm512_inserti32x8(_mm512_castsi256_si512(temp), zero, 1);
|
| 310 |
+
# elif defined(CPU_CAPABILITY_AVX2)
|
| 311 |
+
auto all_ones = _mm256_set1_epi32(0xFFFFFFFF);
|
| 312 |
+
auto mmask_vec = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones);
|
| 313 |
+
__at_align__ uint32_t mmask[8];
|
| 314 |
+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(mmask), mmask_vec);
|
| 315 |
+
__at_align__ uint16_t result[16];
|
| 316 |
+
for (auto i = 0; i < 8; i++) {
|
| 317 |
+
result[i] = mmask[i] == 0xFFFFFFFF ? src[i].x: uint16_t(0);
|
| 318 |
+
}
|
| 319 |
+
return at::vec::Vectorized<T>::loadu(result);
|
| 320 |
+
# elif defined(CPU_CAPABILITY_ZVECTOR)
|
| 321 |
+
auto result = at::vec::Vectorized<T>::loadu(src, 8);
|
| 322 |
+
uint32_t maskdata[8] = { 0 };
|
| 323 |
+
uint16_t maskdata_dest[16] = { 0 };
|
| 324 |
+
mask.store(maskdata);
|
| 325 |
+
for (auto i = 0; i < 8; i++) {
|
| 326 |
+
maskdata_dest[i] = (maskdata[i] == 0xFFFFFFFF) ? 0xFFFF: 0;
|
| 327 |
+
}
|
| 328 |
+
auto maskvector = at::vec::Vectorized<T>::loadu(maskdata_dest);
|
| 329 |
+
return (result & maskvector);
|
| 330 |
+
# else
|
| 331 |
+
# error Unsupported vectorization CPU capability
|
| 332 |
+
# endif
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
template <typename T>
|
| 336 |
+
typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, at::vec::Vectorized<T>>::type
|
| 337 |
+
inline masked_load(const T* src, at::vec::Vectorized<float> mask) {
|
| 338 |
+
# if defined(CPU_CAPABILITY_AVX512)
|
| 339 |
+
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
|
| 340 |
+
auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ);
|
| 341 |
+
auto zero = _mm_set1_epi8(0);
|
| 342 |
+
auto temp = _mm_mask_loadu_epi8(zero, mmask, src);
|
| 343 |
+
return _mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0);
|
| 344 |
+
# elif defined(CPU_CAPABILITY_AVX2)
|
| 345 |
+
auto all_ones = _mm256_set1_epi32(0xFFFFFFFF);
|
| 346 |
+
auto mmask_vec = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones);
|
| 347 |
+
__at_align__ uint32_t mmask[8];
|
| 348 |
+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(mmask), mmask_vec);
|
| 349 |
+
__at_align__ T result[32];
|
| 350 |
+
for (auto i = 0; i < 8; i++) {
|
| 351 |
+
result[i] = mmask[i] == 0xFFFFFFFF ? src[i]: T(0);
|
| 352 |
+
}
|
| 353 |
+
return at::vec::Vectorized<T>::loadu(result);
|
| 354 |
+
# elif defined(CPU_CAPABILITY_ZVECTOR)
|
| 355 |
+
auto result = at::vec::Vectorized<T>::loadu(src, 8);
|
| 356 |
+
uint32_t maskdata[8];
|
| 357 |
+
T maskdata_dest[32] = { 0 };
|
| 358 |
+
mask.store(maskdata);
|
| 359 |
+
for (auto i = 0; i < 8; i++) {
|
| 360 |
+
maskdata_dest[i] = (maskdata[i] == 0xFFFFFFFF) ? 0xFF: 0;
|
| 361 |
+
}
|
| 362 |
+
auto maskvector = at::vec::Vectorized<T>::loadu(maskdata_dest);
|
| 363 |
+
return (result & maskvector);
|
| 364 |
+
# else
|
| 365 |
+
# error Unsupported vectorization CPU capability
|
| 366 |
+
# endif
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
template <typename T>
|
| 370 |
+
inline at::vec::Vectorized<float> flag_to_float_vec(const T* src) {
|
| 371 |
+
__at_align__ float dst_tmp[at::vec::Vectorized<float>::size()];
|
| 372 |
+
#pragma unroll
|
| 373 |
+
for (int64_t i = 0; i < at::vec::Vectorized<float>::size(); i++) {
|
| 374 |
+
dst_tmp[i] = flag_to_float_scalar(src[i]);
|
| 375 |
+
}
|
| 376 |
+
return at::vec::Vectorized<float>::loadu(dst_tmp);
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
template <typename scalar_t>
|
| 380 |
+
inline at::vec::Vectorized<float> cvt_lowp_fp_to_fp32(
|
| 381 |
+
at::vec::Vectorized<scalar_t> src) {
|
| 382 |
+
at::vec::Vectorized<float> res_vec1(0);
|
| 383 |
+
at::vec::Vectorized<float> res_vec2(0);
|
| 384 |
+
std::tie(res_vec1, res_vec2) = at::vec::convert_to_float<scalar_t>(src);
|
| 385 |
+
return res_vec1;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
template <typename scalar_t>
|
| 389 |
+
inline at::vec::Vectorized<scalar_t> cvt_fp32_to_lowp_fp(
|
| 390 |
+
at::vec::Vectorized<float> src) {
|
| 391 |
+
return at::vec::convert_from_float<scalar_t>(src, src);
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
inline at::vec::Vectorized<float> mask_convert_to_float(at::vec::Vectorized<float> src) {
|
| 395 |
+
auto zeros = at::vec::Vectorized<float>(0);
|
| 396 |
+
auto ones = at::vec::Vectorized<float>(1);
|
| 397 |
+
return at::vec::Vectorized<float>::blendv(zeros, ones, src);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
template <typename scalar_t>
|
| 401 |
+
inline
|
| 402 |
+
typename std::enable_if<std::is_same<scalar_t, bfloat16>::value || std::is_same<scalar_t, half>::value, at::vec::Vectorized<scalar_t>>::type
|
| 403 |
+
mask_convert_to_lowp(at::vec::Vectorized<float> src) {
|
| 404 |
+
auto fp_vec = mask_convert_to_float(src);
|
| 405 |
+
return cvt_fp32_to_lowp_fp<scalar_t>(fp_vec);
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
template <typename SRC>
|
| 409 |
+
inline at::vec::Vectorized<float> vec_convert_to_mask(at::vec::Vectorized<SRC> src) {
|
| 410 |
+
assert(
|
| 411 |
+
at::vec::Vectorized<float>::size() == at::vec::Vectorized<SRC>::size());
|
| 412 |
+
at::vec::Vectorized<float> res_vec(0);
|
| 413 |
+
__at_align__ float dst_tmp[at::vec::Vectorized<float>::size()];
|
| 414 |
+
__at_align__ SRC src_tmp[at::vec::Vectorized<SRC>::size()];
|
| 415 |
+
src.store(src_tmp);
|
| 416 |
+
|
| 417 |
+
#pragma unroll
|
| 418 |
+
for (int i = 0; i < at::vec::Vectorized<float>::size(); i++) {
|
| 419 |
+
*(uint32_t*)(dst_tmp + i) = src_tmp[i] ? 0xFFFFFFFF : 0;
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
return res_vec.loadu(dst_tmp);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
template <typename SRC>
|
| 426 |
+
inline at::vec::Vectorized<float> to_float_mask(at::vec::Vectorized<SRC> src) {
|
| 427 |
+
return vec_convert_to_mask(src);
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
|
| 431 |
+
template <>
|
| 432 |
+
inline at::vec::Vectorized<float> to_float_mask(at::vec::Vectorized<int> src) {
|
| 433 |
+
#if defined(CPU_CAPABILITY_AVX2)
|
| 434 |
+
return at::vec::Vectorized<float>(_mm256_castsi256_ps(src));
|
| 435 |
+
#else
|
| 436 |
+
return at::vec::Vectorized<float>(_mm512_castsi512_ps(src));
|
| 437 |
+
#endif
|
| 438 |
+
}
|
| 439 |
+
#endif
|
| 440 |
+
|
| 441 |
+
template <>
|
| 442 |
+
inline at::vec::Vectorized<float> to_float_mask(at::vec::Vectorized<float> src) {
|
| 443 |
+
return src;
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
inline at::vec::Vectorized<float> to_float_mask(int src) {
|
| 447 |
+
union {
|
| 448 |
+
float fmask;
|
| 449 |
+
uint32_t imask;
|
| 450 |
+
} mask;
|
| 451 |
+
mask.imask = src ? 0xFFFFFFFF : 0;
|
| 452 |
+
return at::vec::Vectorized<float>(mask.fmask);
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
inline bool all_zero(at::vec::Vectorized<float> src) {
|
| 456 |
+
# if defined(CPU_CAPABILITY_AVX512)
|
| 457 |
+
auto src_int = _mm512_castps_si512(src);
|
| 458 |
+
__mmask16 mask = _mm512_test_epi32_mask(src_int, src_int);
|
| 459 |
+
return mask == 0;
|
| 460 |
+
# elif defined(CPU_CAPABILITY_AVX2)
|
| 461 |
+
return _mm256_testz_ps(src, src);
|
| 462 |
+
# else
|
| 463 |
+
__at_align__ int mask[at::vec::Vectorized<float>::size()];
|
| 464 |
+
src.store(mask);
|
| 465 |
+
for (int i = 0; i < at::vec::Vectorized<float>::size(); i++) {
|
| 466 |
+
if (mask[i] != 0) {
|
| 467 |
+
return false;
|
| 468 |
+
}
|
| 469 |
+
}
|
| 470 |
+
return true;
|
| 471 |
+
# endif
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
inline bool vector_lane_mask_check(at::vec::Vectorized<float> src, int lane) {
|
| 475 |
+
# if defined(CPU_CAPABILITY_AVX512)
|
| 476 |
+
return _mm512_movepi32_mask(_mm512_castps_si512(src)) & (1 << lane);
|
| 477 |
+
# elif defined(CPU_CAPABILITY_AVX2)
|
| 478 |
+
return _mm256_movemask_ps(src) & (1 << lane);
|
| 479 |
+
# else
|
| 480 |
+
__at_align__ int mask[at::vec::Vectorized<float>::size()];
|
| 481 |
+
src.store(mask);
|
| 482 |
+
return mask[lane] != 0;
|
| 483 |
+
# endif
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
inline at::vec::Vectorized<float> cvt_int64_to_fp32(at::vec::VectorizedN<int64_t,2> src) {
|
| 487 |
+
# if defined(CPU_CAPABILITY_AVX512)
|
| 488 |
+
auto low = _mm512_cvtepi64_ps(src[0]);
|
| 489 |
+
auto high = _mm512_cvtepi64_ps(src[1]);
|
| 490 |
+
return _mm512_insertf32x8(_mm512_castps256_ps512(low), high, 1);
|
| 491 |
+
# elif defined(CPU_CAPABILITY_AVX2)
|
| 492 |
+
auto low_double = at::vec::convert_to_fp_of_same_size<double>(src[0]);
|
| 493 |
+
auto low = _mm256_cvtpd_ps(low_double);
|
| 494 |
+
auto high_double = at::vec::convert_to_fp_of_same_size<double>(src[1]);
|
| 495 |
+
auto high = _mm256_cvtpd_ps(high_double);
|
| 496 |
+
return _mm256_insertf128_ps(_mm256_castps128_ps256(low), high, 1);
|
| 497 |
+
# else
|
| 498 |
+
constexpr int float_vec_size = at::vec::Vectorized<float>::size();
|
| 499 |
+
constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
|
| 500 |
+
__at_align__ float result[float_vec_size];
|
| 501 |
+
__at_align__ int64_t src_buf[int64_vec_size];
|
| 502 |
+
for (int i = 0; i < 2; i++) {
|
| 503 |
+
src[i].store(src_buf + i * int64_vec_size);
|
| 504 |
+
for (int j = 0; j < int64_vec_size; j++) {
|
| 505 |
+
result[i * int64_vec_size + j] = static_cast<float>(src_buf[i * int64_vec_size + j]);
|
| 506 |
+
}
|
| 507 |
+
}
|
| 508 |
+
return at::vec::Vectorized<float>::loadu(result);
|
| 509 |
+
# endif
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
inline at::vec::VectorizedN<int64_t,2> cvt_fp32_to_int64(at::vec::Vectorized<float> src) {
|
| 513 |
+
at::vec::VectorizedN<int64_t,2> result;
|
| 514 |
+
# if defined(CPU_CAPABILITY_AVX512)
|
| 515 |
+
result[0] = _mm512_cvt_roundps_epi64(_mm512_castps512_ps256(src), _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
|
| 516 |
+
result[1] = _mm512_cvt_roundps_epi64(_mm512_extractf32x8_ps(src, 1), _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC);
|
| 517 |
+
# elif defined(CPU_CAPABILITY_AVX2)
|
| 518 |
+
auto int32_vec = at::vec::convert_to_int_of_same_size(src);
|
| 519 |
+
result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(int32_vec));
|
| 520 |
+
result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(int32_vec, 1));
|
| 521 |
+
# else
|
| 522 |
+
constexpr int float_vec_size = at::vec::Vectorized<float>::size();
|
| 523 |
+
constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
|
| 524 |
+
__at_align__ float src_buf[float_vec_size];
|
| 525 |
+
__at_align__ int64_t result_buf[int64_vec_size];
|
| 526 |
+
src.store(src_buf);
|
| 527 |
+
for (int i = 0; i < 2; i++) {
|
| 528 |
+
for (int j = 0; j < int64_vec_size; j++) {
|
| 529 |
+
result_buf[j] = static_cast<int64_t>(src_buf[i * int64_vec_size + j]);
|
| 530 |
+
}
|
| 531 |
+
result[i] = at::vec::Vectorized<int64_t>::loadu(result_buf);
|
| 532 |
+
}
|
| 533 |
+
# endif
|
| 534 |
+
return result;
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
inline at::vec::Vectorized<int32_t> cvt_int64_to_int32(at::vec::VectorizedN<int64_t,2> src) {
|
| 538 |
+
# if defined(CPU_CAPABILITY_AVX512)
|
| 539 |
+
auto low = _mm512_cvtepi64_epi32(src[0]);
|
| 540 |
+
auto high = _mm512_cvtepi64_epi32(src[1]);
|
| 541 |
+
return _mm512_inserti32x8(_mm512_castsi256_si512(low), high, 1);
|
| 542 |
+
# elif defined(CPU_CAPABILITY_AVX2)
|
| 543 |
+
auto low = _mm256_shuffle_epi32(src[0], _MM_SHUFFLE(2, 0, 2, 0));
|
| 544 |
+
auto high = _mm256_shuffle_epi32(src[1], _MM_SHUFFLE(2, 0, 2, 0));
|
| 545 |
+
auto low_perm = _mm256_permute4x64_epi64(low, _MM_SHUFFLE(3, 1, 2, 0));
|
| 546 |
+
auto high_perm = _mm256_permute4x64_epi64(high, _MM_SHUFFLE(3, 1, 2, 0));
|
| 547 |
+
return _mm256_blend_epi32(low_perm, high_perm, 0xF0);
|
| 548 |
+
# else
|
| 549 |
+
constexpr int int32_vec_size = at::vec::Vectorized<int32_t>::size();
|
| 550 |
+
constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
|
| 551 |
+
__at_align__ int32_t result[int32_vec_size];
|
| 552 |
+
__at_align__ int64_t src_buf[int64_vec_size];
|
| 553 |
+
for (int i = 0; i < 2; i++) {
|
| 554 |
+
src[i].store(src_buf + i * int64_vec_size);
|
| 555 |
+
for (int j = 0; j < int64_vec_size; j++) {
|
| 556 |
+
result[i * int64_vec_size + j] = static_cast<int32_t>(src_buf[i * int64_vec_size + j]);
|
| 557 |
+
}
|
| 558 |
+
}
|
| 559 |
+
return at::vec::Vectorized<int32_t>::loadu(result);
|
| 560 |
+
# endif
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
inline at::vec::VectorizedN<int64_t,2> cvt_int32_to_int64(at::vec::Vectorized<int32_t> src) {
|
| 564 |
+
at::vec::VectorizedN<int64_t,2> result;
|
| 565 |
+
# if defined(CPU_CAPABILITY_AVX512)
|
| 566 |
+
result[0] = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(src));
|
| 567 |
+
result[1] = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(src, 1));
|
| 568 |
+
# elif defined(CPU_CAPABILITY_AVX2)
|
| 569 |
+
result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(src));
|
| 570 |
+
result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(src, 1));
|
| 571 |
+
#else
|
| 572 |
+
constexpr int int32_vec_size = at::vec::Vectorized<int32_t>::size();
|
| 573 |
+
constexpr int int64_vec_size = at::vec::Vectorized<int64_t>::size();
|
| 574 |
+
__at_align__ int32_t src_buf[int32_vec_size];
|
| 575 |
+
__at_align__ int64_t result_buf[int64_vec_size];
|
| 576 |
+
src.store(src_buf);
|
| 577 |
+
for (int i = 0; i < 2; i++) {
|
| 578 |
+
for (int j = 0; j < int64_vec_size; j++) {
|
| 579 |
+
result_buf[j] = static_cast<int64_t>(src_buf[i * int64_vec_size + j]);
|
| 580 |
+
}
|
| 581 |
+
result[i] = at::vec::Vectorized<int64_t>::loadu(result_buf);
|
| 582 |
+
}
|
| 583 |
+
# endif
|
| 584 |
+
return result;
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
inline at::vec::VectorizedN<int64_t,2> mask_convert_to_int64(at::vec::Vectorized<float> src) {
|
| 588 |
+
return cvt_fp32_to_int64(mask_convert_to_float(src));
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
inline at::vec::Vectorized<float> to_float_mask(at::vec::VectorizedN<int64_t,2> src) {
|
| 592 |
+
return to_float_mask(cvt_int64_to_int32(src));
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
#endif
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py
ADDED
|
@@ -0,0 +1,1851 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from itertools import count
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import sympy
|
| 8 |
+
from sympy import Expr
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch._ops
|
| 12 |
+
from .. import config, ir
|
| 13 |
+
|
| 14 |
+
from ..codecache import CudaKernelParamCache
|
| 15 |
+
from ..utils import cache_on_self, sympy_product
|
| 16 |
+
from ..virtualized import V
|
| 17 |
+
from .common import IndentedBuffer
|
| 18 |
+
from .wrapper import EnterSubgraphLine, ExitSubgraphLine, pexpr, WrapperCodeGen
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CppWrapperCpu(WrapperCodeGen):
|
| 22 |
+
"""
|
| 23 |
+
Generates cpp wrapper for running on CPU and calls cpp kernels
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
if not hasattr(self, "device"):
|
| 28 |
+
self.device = "cpu"
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.declare = "auto "
|
| 31 |
+
self.declare_maybe_reference = "decltype(auto) "
|
| 32 |
+
self.ending = ";"
|
| 33 |
+
self.open_bracket = "{"
|
| 34 |
+
self.closed_bracket = "}"
|
| 35 |
+
self.comment = "//"
|
| 36 |
+
self.namespace = "at::"
|
| 37 |
+
self.none_str = "nullptr" if config.abi_compatible else "at::Tensor()"
|
| 38 |
+
self.extern_call_ops = set()
|
| 39 |
+
self.size = "sizes()"
|
| 40 |
+
self.stride = "strides()"
|
| 41 |
+
self.cuda = False
|
| 42 |
+
self.supports_intermediate_hooks = False
|
| 43 |
+
self.outputs_need_copy = set()
|
| 44 |
+
self.kernel_callsite_id = count()
|
| 45 |
+
self.int_array_id = count() # for int array local variable declarations
|
| 46 |
+
self.declared_int_array_vars = set()
|
| 47 |
+
self.tmp_tensor_id = count() # for tmp tensor local variable declarations
|
| 48 |
+
self.arg_var_id = count()
|
| 49 |
+
self.used_cached_devices = set()
|
| 50 |
+
self.used_cached_dtypes = set()
|
| 51 |
+
self.cached_output_id = count()
|
| 52 |
+
self.scalar_to_tensor_id = count()
|
| 53 |
+
|
| 54 |
+
from .cpp import cexpr, CppPrinter
|
| 55 |
+
|
| 56 |
+
self.expr_printer = cexpr
|
| 57 |
+
|
| 58 |
+
# CppPrinter sometimes calls at::native functions which causes problems in
|
| 59 |
+
# the ABI-compatible mode. Currently we are hitting this problem when codegen
|
| 60 |
+
# Grid computation expressions, but we my need to fix other size computation
|
| 61 |
+
# as well.
|
| 62 |
+
class GridExprCppPrinter(CppPrinter):
|
| 63 |
+
def _print_FloorDiv(self, expr):
|
| 64 |
+
x, div = expr.args
|
| 65 |
+
x = self.paren(self.doprint(x))
|
| 66 |
+
div = self.paren(self.doprint(div))
|
| 67 |
+
assert expr.is_integer, "Expect integers in GridExprPrinter"
|
| 68 |
+
return f"({x}/{div})"
|
| 69 |
+
|
| 70 |
+
self.grid_expr_printer = GridExprCppPrinter().doprint
|
| 71 |
+
|
| 72 |
+
def generate_kernel_call(
|
| 73 |
+
self,
|
| 74 |
+
name,
|
| 75 |
+
call_args,
|
| 76 |
+
grid=None,
|
| 77 |
+
device_index=None,
|
| 78 |
+
cuda=True,
|
| 79 |
+
triton=True,
|
| 80 |
+
arg_types=None,
|
| 81 |
+
grid_fn: str = "grid",
|
| 82 |
+
triton_meta=None,
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Generates kernel call code.
|
| 86 |
+
|
| 87 |
+
cuda: Defines whether the backend is GPU. Otherwise the backend is CPU.
|
| 88 |
+
|
| 89 |
+
triton: Defines whether the GPU backend uses Triton for codegen.
|
| 90 |
+
Otherwise it uses the CUDA language for codegen.
|
| 91 |
+
Only valid when cuda == True.
|
| 92 |
+
"""
|
| 93 |
+
if cuda:
|
| 94 |
+
return super().generate_kernel_call(
|
| 95 |
+
name,
|
| 96 |
+
call_args,
|
| 97 |
+
grid,
|
| 98 |
+
device_index,
|
| 99 |
+
cuda,
|
| 100 |
+
triton,
|
| 101 |
+
arg_types,
|
| 102 |
+
grid_fn,
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
if config.abi_compatible:
|
| 106 |
+
assert arg_types is not None and len(call_args) == len(
|
| 107 |
+
arg_types
|
| 108 |
+
), "Mismatch call_args and arg_types in generate_kernel_call"
|
| 109 |
+
new_args = []
|
| 110 |
+
for idx, arg in enumerate(call_args):
|
| 111 |
+
if "*" in arg_types[idx]:
|
| 112 |
+
var_name = f"var_{next(self.arg_var_id)}"
|
| 113 |
+
self.writeline(
|
| 114 |
+
f"auto* {var_name} = get_data_ptr_wrapper({arg});"
|
| 115 |
+
)
|
| 116 |
+
new_args.append(f"({arg_types[idx]})({var_name})")
|
| 117 |
+
else:
|
| 118 |
+
# arg is a scalar
|
| 119 |
+
new_args.append(arg)
|
| 120 |
+
self.writeline(self.wrap_kernel_call(name, new_args))
|
| 121 |
+
else:
|
| 122 |
+
self.writeline(self.wrap_kernel_call(name, call_args))
|
| 123 |
+
|
| 124 |
+
def write_constant(self, name, hashed):
|
| 125 |
+
# include a hash so our code cache gives different constants different files
|
| 126 |
+
self.header.writeline(f"// {name} {hashed}")
|
| 127 |
+
|
| 128 |
+
def write_header(self):
|
| 129 |
+
if V.graph.is_const_graph:
|
| 130 |
+
# We do not write header for constant graph, it will be written by main module.
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
if V.graph.aot_mode:
|
| 134 |
+
for header_cpp_file in ("interface.cpp", "implementation.cpp"):
|
| 135 |
+
with open(
|
| 136 |
+
os.path.join(
|
| 137 |
+
os.path.dirname(__file__), "aoti_runtime", header_cpp_file
|
| 138 |
+
)
|
| 139 |
+
) as f:
|
| 140 |
+
self.header.splice(f.read())
|
| 141 |
+
else:
|
| 142 |
+
self.header.splice(
|
| 143 |
+
"""
|
| 144 |
+
import torch
|
| 145 |
+
from torch._inductor.codecache import CppWrapperCodeCache
|
| 146 |
+
|
| 147 |
+
cpp_wrapper_src = (
|
| 148 |
+
'''
|
| 149 |
+
"""
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if config.abi_compatible:
|
| 153 |
+
if config.c_shim_version == "1":
|
| 154 |
+
self.header.splice("#include <torch/csrc/inductor/aoti_torch/c/shim.h>")
|
| 155 |
+
else:
|
| 156 |
+
self.header.splice(
|
| 157 |
+
f"#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{self.device}.h>"
|
| 158 |
+
)
|
| 159 |
+
self.header.splice(
|
| 160 |
+
"""
|
| 161 |
+
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
|
| 162 |
+
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
|
| 163 |
+
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
|
| 164 |
+
"""
|
| 165 |
+
)
|
| 166 |
+
if V.graph.aot_mode:
|
| 167 |
+
self.header.splice(
|
| 168 |
+
"""
|
| 169 |
+
#include <torch/csrc/inductor/aoti_runtime/model.h>
|
| 170 |
+
"""
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
self.header.splice(
|
| 174 |
+
"""
|
| 175 |
+
#include <ATen/ATen.h>
|
| 176 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 177 |
+
#include <ATen/native/BinaryOps.h>
|
| 178 |
+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
| 179 |
+
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
| 180 |
+
#include <torch/csrc/inductor/inductor_ops.h>
|
| 181 |
+
#include <torch/types.h>
|
| 182 |
+
#include <ATen/ops/bernoulli_native.h>
|
| 183 |
+
|
| 184 |
+
#define reinterpret_tensor torch::inductor::_reinterpret_tensor
|
| 185 |
+
#define alloc_from_pool torch::inductor::_alloc_from_pool
|
| 186 |
+
"""
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
self.header.splice("#include <c10/util/generic_math.h>")
|
| 190 |
+
|
| 191 |
+
if not V.graph.aot_mode:
|
| 192 |
+
self.header.splice(
|
| 193 |
+
"""
|
| 194 |
+
#include <pybind11/pybind11.h>
|
| 195 |
+
|
| 196 |
+
using namespace torch::aot_inductor;
|
| 197 |
+
"""
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
from .memory_planning import ALIGN_BYTES
|
| 201 |
+
|
| 202 |
+
# Round up to the nearest multiple of ALIGN_BYTES
|
| 203 |
+
# ALIGN_BYTES must be a power of 2
|
| 204 |
+
self.header.splice(
|
| 205 |
+
f"""
|
| 206 |
+
[[maybe_unused]] static int64_t align(int64_t nbytes) {{
|
| 207 |
+
return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES};
|
| 208 |
+
}}
|
| 209 |
+
"""
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def mark_output_type(self):
|
| 213 |
+
# mark output type to unwrap tensor back to python scalar
|
| 214 |
+
from ..ir import ShapeAsConstantBuffer
|
| 215 |
+
|
| 216 |
+
output_is_tensor = dict()
|
| 217 |
+
for idx, x in enumerate(V.graph.graph_outputs):
|
| 218 |
+
if isinstance(x, ShapeAsConstantBuffer):
|
| 219 |
+
output_is_tensor[idx] = False
|
| 220 |
+
else:
|
| 221 |
+
output_is_tensor[idx] = True
|
| 222 |
+
|
| 223 |
+
self.output_is_tensor = output_is_tensor
|
| 224 |
+
|
| 225 |
+
def write_prefix(self):
|
| 226 |
+
if V.graph.is_const_graph:
|
| 227 |
+
# We do not write prefix for constant graph, it will be written by main module.
|
| 228 |
+
return
|
| 229 |
+
|
| 230 |
+
if V.graph.aot_mode:
|
| 231 |
+
self.prefix.writeline("namespace torch {")
|
| 232 |
+
self.prefix.writeline("namespace aot_inductor {")
|
| 233 |
+
|
| 234 |
+
def write_input_output_info(
|
| 235 |
+
self,
|
| 236 |
+
info_kind: str,
|
| 237 |
+
idx: int,
|
| 238 |
+
name: str,
|
| 239 |
+
):
|
| 240 |
+
self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""")
|
| 241 |
+
|
| 242 |
+
@staticmethod
|
| 243 |
+
def get_input_cpp_type(input):
|
| 244 |
+
assert config.use_minimal_arrayref_interface
|
| 245 |
+
from .cpp import DTYPE_TO_CPP
|
| 246 |
+
|
| 247 |
+
if isinstance(input, sympy.Expr):
|
| 248 |
+
from ..graph import may_get_constant_buffer_dtype
|
| 249 |
+
|
| 250 |
+
dtype = may_get_constant_buffer_dtype(input)
|
| 251 |
+
assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}"
|
| 252 |
+
return DTYPE_TO_CPP[dtype]
|
| 253 |
+
return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>"
|
| 254 |
+
|
| 255 |
+
def write_wrapper_decl(self):
|
| 256 |
+
inputs_len = len(V.graph.graph_inputs.keys())
|
| 257 |
+
if V.graph.aot_mode:
|
| 258 |
+
if config.use_minimal_arrayref_interface and not V.graph.is_const_graph:
|
| 259 |
+
from .cpp import DTYPE_TO_CPP
|
| 260 |
+
|
| 261 |
+
input_cpp_types = ", ".join(
|
| 262 |
+
f"{CppWrapperCpu.get_input_cpp_type(x)}"
|
| 263 |
+
for x in V.graph.graph_inputs.values()
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
output_arrayref_types = ", ".join(
|
| 267 |
+
f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>"
|
| 268 |
+
for x in V.graph.graph_outputs
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
self.prefix.splice(
|
| 272 |
+
f"""
|
| 273 |
+
using AOTInductorModelInputs = std::tuple<{input_cpp_types}>;
|
| 274 |
+
using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>;
|
| 275 |
+
"""
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if V.graph.const_module:
|
| 279 |
+
self.header.splice(V.graph.const_module.wrapper_code.header)
|
| 280 |
+
self.prefix.splice(V.graph.const_code)
|
| 281 |
+
|
| 282 |
+
if V.graph.is_const_graph:
|
| 283 |
+
self.prefix.splice(
|
| 284 |
+
"""
|
| 285 |
+
void AOTInductorModel::_const_run_impl(
|
| 286 |
+
std::vector<AtenTensorHandle>& output_handles,
|
| 287 |
+
DeviceStreamType stream,
|
| 288 |
+
AOTIProxyExecutorHandle proxy_executor
|
| 289 |
+
) {
|
| 290 |
+
"""
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
if not config.aot_inductor.use_runtime_constant_folding:
|
| 294 |
+
# If we do not split the constant graph, we'll just create
|
| 295 |
+
# an empty implementation when wrapping the main module.
|
| 296 |
+
self.prefix.splice(
|
| 297 |
+
"""
|
| 298 |
+
void AOTInductorModel::_const_run_impl(
|
| 299 |
+
std::vector<AtenTensorHandle>& output_handles,
|
| 300 |
+
DeviceStreamType stream,
|
| 301 |
+
AOTIProxyExecutorHandle proxy_executor
|
| 302 |
+
) {}
|
| 303 |
+
|
| 304 |
+
"""
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
run_impl_proto = """
|
| 308 |
+
void AOTInductorModel::run_impl(
|
| 309 |
+
AtenTensorHandle*
|
| 310 |
+
input_handles, // array of input AtenTensorHandle; handles
|
| 311 |
+
// are stolen; the array itself is borrowed
|
| 312 |
+
AtenTensorHandle*
|
| 313 |
+
output_handles, // array for writing output AtenTensorHandle; handles
|
| 314 |
+
// will be stolen by the caller; the array itself is
|
| 315 |
+
// borrowed
|
| 316 |
+
DeviceStreamType stream,
|
| 317 |
+
AOTIProxyExecutorHandle proxy_executor
|
| 318 |
+
) {
|
| 319 |
+
"""
|
| 320 |
+
if config.use_minimal_arrayref_interface:
|
| 321 |
+
self.prefix.splice(
|
| 322 |
+
"""
|
| 323 |
+
template <>
|
| 324 |
+
AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface<
|
| 325 |
+
AOTInductorModelInputs, AOTInductorModelOutputs>(
|
| 326 |
+
const AOTInductorModelInputs& inputs,
|
| 327 |
+
DeviceStreamType stream,
|
| 328 |
+
AOTIProxyExecutorHandle proxy_executor
|
| 329 |
+
) {
|
| 330 |
+
"""
|
| 331 |
+
)
|
| 332 |
+
self.suffix.splice(run_impl_proto)
|
| 333 |
+
self.suffix.splice(
|
| 334 |
+
"""
|
| 335 |
+
AOTInductorModelInputs inputs;
|
| 336 |
+
convert_handles_to_inputs(input_handles, inputs);
|
| 337 |
+
auto outputs = run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
|
| 338 |
+
inputs, stream, proxy_executor);
|
| 339 |
+
// NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this
|
| 340 |
+
// interface to perform well for a DSO using the minimal arrayref interface, all we need
|
| 341 |
+
// to do is provide ThreadLocalCachedTensor for each one!
|
| 342 |
+
convert_outputs_to_handles(outputs, output_handles);
|
| 343 |
+
}
|
| 344 |
+
"""
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
self.suffix.splice(
|
| 348 |
+
"""
|
| 349 |
+
extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface(
|
| 350 |
+
AOTInductorModelHandle model_handle,
|
| 351 |
+
const AOTInductorModelInputs& inputs,
|
| 352 |
+
AOTInductorModelOutputs& outputs) {
|
| 353 |
+
auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
|
| 354 |
+
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
| 355 |
+
outputs = model->run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
|
| 356 |
+
inputs,
|
| 357 |
+
(torch::aot_inductor::DeviceStreamType)nullptr,
|
| 358 |
+
nullptr);
|
| 359 |
+
})
|
| 360 |
+
}
|
| 361 |
+
"""
|
| 362 |
+
)
|
| 363 |
+
else:
|
| 364 |
+
self.prefix.splice(run_impl_proto)
|
| 365 |
+
else:
|
| 366 |
+
self.prefix.splice(
|
| 367 |
+
"""
|
| 368 |
+
void inductor_entry_impl(
|
| 369 |
+
AtenTensorHandle*
|
| 370 |
+
input_handles, // array of input AtenTensorHandle; handles
|
| 371 |
+
// are stolen; the array itself is borrowed
|
| 372 |
+
AtenTensorHandle*
|
| 373 |
+
output_handles // array for writing output AtenTensorHandle; handles
|
| 374 |
+
// will be stolen by the caller; the array itself is
|
| 375 |
+
// borrowed)
|
| 376 |
+
) {
|
| 377 |
+
"""
|
| 378 |
+
)
|
| 379 |
+
with self.prefix.indent():
|
| 380 |
+
# assign inputs and outputs in both cases so the later codegen can be simplified
|
| 381 |
+
if not config.use_minimal_arrayref_interface:
|
| 382 |
+
if not V.graph.is_const_graph:
|
| 383 |
+
if V.graph.aot_mode:
|
| 384 |
+
num_args = len(V.graph.graph_inputs)
|
| 385 |
+
else:
|
| 386 |
+
# Weights are promoted in the JIT mode
|
| 387 |
+
num_args = len(V.graph.graph_inputs) + len(V.graph.constants)
|
| 388 |
+
self.prefix.splice(
|
| 389 |
+
"""
|
| 390 |
+
pybind11::gil_scoped_release release;
|
| 391 |
+
"""
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if config.abi_compatible:
|
| 395 |
+
self.prefix.splice(
|
| 396 |
+
f"""
|
| 397 |
+
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args});
|
| 398 |
+
"""
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
# This looks dumb, but can avoid creating two versions of code in the AOTInductor runtime.
|
| 402 |
+
self.prefix.splice(
|
| 403 |
+
f"""
|
| 404 |
+
auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, {num_args});
|
| 405 |
+
"""
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
if inputs_len != 0:
|
| 409 |
+
for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
|
| 410 |
+
if config.use_minimal_arrayref_interface:
|
| 411 |
+
self.prefix.writeline(
|
| 412 |
+
f"auto {input_key} = std::get<{idx}>(inputs);"
|
| 413 |
+
)
|
| 414 |
+
continue
|
| 415 |
+
# unwrap input tensor back to scalar
|
| 416 |
+
if isinstance(V.graph.graph_inputs[input_key], sympy.Expr):
|
| 417 |
+
from ..graph import may_get_constant_buffer_dtype
|
| 418 |
+
from .cpp import DTYPE_TO_CPP
|
| 419 |
+
|
| 420 |
+
dtype = may_get_constant_buffer_dtype(
|
| 421 |
+
V.graph.graph_inputs[input_key]
|
| 422 |
+
)
|
| 423 |
+
assert (
|
| 424 |
+
dtype is not None
|
| 425 |
+
), "Fails to get the dtype of the sympy.Expr"
|
| 426 |
+
cpp_dtype = DTYPE_TO_CPP[dtype]
|
| 427 |
+
if config.abi_compatible:
|
| 428 |
+
self.prefix.writeline(f"{cpp_dtype} {input_key};")
|
| 429 |
+
dtype_str = str(dtype).split(".")[-1]
|
| 430 |
+
self.prefix.writeline(
|
| 431 |
+
f"aoti_torch_item_{dtype_str}(inputs[{idx}], &{input_key});"
|
| 432 |
+
)
|
| 433 |
+
else:
|
| 434 |
+
self.prefix.writeline(
|
| 435 |
+
f"{cpp_dtype} {input_key} = inputs[{idx}].item<{cpp_dtype}>();"
|
| 436 |
+
)
|
| 437 |
+
else:
|
| 438 |
+
self.prefix.writeline(
|
| 439 |
+
f"auto {input_key} = std::move(inputs[{idx}]);"
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
assert all(
|
| 443 |
+
isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
|
| 444 |
+
), "Expect all constants to be Tensor"
|
| 445 |
+
for idx, constants_key in enumerate(V.graph.constants.keys()):
|
| 446 |
+
if V.graph.aot_mode:
|
| 447 |
+
# Weights are stored in constants_ and owned by RAIIAtenTensorHandle there.
|
| 448 |
+
# Don't call std::move here because it will cause constants_ to lose the ownership.
|
| 449 |
+
if config.abi_compatible:
|
| 450 |
+
self.prefix.writeline(
|
| 451 |
+
f"""auto {constants_key} = constants_->at({idx});"""
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
self.prefix.writeline(
|
| 455 |
+
f"auto {constants_key} = *tensor_handle_to_tensor_pointer("
|
| 456 |
+
+ f"""constants_->at({idx}));"""
|
| 457 |
+
)
|
| 458 |
+
else:
|
| 459 |
+
# Append constants as inputs to the graph
|
| 460 |
+
constants_idx = inputs_len + idx
|
| 461 |
+
self.prefix.writeline(
|
| 462 |
+
f"auto {constants_key} = inputs[{constants_idx}];"
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
self.codegen_inputs(self.prefix, V.graph.graph_inputs)
|
| 466 |
+
|
| 467 |
+
if V.graph.aot_mode:
|
| 468 |
+
if not V.graph.is_const_graph:
|
| 469 |
+
if config.use_minimal_arrayref_interface:
|
| 470 |
+
# TODO: input shape checking for regular tensor interface as well?
|
| 471 |
+
self.codegen_input_numel_asserts()
|
| 472 |
+
else:
|
| 473 |
+
self.prefix.writeline("inputs.clear();")
|
| 474 |
+
self.prefix.writeline(
|
| 475 |
+
"auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());"
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
def codegen_input_numel_asserts(self):
|
| 479 |
+
for name, buf in V.graph.graph_inputs.items():
|
| 480 |
+
if isinstance(buf, sympy.Expr):
|
| 481 |
+
continue
|
| 482 |
+
|
| 483 |
+
# comparing strides for 0 size tensor is tricky. Ignore them for now.
|
| 484 |
+
if sympy_product(buf.get_size()) == 0:
|
| 485 |
+
continue
|
| 486 |
+
numel = buf.get_numel()
|
| 487 |
+
self.prefix.writeline(f"assert_numel({name}, {numel});")
|
| 488 |
+
|
| 489 |
+
def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
|
| 490 |
+
if config.abi_compatible:
|
| 491 |
+
code.writeline(f"int64_t* {name}_size;")
|
| 492 |
+
code.writeline(
|
| 493 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));"
|
| 494 |
+
)
|
| 495 |
+
else:
|
| 496 |
+
super().codegen_input_size_var_decl(code, name)
|
| 497 |
+
|
| 498 |
+
def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
|
| 499 |
+
if config.abi_compatible:
|
| 500 |
+
code.writeline(f"int64_t* {name}_stride;")
|
| 501 |
+
code.writeline(
|
| 502 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));"
|
| 503 |
+
)
|
| 504 |
+
else:
|
| 505 |
+
super().codegen_input_stride_var_decl(code, name)
|
| 506 |
+
|
| 507 |
+
def codegen_model_kernels(self):
|
| 508 |
+
self.prefix.writeline("namespace {")
|
| 509 |
+
self.prefix.writeline(
|
| 510 |
+
"class AOTInductorModelKernels : public AOTInductorModelKernelsBase {"
|
| 511 |
+
)
|
| 512 |
+
self.prefix.writeline(" public:")
|
| 513 |
+
declare_kernel = set(self.src_to_kernel.values())
|
| 514 |
+
declare_kernel.update(
|
| 515 |
+
entry[0] for entry in self.user_defined_kernel_cache.values()
|
| 516 |
+
)
|
| 517 |
+
if V.graph.const_module:
|
| 518 |
+
declare_kernel.update(
|
| 519 |
+
V.graph.const_module.wrapper_code.src_to_kernel.values()
|
| 520 |
+
)
|
| 521 |
+
for kernel in declare_kernel:
|
| 522 |
+
self.prefix.writeline(f" CUfunction {kernel}{{nullptr}};")
|
| 523 |
+
self.prefix.writeline("};")
|
| 524 |
+
self.prefix.writeline("} // namespace")
|
| 525 |
+
|
| 526 |
+
def codegen_model_constructor(self):
|
| 527 |
+
"""
|
| 528 |
+
// Generated code example
|
| 529 |
+
AOTInductorModel::AOTInductorModel()
|
| 530 |
+
: AOTInductorModelBase(4, 1) {
|
| 531 |
+
inputs_info_[0].name = "input0";
|
| 532 |
+
inputs_info_[0].dtype = "torch.float16";
|
| 533 |
+
...
|
| 534 |
+
constants_info_[0].name = "L__self___weight";
|
| 535 |
+
constants_info_[0].dtype = at::kFloat;
|
| 536 |
+
constants_info_[0].offset = 0;
|
| 537 |
+
constants_info_[0].data_size = 8192;
|
| 538 |
+
constants_info_[0].shape = {64, 32};
|
| 539 |
+
constants_info_[0].stride = {32, 1};
|
| 540 |
+
...
|
| 541 |
+
outputs_info_[0].name = "output0";
|
| 542 |
+
outputs_info_[0].dtype = "torch.float16";
|
| 543 |
+
}
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
num_inputs = len(V.graph.graph_inputs)
|
| 547 |
+
num_outputs = len(V.graph.graph_outputs)
|
| 548 |
+
num_constants = len(V.graph.constants)
|
| 549 |
+
self.prefix.splice(
|
| 550 |
+
f"""
|
| 551 |
+
AOTInductorModel::AOTInductorModel(std::shared_ptr<ConstantMap> constants_map,
|
| 552 |
+
std::shared_ptr<std::vector<ConstantHandle>> constants_array,
|
| 553 |
+
const std::string& device_str,
|
| 554 |
+
std::optional<std::string> cubin_dir)
|
| 555 |
+
: AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir) {{
|
| 556 |
+
"""
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
with self.prefix.indent():
|
| 560 |
+
for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
|
| 561 |
+
assert not isinstance(
|
| 562 |
+
inp, sympy.Expr
|
| 563 |
+
), f"input {name=} cannot be symbolic"
|
| 564 |
+
self.write_input_output_info("inputs_info_", idx, name)
|
| 565 |
+
|
| 566 |
+
for idx, (name, tensor) in enumerate(V.graph.constants.items()):
|
| 567 |
+
assert isinstance(tensor, torch.Tensor)
|
| 568 |
+
self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""")
|
| 569 |
+
self.prefix.writeline(
|
| 570 |
+
f"constants_info_[{idx}].dtype = static_cast<int32_t>({self.codegen_dtype(tensor.dtype)});"
|
| 571 |
+
)
|
| 572 |
+
self.prefix.writeline(
|
| 573 |
+
f"constants_info_[{idx}].offset = {tensor.storage_offset()};"
|
| 574 |
+
)
|
| 575 |
+
self.prefix.writeline(
|
| 576 |
+
f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};"
|
| 577 |
+
)
|
| 578 |
+
from_folded = "true" if name in V.graph.folded_constants else "false"
|
| 579 |
+
self.prefix.writeline(
|
| 580 |
+
f"constants_info_[{idx}].from_folded = {from_folded};"
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
size_str = ", ".join([str(s) for s in tensor.size()])
|
| 584 |
+
self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};")
|
| 585 |
+
|
| 586 |
+
stride_str = ", ".join([str(s) for s in tensor.stride()])
|
| 587 |
+
self.prefix.writeline(
|
| 588 |
+
f"constants_info_[{idx}].stride = {{{stride_str}}};"
|
| 589 |
+
)
|
| 590 |
+
if name in V.graph.dynamo_flat_name_to_original_fqn:
|
| 591 |
+
original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get(
|
| 592 |
+
name, name
|
| 593 |
+
)
|
| 594 |
+
elif name in V.graph.allocated_constant_name:
|
| 595 |
+
original_fqn = V.graph.allocated_constant_name[name]
|
| 596 |
+
else:
|
| 597 |
+
raise AssertionError("original_fqn must be set for constant")
|
| 598 |
+
self.prefix.writeline(
|
| 599 |
+
f"""constants_info_[{idx}].original_fqn = "{original_fqn}";"""
|
| 600 |
+
)
|
| 601 |
+
self.prefix.writeline("update_constants_map(std::move(constants_map));")
|
| 602 |
+
self.prefix.writeline("update_constants_array(std::move(constants_array));")
|
| 603 |
+
|
| 604 |
+
def escape_string(x):
|
| 605 |
+
return (
|
| 606 |
+
x.replace("\\", "\\\\")
|
| 607 |
+
.replace('"', '\\"')
|
| 608 |
+
.replace("\n", "\\n")
|
| 609 |
+
.replace("\t", "\\t")
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
self.prefix.writeline(
|
| 613 |
+
f'in_spec_ = "{escape_string(config.aot_inductor.serialized_in_spec)}";'
|
| 614 |
+
)
|
| 615 |
+
self.prefix.writeline(
|
| 616 |
+
f'out_spec_ = "{escape_string(config.aot_inductor.serialized_out_spec)}";'
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
for idx, output in enumerate(V.graph.graph_outputs):
|
| 620 |
+
assert not isinstance(
|
| 621 |
+
output, sympy.Expr
|
| 622 |
+
), f"output {name=} cannot be symbolic"
|
| 623 |
+
name = f"output{idx}"
|
| 624 |
+
self.write_input_output_info("outputs_info_", idx, name)
|
| 625 |
+
|
| 626 |
+
self.prefix.writeline(
|
| 627 |
+
"this->kernels_ = std::make_unique<AOTInductorModelKernels>();"
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
self.prefix.writeline("}")
|
| 631 |
+
|
| 632 |
+
def codegen_const_run_driver(self):
|
| 633 |
+
"""
|
| 634 |
+
// Generated code example
|
| 635 |
+
std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl(
|
| 636 |
+
DeviceStreamType stream,
|
| 637 |
+
AOTIProxyExecutorHandle proxy_executor,
|
| 638 |
+
bool initialization
|
| 639 |
+
) {
|
| 640 |
+
std::unordered_map<std::string, AtenTensorHandle> folded_constants_map;
|
| 641 |
+
std::vector<AtenTensorHandle> output_handles;
|
| 642 |
+
// build up output_handles over here.
|
| 643 |
+
_const_run_impl(output_handles, stream, proxy_executor);
|
| 644 |
+
// build up folded_constants_map
|
| 645 |
+
return folded_constants_map;
|
| 646 |
+
}
|
| 647 |
+
"""
|
| 648 |
+
|
| 649 |
+
self.prefix.splice(
|
| 650 |
+
"""
|
| 651 |
+
std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl(
|
| 652 |
+
DeviceStreamType stream,
|
| 653 |
+
AOTIProxyExecutorHandle proxy_executor,
|
| 654 |
+
bool initialization
|
| 655 |
+
) {
|
| 656 |
+
"""
|
| 657 |
+
)
|
| 658 |
+
if not config.aot_inductor.use_runtime_constant_folding:
|
| 659 |
+
self.prefix.splice(
|
| 660 |
+
"""
|
| 661 |
+
if (!initialization) {
|
| 662 |
+
std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: "
|
| 663 |
+
<< "aot_inductor.use_runtime_constant_folding=False\\n";
|
| 664 |
+
}
|
| 665 |
+
return {};
|
| 666 |
+
}
|
| 667 |
+
"""
|
| 668 |
+
)
|
| 669 |
+
return
|
| 670 |
+
|
| 671 |
+
with self.prefix.indent():
|
| 672 |
+
# This is a mapping to the index of constant folding graph's output
|
| 673 |
+
const_index_mapping: List[Optional[Tuple[int, str]]] = [None] * len(
|
| 674 |
+
V.graph.const_output_index
|
| 675 |
+
)
|
| 676 |
+
for idx, (name, _) in enumerate(V.graph.constants.items()):
|
| 677 |
+
if name in V.graph.const_output_index:
|
| 678 |
+
const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload]
|
| 679 |
+
assert (
|
| 680 |
+
None not in const_index_mapping
|
| 681 |
+
), "Not all constant gets mapped for constant folding graph."
|
| 682 |
+
|
| 683 |
+
self.prefix.writeline(
|
| 684 |
+
f"""
|
| 685 |
+
std::unordered_map<std::string, AtenTensorHandle> folded_constants_map;
|
| 686 |
+
folded_constants_map.reserve({len(const_index_mapping)});
|
| 687 |
+
std::vector<AtenTensorHandle> output_handles({len(const_index_mapping)});
|
| 688 |
+
"""
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
self.prefix.splice(
|
| 692 |
+
"""
|
| 693 |
+
// The below assignment of output_handles to constants is not used directly.
|
| 694 |
+
// It's only used to memo the correspondence of handle and constants.
|
| 695 |
+
"""
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
for output_idx, (const_idx, _) in enumerate(const_index_mapping): # type: ignore[misc]
|
| 699 |
+
self.prefix.writeline(
|
| 700 |
+
f"output_handles[{output_idx}] = constants_->at({const_idx});"
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
self.prefix.writeline(
|
| 704 |
+
"_const_run_impl(output_handles, stream, proxy_executor);"
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
for output_idx, (_, const_name) in enumerate(const_index_mapping): # type: ignore[misc]
|
| 708 |
+
self.prefix.writeline(
|
| 709 |
+
f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];'
|
| 710 |
+
)
|
| 711 |
+
self.prefix.writeline("return folded_constants_map;")
|
| 712 |
+
|
| 713 |
+
self.prefix.writeline("}")
|
| 714 |
+
|
| 715 |
+
def generate(self, is_inference):
|
| 716 |
+
if V.graph.aot_mode and not V.graph.is_const_graph:
|
| 717 |
+
self.codegen_model_kernels()
|
| 718 |
+
self.codegen_model_constructor()
|
| 719 |
+
self.codegen_const_run_driver()
|
| 720 |
+
self.write_wrapper_decl()
|
| 721 |
+
return super().generate(is_inference)
|
| 722 |
+
|
| 723 |
+
def finalize_prefix(self):
|
| 724 |
+
cached_dtypes_buffer = IndentedBuffer()
|
| 725 |
+
if config.abi_compatible:
|
| 726 |
+
for dtype in self.used_cached_dtypes:
|
| 727 |
+
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});")
|
| 728 |
+
for device in self.used_cached_devices:
|
| 729 |
+
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});")
|
| 730 |
+
cached_dtypes_buffer.splice(self.prefix)
|
| 731 |
+
self.prefix = cached_dtypes_buffer
|
| 732 |
+
|
| 733 |
+
def define_kernel(
|
| 734 |
+
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=False
|
| 735 |
+
):
|
| 736 |
+
self.header.splice(f"\n{kernel}\n")
|
| 737 |
+
|
| 738 |
+
def codegen_scalar_to_tensor(self, output: str):
|
| 739 |
+
name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}"
|
| 740 |
+
self.wrapper_call.writeline(
|
| 741 |
+
f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});"
|
| 742 |
+
)
|
| 743 |
+
return name
|
| 744 |
+
|
| 745 |
+
@cache_on_self
|
| 746 |
+
def get_output_refs(self):
|
| 747 |
+
return [
|
| 748 |
+
f"torch::tensor({x.codegen_reference(self.wrapper_call)})"
|
| 749 |
+
if isinstance(x, ir.ShapeAsConstantBuffer) and not config.abi_compatible
|
| 750 |
+
else x.codegen_reference(self.wrapper_call)
|
| 751 |
+
for x in V.graph.graph_outputs
|
| 752 |
+
]
|
| 753 |
+
|
| 754 |
+
def generate_return(self, output_refs):
|
| 755 |
+
cst_names = V.graph.constants.keys()
|
| 756 |
+
arr_iface = (
|
| 757 |
+
not V.graph.is_const_graph and config.use_minimal_arrayref_interface
|
| 758 |
+
) # For brevity.
|
| 759 |
+
|
| 760 |
+
def use_thread_local_cached_output_tensor(idx, output):
|
| 761 |
+
cached_output_name = f"cached_output_{next(self.cached_output_id)}"
|
| 762 |
+
cache_type = "Array" if arr_iface else "Tensor"
|
| 763 |
+
self.wrapper_call.writeline(
|
| 764 |
+
f"thread_local ThreadLocalCachedOutput{cache_type}<std::decay_t<decltype({output})>> "
|
| 765 |
+
f"{cached_output_name}({output});"
|
| 766 |
+
)
|
| 767 |
+
if arr_iface:
|
| 768 |
+
self.wrapper_call.writeline(
|
| 769 |
+
f"{cached_output_name}.copy_data_from({output});"
|
| 770 |
+
)
|
| 771 |
+
output_entry = f"std::get<{idx}>(output_arrayref_tensors)"
|
| 772 |
+
element_type = f"std::decay_t<decltype({output_entry}.data()[0])>"
|
| 773 |
+
self.wrapper_call.writeline(
|
| 774 |
+
f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();"
|
| 775 |
+
)
|
| 776 |
+
else:
|
| 777 |
+
self.wrapper_call.writeline(
|
| 778 |
+
f"{cached_output_name}.copy_data_from({output});"
|
| 779 |
+
)
|
| 780 |
+
self.wrapper_call.writeline(
|
| 781 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));"
|
| 782 |
+
)
|
| 783 |
+
self.wrapper_call.writeline(
|
| 784 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), "
|
| 785 |
+
f"output_handles[{idx}]));"
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
if arr_iface:
|
| 789 |
+
self.wrapper_call.writeline(
|
| 790 |
+
"AOTInductorModelOutputs output_arrayref_tensors;"
|
| 791 |
+
)
|
| 792 |
+
for idx, output in enumerate(output_refs):
|
| 793 |
+
if config.abi_compatible:
|
| 794 |
+
output_buffer = V.graph.graph_outputs[idx]
|
| 795 |
+
if isinstance(output_buffer, ir.ShapeAsConstantBuffer):
|
| 796 |
+
# Need to wrap scalar into tensor as the main function returns a vector of tensors
|
| 797 |
+
output_tensor = self.codegen_scalar_to_tensor(output)
|
| 798 |
+
self.wrapper_call.writeline(
|
| 799 |
+
f"output_handles[{idx}] = {output_tensor}.release();"
|
| 800 |
+
)
|
| 801 |
+
continue
|
| 802 |
+
|
| 803 |
+
output_is_tensor_handle_expr = (
|
| 804 |
+
f"std::is_same_v<std::decay_t<decltype({output})>,"
|
| 805 |
+
"RAIIAtenTensorHandle> || "
|
| 806 |
+
f"std::is_same_v<std::decay_t<decltype({output})>,"
|
| 807 |
+
"AtenTensorHandle> || "
|
| 808 |
+
f"std::is_same_v<std::decay_t<decltype({output})>,"
|
| 809 |
+
"ConstantHandle>"
|
| 810 |
+
)
|
| 811 |
+
self.wrapper_call.writeline(
|
| 812 |
+
f"if constexpr ({output_is_tensor_handle_expr}) {{"
|
| 813 |
+
)
|
| 814 |
+
with self.wrapper_call.indent():
|
| 815 |
+
if arr_iface:
|
| 816 |
+
cached_output_name = (
|
| 817 |
+
f"cached_output_{next(self.cached_output_id)}"
|
| 818 |
+
)
|
| 819 |
+
output_value_type = f"std::decay_t<decltype(std::get<{idx}>(output_arrayref_tensors).data()[0])>"
|
| 820 |
+
self.wrapper_call.writeline(
|
| 821 |
+
f"thread_local RAIIAtenTensorHandle {cached_output_name};"
|
| 822 |
+
)
|
| 823 |
+
if output in cst_names:
|
| 824 |
+
# NOTE(return_constant): In some rare cases where we return
|
| 825 |
+
# a constant, we have to return a copy of this constant,
|
| 826 |
+
# because (1) constants are not owned by the Model instance
|
| 827 |
+
# (2) constants remain the same cross inference runs,
|
| 828 |
+
# assuming they are not updated at runtime Basically, we
|
| 829 |
+
# cannot release or transfer the ownership of any original
|
| 830 |
+
# constant to the user.
|
| 831 |
+
self.wrapper_call.writeline(
|
| 832 |
+
f"AtenTensorHandle {cached_output_name}_tmp;"
|
| 833 |
+
)
|
| 834 |
+
self.wrapper_call.writeline(
|
| 835 |
+
f"aoti_torch_clone({output}, &{cached_output_name}_tmp);"
|
| 836 |
+
)
|
| 837 |
+
self.wrapper_call.writeline(
|
| 838 |
+
f"{cached_output_name} = {cached_output_name}_tmp;"
|
| 839 |
+
)
|
| 840 |
+
else:
|
| 841 |
+
self.wrapper_call.writeline(
|
| 842 |
+
f"{cached_output_name} = {output}.release();"
|
| 843 |
+
)
|
| 844 |
+
self.wrapper_call.writeline(
|
| 845 |
+
f"convert_handle_to_arrayref_tensor({cached_output_name}, "
|
| 846 |
+
f"std::get<{idx}>(output_arrayref_tensors));"
|
| 847 |
+
)
|
| 848 |
+
else:
|
| 849 |
+
if output in cst_names:
|
| 850 |
+
# See NOTE(return_constant) above.
|
| 851 |
+
self.wrapper_call.writeline(
|
| 852 |
+
f"aoti_torch_clone({output}, &output_handles[{idx}]);"
|
| 853 |
+
)
|
| 854 |
+
else:
|
| 855 |
+
self.wrapper_call.writeline(
|
| 856 |
+
f"output_handles[{idx}] = {output}.release();"
|
| 857 |
+
)
|
| 858 |
+
self.wrapper_call.writeline("} else {")
|
| 859 |
+
with self.wrapper_call.indent():
|
| 860 |
+
use_thread_local_cached_output_tensor(idx, output)
|
| 861 |
+
self.wrapper_call.writeline("}")
|
| 862 |
+
|
| 863 |
+
else:
|
| 864 |
+
assert (
|
| 865 |
+
not arr_iface
|
| 866 |
+
), "minimal ArrayRef interface is only supported in ABI-compatible mode"
|
| 867 |
+
if output in cst_names:
|
| 868 |
+
output_expr = f"{output}.clone()"
|
| 869 |
+
# See NOTE(return_constant) above.
|
| 870 |
+
else:
|
| 871 |
+
output_expr = output
|
| 872 |
+
self.wrapper_call.writeline(
|
| 873 |
+
f"output_handles[{idx}] = reinterpret_cast<AtenTensorHandle>("
|
| 874 |
+
+ f"new at::Tensor({output_expr}));"
|
| 875 |
+
)
|
| 876 |
+
if arr_iface:
|
| 877 |
+
self.wrapper_call.writeline("return output_arrayref_tensors;")
|
| 878 |
+
|
| 879 |
+
def generate_before_suffix(self, result):
|
| 880 |
+
if not V.graph.is_const_graph:
|
| 881 |
+
if V.graph.aot_mode:
|
| 882 |
+
result.writeline("} // AOTInductorModel::run_impl")
|
| 883 |
+
else:
|
| 884 |
+
result.writeline("} // inductor_entry_impl")
|
| 885 |
+
|
| 886 |
+
def generate_end(self, result):
|
| 887 |
+
if V.graph.aot_mode:
|
| 888 |
+
if V.graph.is_const_graph:
|
| 889 |
+
result.writeline("} // AOTInductorModel::_const_run_impl")
|
| 890 |
+
else:
|
| 891 |
+
result.writeline("} // namespace aot_inductor")
|
| 892 |
+
result.writeline("} // namespace torch")
|
| 893 |
+
return
|
| 894 |
+
|
| 895 |
+
result.writeline("'''\n)")
|
| 896 |
+
result.splice(
|
| 897 |
+
f"""
|
| 898 |
+
inductor_entry = CppWrapperCodeCache.load_pybinding(
|
| 899 |
+
["std::vector<at::Tensor>"], cpp_wrapper_src, {self.cuda}, {len(V.graph.graph_outputs)})
|
| 900 |
+
"""
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
# unwrap output tensor back to python scalar
|
| 904 |
+
if all(x for x in self.output_is_tensor.values()):
|
| 905 |
+
# If no ShapeAsConstantBuffer in the output, directly return the output as tensors
|
| 906 |
+
return_str = "return f(args_tensor)"
|
| 907 |
+
else:
|
| 908 |
+
outputs = [
|
| 909 |
+
f"outputs[{i}]" if self.output_is_tensor[i] else f"outputs[{i}].item()"
|
| 910 |
+
for i in range(len(V.graph.graph_outputs))
|
| 911 |
+
]
|
| 912 |
+
outputs_str = f"[{', '.join(outputs)}]"
|
| 913 |
+
return_str = f"""
|
| 914 |
+
outputs = f(args_tensor)
|
| 915 |
+
return {outputs_str}
|
| 916 |
+
"""
|
| 917 |
+
|
| 918 |
+
args_str = "args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]"
|
| 919 |
+
if V.graph.constants:
|
| 920 |
+
# Append constants to the input args for cpp wrapper.
|
| 921 |
+
# Python wrapper directly gets the value inside the wrapper call
|
| 922 |
+
# as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__).
|
| 923 |
+
# For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly.
|
| 924 |
+
assert all(
|
| 925 |
+
isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
|
| 926 |
+
), "Expect all constants to be Tensor"
|
| 927 |
+
constants_str = f"[{', '.join(V.graph.constants.keys())}]"
|
| 928 |
+
args_str += f"""
|
| 929 |
+
constants_tensor = {constants_str}
|
| 930 |
+
args_tensor.extend(constants_tensor)
|
| 931 |
+
"""
|
| 932 |
+
|
| 933 |
+
# Wrap the func to support setting result._boxed_call = True
|
| 934 |
+
result.splice(
|
| 935 |
+
f"""
|
| 936 |
+
def _wrap_func(f):
|
| 937 |
+
def g(args):
|
| 938 |
+
{args_str}
|
| 939 |
+
{return_str}
|
| 940 |
+
return g
|
| 941 |
+
call = _wrap_func(inductor_entry)
|
| 942 |
+
"""
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
def generate_c_shim_extern_kernel_call(self, kernel, args):
|
| 946 |
+
# In the abi_compatible mode, we call fallback aten ops through a C shim layer
|
| 947 |
+
self.allow_stack_allocation = False
|
| 948 |
+
kernel_tokens = kernel.split("::")
|
| 949 |
+
kernel_suffix = kernel_tokens[-1]
|
| 950 |
+
if kernel_suffix == "call":
|
| 951 |
+
kernel_suffix = kernel_tokens[-2]
|
| 952 |
+
if config.c_shim_version == "1":
|
| 953 |
+
shim_fn = f"aoti_torch_{kernel_suffix}"
|
| 954 |
+
else:
|
| 955 |
+
shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}"
|
| 956 |
+
|
| 957 |
+
# HACK: val_to_arg_str jams multiple arguments together using a comma. If that
|
| 958 |
+
# ever breaks, it needs to be reworked to be able to return multiple arguments,
|
| 959 |
+
# and the split-on-comma code here needs to be removed.
|
| 960 |
+
wrapped_args = []
|
| 961 |
+
for x in args:
|
| 962 |
+
pieces = x.split(", ")
|
| 963 |
+
for piece in pieces:
|
| 964 |
+
# We only really *need* convert_arrayref_tensor_to_tensor for
|
| 965 |
+
# ArrayRefTensors. The code flowing into here uses `0` for nullptr,
|
| 966 |
+
# which convert_arrayref_tensor_to_tensor would blindly coerce to int,
|
| 967 |
+
# so just avoid wrapping integers.
|
| 968 |
+
if not piece.isdigit():
|
| 969 |
+
piece = f"convert_arrayref_tensor_to_tensor({piece})"
|
| 970 |
+
wrapped_args.append(piece)
|
| 971 |
+
self.writeline(
|
| 972 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(wrapped_args)}));"
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args):
|
| 976 |
+
# registered output buffer name
|
| 977 |
+
name = extern_kernel.name
|
| 978 |
+
output_handle_name = f"{name}_handle"
|
| 979 |
+
self.writeline(f"AtenTensorHandle {output_handle_name};")
|
| 980 |
+
output_arg = f"&{output_handle_name}"
|
| 981 |
+
self.generate_c_shim_extern_kernel_call(
|
| 982 |
+
extern_kernel.get_kernel_name(), args + [output_arg]
|
| 983 |
+
)
|
| 984 |
+
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
|
| 985 |
+
|
| 986 |
+
def generate_extern_kernel_alloc(self, extern_kernel, args):
|
| 987 |
+
if config.abi_compatible:
|
| 988 |
+
self.generate_c_shim_extern_kernel_alloc(extern_kernel, args)
|
| 989 |
+
else:
|
| 990 |
+
super().generate_extern_kernel_alloc(extern_kernel, args)
|
| 991 |
+
|
| 992 |
+
def generate_c_shim_fallback_kernel(self, fallback_kernel, args):
|
| 993 |
+
output_args = []
|
| 994 |
+
output_raii_handles = []
|
| 995 |
+
output_name_base = fallback_kernel.get_name()
|
| 996 |
+
for idx, output in enumerate(fallback_kernel.outputs):
|
| 997 |
+
if isinstance(output, ir.MultiOutput):
|
| 998 |
+
name = f"{output.get_name()}"
|
| 999 |
+
output_handle_name = f"{name}_handle"
|
| 1000 |
+
if output.indices:
|
| 1001 |
+
assert (
|
| 1002 |
+
output.indices[0][1] == idx
|
| 1003 |
+
), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
|
| 1004 |
+
self.writeline(f"AtenTensorHandle {output_handle_name};")
|
| 1005 |
+
output_args.append(f"&{output_handle_name}")
|
| 1006 |
+
output_raii_handles.append(
|
| 1007 |
+
f"RAIIAtenTensorHandle {name}({output_handle_name});"
|
| 1008 |
+
)
|
| 1009 |
+
elif isinstance(output, int):
|
| 1010 |
+
output_name = f"{output_name_base}_{idx}"
|
| 1011 |
+
self.writeline(f"int64_t {output_name} = {output};")
|
| 1012 |
+
output_args.append(f"&{output_name}")
|
| 1013 |
+
elif output is None:
|
| 1014 |
+
output_args.append("nullptr")
|
| 1015 |
+
else:
|
| 1016 |
+
raise NotImplementedError("unsupported type of {output=}")
|
| 1017 |
+
args = args + output_args
|
| 1018 |
+
assert (
|
| 1019 |
+
fallback_kernel.abi_compatible_kernel is not None
|
| 1020 |
+
), f"abi_compatible_kernel is None for {fallback_kernel.python_kernel_name=}"
|
| 1021 |
+
self.generate_c_shim_extern_kernel_call(
|
| 1022 |
+
fallback_kernel.abi_compatible_kernel, args
|
| 1023 |
+
)
|
| 1024 |
+
for raii_handle in output_raii_handles:
|
| 1025 |
+
self.writeline(raii_handle)
|
| 1026 |
+
|
| 1027 |
+
def generate_fallback_kernel(self, fallback_kernel, args):
|
| 1028 |
+
if config.abi_compatible:
|
| 1029 |
+
self.generate_c_shim_fallback_kernel(fallback_kernel, args)
|
| 1030 |
+
else:
|
| 1031 |
+
super().generate_fallback_kernel(fallback_kernel, args)
|
| 1032 |
+
|
| 1033 |
+
def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
|
| 1034 |
+
if output_view:
|
| 1035 |
+
output_as_strided = f"{output_view.codegen_reference()}"
|
| 1036 |
+
output_name = f"{output_view.get_name()}_as_strided"
|
| 1037 |
+
self.writeline(f"auto {output_name} = {output_as_strided};")
|
| 1038 |
+
|
| 1039 |
+
args.insert(0, output_name)
|
| 1040 |
+
else:
|
| 1041 |
+
args.insert(0, f"{codegen_reference}")
|
| 1042 |
+
|
| 1043 |
+
if config.abi_compatible:
|
| 1044 |
+
self.generate_c_shim_extern_kernel_call(kernel, args)
|
| 1045 |
+
else:
|
| 1046 |
+
self.writeline(self.wrap_kernel_call(kernel, args))
|
| 1047 |
+
|
| 1048 |
+
def generate_user_defined_triton_kernel(
|
| 1049 |
+
self, kernel_name, grid, configs, args, triton_meta
|
| 1050 |
+
):
|
| 1051 |
+
assert len(grid) != 0
|
| 1052 |
+
if len(grid) == 1:
|
| 1053 |
+
grid_decision = grid[0]
|
| 1054 |
+
else:
|
| 1055 |
+
meta = CudaKernelParamCache.get(kernel_name)
|
| 1056 |
+
assert meta is not None
|
| 1057 |
+
grid_decision = None
|
| 1058 |
+
for i, c in enumerate(configs):
|
| 1059 |
+
if all(arg == meta["meta"][key] for key, arg in c.kwargs.items()):
|
| 1060 |
+
grid_decision = grid[i]
|
| 1061 |
+
break
|
| 1062 |
+
assert grid_decision is not None
|
| 1063 |
+
|
| 1064 |
+
self.generate_kernel_call(
|
| 1065 |
+
kernel_name,
|
| 1066 |
+
args,
|
| 1067 |
+
grid=grid_decision,
|
| 1068 |
+
device_index=V.graph.scheduler.current_device.index,
|
| 1069 |
+
cuda=True,
|
| 1070 |
+
triton=True,
|
| 1071 |
+
triton_meta=triton_meta,
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
def generate_scatter_fallback(
|
| 1075 |
+
self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs
|
| 1076 |
+
):
|
| 1077 |
+
# TODO: support other overload for cpp wrapper and remove the below assertions
|
| 1078 |
+
if config.abi_compatible:
|
| 1079 |
+
# call the ABI shim function instead of the ATen one
|
| 1080 |
+
kernel = kernel.replace("at::", "aoti_torch_")
|
| 1081 |
+
line = f"{kernel}({output}, {','.join(map(str, inputs))}"
|
| 1082 |
+
if python_kernel_name == "aten.scatter_":
|
| 1083 |
+
if src_is_tensor:
|
| 1084 |
+
if reduce:
|
| 1085 |
+
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
|
| 1086 |
+
else:
|
| 1087 |
+
assert (
|
| 1088 |
+
reduce is None
|
| 1089 |
+
), "Expect reduce to be None for aten.scatter_ with scalar src"
|
| 1090 |
+
else:
|
| 1091 |
+
line += f", {','.join(kwargs)}"
|
| 1092 |
+
line += f"){self.ending}"
|
| 1093 |
+
self.writeline(line)
|
| 1094 |
+
|
| 1095 |
+
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
|
| 1096 |
+
if V.graph.aot_mode and V.graph.cpp_wrapper and config.abi_compatible:
|
| 1097 |
+
# See the comment in codegen_reinterpret_view about why having something like
|
| 1098 |
+
# RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding
|
| 1099 |
+
# tensor prematurely deallocated, thus this std::vector().data() trick here.
|
| 1100 |
+
indices_str = (
|
| 1101 |
+
f"std::vector<AtenTensorHandle>{{{', '.join(indices)}}}.data()"
|
| 1102 |
+
)
|
| 1103 |
+
args = [x, indices_str, str(len(indices)), values, accumulate]
|
| 1104 |
+
else:
|
| 1105 |
+
indices_str = (
|
| 1106 |
+
f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}"
|
| 1107 |
+
)
|
| 1108 |
+
args = [x, indices_str, values, accumulate]
|
| 1109 |
+
|
| 1110 |
+
args.insert(0, x) # set x as the output tensor, this fallback mutates x.
|
| 1111 |
+
self.writeline(self.wrap_kernel_call(kernel, args))
|
| 1112 |
+
|
| 1113 |
+
def add_benchmark_harness(self, output):
|
| 1114 |
+
if V.graph.aot_mode:
|
| 1115 |
+
return
|
| 1116 |
+
super().add_benchmark_harness(output)
|
| 1117 |
+
|
| 1118 |
+
def codegen_sizevar(self, x: Expr) -> str:
|
| 1119 |
+
return self.expr_printer(V.graph.sizevars.simplify(x))
|
| 1120 |
+
|
| 1121 |
+
def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
|
| 1122 |
+
if config.abi_compatible:
|
| 1123 |
+
# in the abi_compatible mode, outputs are returned via arguments
|
| 1124 |
+
return name
|
| 1125 |
+
else:
|
| 1126 |
+
return f"std::get<{index}>({basename})"
|
| 1127 |
+
|
| 1128 |
+
def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
|
| 1129 |
+
parts = list(map(self.codegen_sizevar, shape))
|
| 1130 |
+
if len(parts) == 0:
|
| 1131 |
+
return "{}"
|
| 1132 |
+
if len(parts) == 1:
|
| 1133 |
+
return f"{{{parts[0]}, }}"
|
| 1134 |
+
return f"{{{', '.join(parts)}}}"
|
| 1135 |
+
|
| 1136 |
+
def codegen_dynamic_scalar(self, node):
|
| 1137 |
+
from .cpp import DTYPE_TO_ATEN, DTYPE_TO_CPP
|
| 1138 |
+
|
| 1139 |
+
(data,) = (t.codegen_reference() for t in node.inputs)
|
| 1140 |
+
if config.abi_compatible:
|
| 1141 |
+
dtype = node.inputs[0].get_dtype()
|
| 1142 |
+
dtype_str = str(dtype).split(".")[-1]
|
| 1143 |
+
self.writeline(f"{DTYPE_TO_CPP[dtype]} {node.sym};")
|
| 1144 |
+
self.writeline(f"aoti_torch_item_{dtype_str}({data}, &{node.sym});")
|
| 1145 |
+
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
| 1146 |
+
self.unbacked_symbol_decls.add(str(node.sym))
|
| 1147 |
+
else:
|
| 1148 |
+
if node.is_bool:
|
| 1149 |
+
self.writeline(f"bool {node.sym} = {data}.item() ? 1 : 0;")
|
| 1150 |
+
else:
|
| 1151 |
+
convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace(
|
| 1152 |
+
"at::k", "to"
|
| 1153 |
+
)
|
| 1154 |
+
self.writeline(f"auto {node.sym} = {data}.item().{convert_type}();")
|
| 1155 |
+
|
| 1156 |
+
def can_stack_allocate_buffer(self, buffer):
|
| 1157 |
+
return (
|
| 1158 |
+
self.allow_stack_allocation
|
| 1159 |
+
and buffer.get_device().type == "cpu"
|
| 1160 |
+
and self.can_prove_buffer_has_static_shape(buffer)
|
| 1161 |
+
and ir.is_contiguous_strides_for_shape(
|
| 1162 |
+
buffer.get_stride(), buffer.get_size()
|
| 1163 |
+
)
|
| 1164 |
+
)
|
| 1165 |
+
|
| 1166 |
+
def make_buffer_free(self, buffer):
|
| 1167 |
+
return (
|
| 1168 |
+
""
|
| 1169 |
+
if isinstance(buffer.get_layout(), ir.MultiOutputLayout)
|
| 1170 |
+
or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers)
|
| 1171 |
+
or (
|
| 1172 |
+
config.use_minimal_arrayref_interface
|
| 1173 |
+
and V.graph.aot_mode
|
| 1174 |
+
and buffer.get_name() in V.graph.graph_inputs
|
| 1175 |
+
)
|
| 1176 |
+
else f"{buffer.get_name()}.reset();"
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
def make_free_by_names(self, names_to_del: List[str]):
|
| 1180 |
+
return " ".join(f"{name}.reset();" for name in names_to_del)
|
| 1181 |
+
|
| 1182 |
+
def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
|
| 1183 |
+
if config.abi_compatible:
|
| 1184 |
+
return f"auto {new_name} = std::move({old_name}); // reuse"
|
| 1185 |
+
else:
|
| 1186 |
+
return super().codegen_exact_buffer_reuse(old_name, new_name, del_line)
|
| 1187 |
+
|
| 1188 |
+
def generate_profiler_mark_wrapper_call(self, stack):
|
| 1189 |
+
self.wrapper_call.writeline(
|
| 1190 |
+
'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef<c10::IValue>());'
|
| 1191 |
+
)
|
| 1192 |
+
|
| 1193 |
+
def write_triton_header_once(self):
|
| 1194 |
+
pass
|
| 1195 |
+
|
| 1196 |
+
def generate_start_graph(self):
|
| 1197 |
+
pass
|
| 1198 |
+
|
| 1199 |
+
def generate_end_graph(self):
|
| 1200 |
+
pass
|
| 1201 |
+
|
| 1202 |
+
def generate_inf_and_nan_checker(self, nodes):
|
| 1203 |
+
for buf in nodes.get_names():
|
| 1204 |
+
# TODO: Add buf name directly into check_inf_and_nan.
|
| 1205 |
+
self.writeline(
|
| 1206 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_check_inf_and_nan({buf}));"
|
| 1207 |
+
)
|
| 1208 |
+
|
| 1209 |
+
def codegen_device(self, device):
|
| 1210 |
+
if config.abi_compatible:
|
| 1211 |
+
self.used_cached_devices.add(device.type)
|
| 1212 |
+
return f"cached_torch_device_type_{device.type},{device.index if device.index else 0}"
|
| 1213 |
+
else:
|
| 1214 |
+
from .cpp import DEVICE_TO_ATEN
|
| 1215 |
+
|
| 1216 |
+
return (
|
| 1217 |
+
f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})"
|
| 1218 |
+
if device.index is not None
|
| 1219 |
+
else f"{DEVICE_TO_ATEN[device.type]}"
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
+
def codegen_dtype(self, dtype):
|
| 1223 |
+
if config.abi_compatible:
|
| 1224 |
+
dtype_str = str(dtype).split(".")[-1]
|
| 1225 |
+
self.used_cached_dtypes.add(dtype_str)
|
| 1226 |
+
return f"cached_torch_dtype_{dtype_str}"
|
| 1227 |
+
else:
|
| 1228 |
+
from .cpp import DTYPE_TO_ATEN
|
| 1229 |
+
|
| 1230 |
+
return DTYPE_TO_ATEN[dtype]
|
| 1231 |
+
|
| 1232 |
+
@functools.lru_cache(None)
|
| 1233 |
+
def codegen_int_array_var(
|
| 1234 |
+
self,
|
| 1235 |
+
int_array: str,
|
| 1236 |
+
writer=None,
|
| 1237 |
+
known_statically=False,
|
| 1238 |
+
graph=None, # for per-graph caching
|
| 1239 |
+
):
|
| 1240 |
+
# Because the memory planning is done in two passes (see the implementation
|
| 1241 |
+
# of self.generate), the writeline behavior is different in the two passes.
|
| 1242 |
+
# As a result, the emitted int array declarations may appear in a later
|
| 1243 |
+
# position of the generated code, so the second pass codegen should not
|
| 1244 |
+
# reuse int array declarations generated in the first pass
|
| 1245 |
+
if writer is None:
|
| 1246 |
+
# The first pass codegen uses `self` as the writer
|
| 1247 |
+
writer = self
|
| 1248 |
+
|
| 1249 |
+
var = f"int_array_{next(self.int_array_id)}"
|
| 1250 |
+
if var not in self.declared_int_array_vars:
|
| 1251 |
+
self.declared_int_array_vars.add(var)
|
| 1252 |
+
if known_statically:
|
| 1253 |
+
writer.writeline(f"static constexpr int64_t {var}[] = {int_array};")
|
| 1254 |
+
else:
|
| 1255 |
+
writer.writeline(f"int64_t {var}[] = {int_array};")
|
| 1256 |
+
return var
|
| 1257 |
+
|
| 1258 |
+
def make_buffer_allocation(self, buffer):
|
| 1259 |
+
return self.make_allocation(
|
| 1260 |
+
buffer.get_name(),
|
| 1261 |
+
buffer.get_device(),
|
| 1262 |
+
buffer.get_dtype(),
|
| 1263 |
+
buffer.get_size(),
|
| 1264 |
+
buffer.get_stride(),
|
| 1265 |
+
buffer if self.can_stack_allocate_buffer(buffer) else None,
|
| 1266 |
+
)
|
| 1267 |
+
|
| 1268 |
+
def make_allocation(
|
| 1269 |
+
self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None
|
| 1270 |
+
):
|
| 1271 |
+
orig_stride = stride
|
| 1272 |
+
device_str = self.codegen_device(device)
|
| 1273 |
+
dtype_code = self.codegen_dtype(dtype)
|
| 1274 |
+
size = self.codegen_shape_tuple(shape)
|
| 1275 |
+
stride = self.codegen_shape_tuple(orig_stride)
|
| 1276 |
+
if config.abi_compatible:
|
| 1277 |
+
size_array_var = self.codegen_int_array_var(
|
| 1278 |
+
size,
|
| 1279 |
+
self.wrapper_call,
|
| 1280 |
+
known_statically=self.is_statically_known_list_of_ints(shape),
|
| 1281 |
+
graph=self.get_codegened_graph(),
|
| 1282 |
+
)
|
| 1283 |
+
stride_array_var = self.codegen_int_array_var(
|
| 1284 |
+
stride,
|
| 1285 |
+
self.wrapper_call,
|
| 1286 |
+
known_statically=self.is_statically_known_list_of_ints(orig_stride),
|
| 1287 |
+
graph=self.get_codegened_graph(),
|
| 1288 |
+
)
|
| 1289 |
+
device_type, device_id = device_str.split(",")
|
| 1290 |
+
device_idx = "this->device_idx_" if V.graph.aot_mode else device_id
|
| 1291 |
+
if buffer_if_can_stack_allocate is not None:
|
| 1292 |
+
from .cpp import DTYPE_TO_CPP
|
| 1293 |
+
|
| 1294 |
+
self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate
|
| 1295 |
+
cpp_type = DTYPE_TO_CPP[dtype]
|
| 1296 |
+
numel = buffer_if_can_stack_allocate.get_numel()
|
| 1297 |
+
# Note: we don't zero storage because empty_strided doesn't zero either.
|
| 1298 |
+
self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];")
|
| 1299 |
+
args = [
|
| 1300 |
+
f"{name}_storage",
|
| 1301 |
+
size_array_var,
|
| 1302 |
+
stride_array_var,
|
| 1303 |
+
device_type,
|
| 1304 |
+
device_idx,
|
| 1305 |
+
]
|
| 1306 |
+
return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});"
|
| 1307 |
+
|
| 1308 |
+
args = [
|
| 1309 |
+
str(len(shape)),
|
| 1310 |
+
size_array_var,
|
| 1311 |
+
stride_array_var,
|
| 1312 |
+
dtype_code,
|
| 1313 |
+
device_type,
|
| 1314 |
+
device_idx,
|
| 1315 |
+
f"&{name}_handle",
|
| 1316 |
+
]
|
| 1317 |
+
|
| 1318 |
+
self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;")
|
| 1319 |
+
self.wrapper_call.writeline(
|
| 1320 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));"
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
return f"RAIIAtenTensorHandle {name}({name}_handle);"
|
| 1324 |
+
|
| 1325 |
+
if V.graph.aot_mode and device_str.startswith("c10::Device("):
|
| 1326 |
+
tensor_device = f"{device_str.split(',')[0]}, this->device_idx_)"
|
| 1327 |
+
else:
|
| 1328 |
+
tensor_device = device_str
|
| 1329 |
+
|
| 1330 |
+
if device.type == "cpu":
|
| 1331 |
+
return f"at::Tensor {name} = at::detail::empty_strided_cpu({size}, {stride}, {dtype_code});"
|
| 1332 |
+
if device.type == "cuda":
|
| 1333 |
+
return (
|
| 1334 |
+
f"at::Tensor {name} = at::detail::empty_strided_cuda("
|
| 1335 |
+
f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);"
|
| 1336 |
+
)
|
| 1337 |
+
return (
|
| 1338 |
+
f"{self.declare}{name} = {self.namespace}empty_strided("
|
| 1339 |
+
f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}"
|
| 1340 |
+
)
|
| 1341 |
+
|
| 1342 |
+
def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
|
| 1343 |
+
if config.abi_compatible:
|
| 1344 |
+
size = self.codegen_shape_tuple(shape)
|
| 1345 |
+
stride = self.codegen_shape_tuple(stride)
|
| 1346 |
+
tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
|
| 1347 |
+
args = [
|
| 1348 |
+
name,
|
| 1349 |
+
pexpr(offset), # bytes not numel
|
| 1350 |
+
self.codegen_dtype(dtype),
|
| 1351 |
+
str(len(shape)),
|
| 1352 |
+
self.codegen_int_array_var(
|
| 1353 |
+
size, self.wrapper_call, graph=self.get_codegened_graph()
|
| 1354 |
+
),
|
| 1355 |
+
self.codegen_int_array_var(
|
| 1356 |
+
stride, self.wrapper_call, graph=self.get_codegened_graph()
|
| 1357 |
+
),
|
| 1358 |
+
f"&{tmp_name}",
|
| 1359 |
+
]
|
| 1360 |
+
self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};")
|
| 1361 |
+
self.wrapper_call.writeline(
|
| 1362 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));"
|
| 1363 |
+
)
|
| 1364 |
+
return f"RAIIAtenTensorHandle({tmp_name})"
|
| 1365 |
+
|
| 1366 |
+
return "alloc_from_pool({})".format(
|
| 1367 |
+
", ".join(
|
| 1368 |
+
[
|
| 1369 |
+
name,
|
| 1370 |
+
pexpr(offset), # bytes not numel
|
| 1371 |
+
self.codegen_dtype(dtype),
|
| 1372 |
+
self.codegen_shape_tuple(shape),
|
| 1373 |
+
self.codegen_shape_tuple(stride),
|
| 1374 |
+
]
|
| 1375 |
+
)
|
| 1376 |
+
)
|
| 1377 |
+
|
| 1378 |
+
def codegen_reinterpret_view(
|
| 1379 |
+
self, data, size_list, stride_list, offset, writer
|
| 1380 |
+
) -> str:
|
| 1381 |
+
dim = str(len(size_list))
|
| 1382 |
+
size = self.codegen_shape_tuple(size_list)
|
| 1383 |
+
stride = self.codegen_shape_tuple(stride_list)
|
| 1384 |
+
offset = self.codegen_sizevar(offset)
|
| 1385 |
+
|
| 1386 |
+
if config.abi_compatible:
|
| 1387 |
+
tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
|
| 1388 |
+
# Because the memory planning is done in two passes (see the implementation
|
| 1389 |
+
# of self.generate), the writeline behavior is different in the two passes.
|
| 1390 |
+
if writer is None:
|
| 1391 |
+
writer = self
|
| 1392 |
+
|
| 1393 |
+
args = [
|
| 1394 |
+
f"{data.get_name()}",
|
| 1395 |
+
dim,
|
| 1396 |
+
self.codegen_int_array_var(
|
| 1397 |
+
size,
|
| 1398 |
+
writer,
|
| 1399 |
+
known_statically=self.is_statically_known_list_of_ints(size_list),
|
| 1400 |
+
graph=self.get_codegened_graph(),
|
| 1401 |
+
),
|
| 1402 |
+
self.codegen_int_array_var(
|
| 1403 |
+
stride,
|
| 1404 |
+
writer,
|
| 1405 |
+
known_statically=self.is_statically_known_list_of_ints(stride_list),
|
| 1406 |
+
graph=self.get_codegened_graph(),
|
| 1407 |
+
),
|
| 1408 |
+
offset,
|
| 1409 |
+
]
|
| 1410 |
+
|
| 1411 |
+
def gen_reinterpret_call(writer, args):
|
| 1412 |
+
writer.writeline(
|
| 1413 |
+
f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});"
|
| 1414 |
+
)
|
| 1415 |
+
|
| 1416 |
+
if (
|
| 1417 |
+
self.can_stack_allocate_buffer(data)
|
| 1418 |
+
and self.is_statically_known_list_of_ints(size_list)
|
| 1419 |
+
and self.is_statically_known_list_of_ints(stride_list)
|
| 1420 |
+
and ir.is_contiguous_strides_for_shape(stride_list, size_list)
|
| 1421 |
+
):
|
| 1422 |
+
gen_reinterpret_call(writer, args)
|
| 1423 |
+
return tmp_name
|
| 1424 |
+
|
| 1425 |
+
gen_reinterpret_call(writer, args)
|
| 1426 |
+
|
| 1427 |
+
# NB, the return handle here represents a temporary tensor, which will be automatically
|
| 1428 |
+
# released.
|
| 1429 |
+
# Here's a sample usage in the cpp wrapper code:
|
| 1430 |
+
# ```
|
| 1431 |
+
# aoti_torch_addmm_out(
|
| 1432 |
+
# buf1,
|
| 1433 |
+
# arg1_1,
|
| 1434 |
+
# RAIIAtenTensorHandle(tmp_tensor_handle_0),
|
| 1435 |
+
# buf0,
|
| 1436 |
+
# 1L,
|
| 1437 |
+
# 1L));
|
| 1438 |
+
# ```
|
| 1439 |
+
# RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out.
|
| 1440 |
+
# This could be problematic when it's used in a different pattern, for example:
|
| 1441 |
+
# ````
|
| 1442 |
+
# AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6};
|
| 1443 |
+
# aoti_torch_proxy_executor_call_function(..., tensor_args);
|
| 1444 |
+
# ````
|
| 1445 |
+
# RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter
|
| 1446 |
+
# kernel call.
|
| 1447 |
+
#
|
| 1448 |
+
# This is solved by updating the proxy_executor invocation to
|
| 1449 |
+
# ```
|
| 1450 |
+
# aoti_torch_proxy_executor_call_function(...,
|
| 1451 |
+
# std::vector<AtenTensorHandle>{
|
| 1452 |
+
# RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6
|
| 1453 |
+
# }.data()
|
| 1454 |
+
# );
|
| 1455 |
+
# ```
|
| 1456 |
+
return f"wrap_with_raii_handle_if_needed({tmp_name})"
|
| 1457 |
+
else:
|
| 1458 |
+
args = [data.get_name(), size, stride, offset]
|
| 1459 |
+
return f"reinterpret_tensor({', '.join(args)})"
|
| 1460 |
+
|
| 1461 |
+
def codegen_device_copy(self, src, dst):
|
| 1462 |
+
if config.abi_compatible:
|
| 1463 |
+
self.writeline(
|
| 1464 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));"
|
| 1465 |
+
)
|
| 1466 |
+
else:
|
| 1467 |
+
self.writeline(f"{dst}.copy_({src});")
|
| 1468 |
+
|
| 1469 |
+
def codegen_multi_output(self, name, value):
|
| 1470 |
+
# in the abi_compatible mode, outputs are retrieved by passing
|
| 1471 |
+
# output pointers, so we skip its codegen here.
|
| 1472 |
+
if not config.abi_compatible:
|
| 1473 |
+
super().codegen_multi_output(name, value)
|
| 1474 |
+
|
| 1475 |
+
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
|
| 1476 |
+
for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs):
|
| 1477 |
+
if config.abi_compatible:
|
| 1478 |
+
# in ABI-compatible mode, we copy the underlying at::Tensor of the conditional
|
| 1479 |
+
# input (outer_input) into another at::Tensor to be used as a subgraph input
|
| 1480 |
+
# (inner_input) in the nested scope. we can't std::move here, as the codegened
|
| 1481 |
+
# outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we
|
| 1482 |
+
# can't necessarily std::move it back to the origin (x).
|
| 1483 |
+
self.writeline(f"AtenTensorHandle {inner_input}_handle;")
|
| 1484 |
+
self.writeline(
|
| 1485 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));"
|
| 1486 |
+
)
|
| 1487 |
+
self.writeline(
|
| 1488 |
+
f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);"
|
| 1489 |
+
)
|
| 1490 |
+
else:
|
| 1491 |
+
self.writeline(
|
| 1492 |
+
f"{self.declare}{inner_input} = {outer_input}{self.ending}"
|
| 1493 |
+
)
|
| 1494 |
+
|
| 1495 |
+
def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
|
| 1496 |
+
for inner_output, outer_output in zip(
|
| 1497 |
+
subgraph.graph.graph_outputs, outer_outputs
|
| 1498 |
+
):
|
| 1499 |
+
src = inner_output.codegen_reference()
|
| 1500 |
+
if config.abi_compatible:
|
| 1501 |
+
# in ABI-compatible mode, we need to std::move subgraph output (inner_output)
|
| 1502 |
+
# to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
|
| 1503 |
+
# constructor is deleted.
|
| 1504 |
+
src = f"std::move({src})"
|
| 1505 |
+
self.writeline(f"{outer_output} = {src}{self.ending}")
|
| 1506 |
+
|
| 1507 |
+
def codegen_conditional(self, conditional):
|
| 1508 |
+
name = conditional.get_name()
|
| 1509 |
+
outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands]
|
| 1510 |
+
if config.abi_compatible:
|
| 1511 |
+
outer_outputs = []
|
| 1512 |
+
for out in conditional.outputs:
|
| 1513 |
+
# in ABI-compatible mode, ir.MultiOutput is not codegened,
|
| 1514 |
+
# hence pre-declare output variables directly and separately
|
| 1515 |
+
self.writeline(f"RAIIAtenTensorHandle {out.get_name()};")
|
| 1516 |
+
outer_outputs.append(out.get_name())
|
| 1517 |
+
predicate = f"{conditional.predicate.get_name()}_scalar"
|
| 1518 |
+
self.writeline(f"bool {predicate};")
|
| 1519 |
+
# in ABI-compatible mode, we need to use the ABI shim function
|
| 1520 |
+
# to extract a C++ bool from the unrelying scalar bool Tensor
|
| 1521 |
+
self.writeline(
|
| 1522 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({conditional.predicate.codegen_reference()}, &{predicate}));"
|
| 1523 |
+
)
|
| 1524 |
+
else:
|
| 1525 |
+
# in non-ABI-compatible mode, we can codegen the conditional outputs
|
| 1526 |
+
# as array of at::Tensor instances, as the ir.MultiOutput is codegened
|
| 1527 |
+
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
|
| 1528 |
+
self.writeline(f"at::Tensor {name}[{len(conditional.outputs)}];")
|
| 1529 |
+
predicate = f"{conditional.predicate.codegen_reference()}.item<bool>()"
|
| 1530 |
+
|
| 1531 |
+
self.writeline(f"if ({predicate}) {{")
|
| 1532 |
+
self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
|
| 1533 |
+
self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
|
| 1534 |
+
self.writeline(ExitSubgraphLine(self))
|
| 1535 |
+
self.writeline("} else {")
|
| 1536 |
+
self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
|
| 1537 |
+
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
|
| 1538 |
+
self.writeline(ExitSubgraphLine(self))
|
| 1539 |
+
self.writeline("}")
|
| 1540 |
+
|
| 1541 |
+
def generate_extern_kernel_args_decl_if_needed(
|
| 1542 |
+
self, op_overload, raw_args, output_args
|
| 1543 |
+
):
|
| 1544 |
+
arg_types = [x.real_type for x in op_overload._schema.arguments]
|
| 1545 |
+
return_types = [x.type for x in op_overload._schema.returns]
|
| 1546 |
+
|
| 1547 |
+
new_tensor_args = []
|
| 1548 |
+
new_int_args = []
|
| 1549 |
+
|
| 1550 |
+
def fill_args(arg, arg_type):
|
| 1551 |
+
static_arg_types = (
|
| 1552 |
+
torch.FloatType,
|
| 1553 |
+
torch.BoolType,
|
| 1554 |
+
torch.StringType,
|
| 1555 |
+
torch.Type,
|
| 1556 |
+
torch.DeviceObjType,
|
| 1557 |
+
)
|
| 1558 |
+
inductor_tensor_buffers = (
|
| 1559 |
+
ir.Buffer,
|
| 1560 |
+
ir.ReinterpretView,
|
| 1561 |
+
)
|
| 1562 |
+
|
| 1563 |
+
if isinstance(arg_type, torch.TensorType):
|
| 1564 |
+
assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}"
|
| 1565 |
+
new_tensor_args.append(f"{arg.codegen_reference()}")
|
| 1566 |
+
elif isinstance(arg_type, torch.IntType):
|
| 1567 |
+
# int
|
| 1568 |
+
new_int_args.append(str(arg))
|
| 1569 |
+
elif isinstance(arg_type, torch.SymIntType):
|
| 1570 |
+
# SymInt
|
| 1571 |
+
expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg
|
| 1572 |
+
new_int_args.append(self.expr_printer(expr))
|
| 1573 |
+
elif isinstance(arg_type, torch.NumberType):
|
| 1574 |
+
# Scalar of type int
|
| 1575 |
+
assert isinstance(arg, (int, float, bool))
|
| 1576 |
+
# Only treat int Scalar as dynamic
|
| 1577 |
+
if isinstance(arg, int):
|
| 1578 |
+
new_int_args.append(str(arg))
|
| 1579 |
+
elif isinstance(arg_type, torch.ListType):
|
| 1580 |
+
assert isinstance(arg, (list, tuple))
|
| 1581 |
+
|
| 1582 |
+
# List[Tensor]
|
| 1583 |
+
if isinstance(arg_type.getElementType(), torch.TensorType):
|
| 1584 |
+
new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg])
|
| 1585 |
+
# List[Optional[Tensor]]
|
| 1586 |
+
elif isinstance(
|
| 1587 |
+
arg_type.getElementType(), torch.OptionalType
|
| 1588 |
+
) and isinstance(
|
| 1589 |
+
arg_type.getElementType().getElementType(), torch.TensorType
|
| 1590 |
+
):
|
| 1591 |
+
new_tensor_args.extend(
|
| 1592 |
+
[f"{a.codegen_reference()}" for a in arg if a is not None]
|
| 1593 |
+
)
|
| 1594 |
+
# List[int]
|
| 1595 |
+
elif isinstance(arg_type.getElementType(), torch.IntType):
|
| 1596 |
+
new_int_args.extend([str(a) for a in arg])
|
| 1597 |
+
# List[SymInt]
|
| 1598 |
+
elif isinstance(arg_type.getElementType(), torch.SymIntType):
|
| 1599 |
+
expressions = [
|
| 1600 |
+
a.node.expr if isinstance(a, torch.SymInt) else a for a in arg
|
| 1601 |
+
]
|
| 1602 |
+
new_int_args.extend(
|
| 1603 |
+
[self.expr_printer(expr) for expr in expressions]
|
| 1604 |
+
)
|
| 1605 |
+
# List[Scalar]
|
| 1606 |
+
elif isinstance(arg_type.getElementType(), torch.NumberType):
|
| 1607 |
+
# Only treat int Scalar as dynamic
|
| 1608 |
+
is_int_type = [isinstance(a, int) for a in arg]
|
| 1609 |
+
if any(is_int_type):
|
| 1610 |
+
assert all(
|
| 1611 |
+
is_int_type
|
| 1612 |
+
), "AOTInductor only supports int scalars of the same type"
|
| 1613 |
+
new_int_args.extend([str(a) for a in arg])
|
| 1614 |
+
else:
|
| 1615 |
+
assert isinstance(
|
| 1616 |
+
arg_type.getElementType(), static_arg_types # type: ignore[arg-type]
|
| 1617 |
+
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
| 1618 |
+
else:
|
| 1619 |
+
assert isinstance(
|
| 1620 |
+
arg_type, static_arg_types # type: ignore[arg-type]
|
| 1621 |
+
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
| 1622 |
+
|
| 1623 |
+
for arg, arg_type in zip(raw_args, arg_types):
|
| 1624 |
+
if arg is not None:
|
| 1625 |
+
if isinstance(arg_type, torch.OptionalType):
|
| 1626 |
+
fill_args(arg, arg_type.getElementType())
|
| 1627 |
+
else:
|
| 1628 |
+
fill_args(arg, arg_type)
|
| 1629 |
+
|
| 1630 |
+
def fill_output_arg(arg, return_type):
|
| 1631 |
+
if isinstance(return_type, torch.TensorType):
|
| 1632 |
+
self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer")
|
| 1633 |
+
self.writeline(
|
| 1634 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));"
|
| 1635 |
+
)
|
| 1636 |
+
self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
|
| 1637 |
+
new_tensor_args.append(f"{arg}")
|
| 1638 |
+
elif isinstance(return_type, torch.SymIntType):
|
| 1639 |
+
raise NotImplementedError("NYI support for return type: SymInt")
|
| 1640 |
+
elif isinstance(return_type, torch.ListType) and isinstance(
|
| 1641 |
+
return_type.getElementType(), torch.SymIntType
|
| 1642 |
+
):
|
| 1643 |
+
raise NotImplementedError("NYI support for return type: List[SymInt]")
|
| 1644 |
+
else:
|
| 1645 |
+
raise AssertionError(f"Unsupported return type found: {return_type}")
|
| 1646 |
+
|
| 1647 |
+
# TODO: Only support tensor(s) returns for now, SymInt is not implemented yet
|
| 1648 |
+
for return_type in return_types:
|
| 1649 |
+
if isinstance(return_type, (torch.TensorType)):
|
| 1650 |
+
pass
|
| 1651 |
+
elif isinstance(return_type, torch.OptionalType):
|
| 1652 |
+
assert isinstance(return_type.getElementType(), torch.TensorType)
|
| 1653 |
+
elif isinstance(return_type, torch.ListType):
|
| 1654 |
+
assert isinstance(return_type.getElementType(), torch.TensorType)
|
| 1655 |
+
else:
|
| 1656 |
+
raise NotImplementedError(
|
| 1657 |
+
f"return type {return_type} is not yet supported."
|
| 1658 |
+
)
|
| 1659 |
+
|
| 1660 |
+
for output_arg in output_args:
|
| 1661 |
+
assert output_arg is not None, "Optional return types are not yet supported"
|
| 1662 |
+
if isinstance(output_arg, (list, tuple)):
|
| 1663 |
+
for out in output_arg:
|
| 1664 |
+
fill_output_arg(out, torch.TensorType.get())
|
| 1665 |
+
else:
|
| 1666 |
+
fill_output_arg(output_arg, torch.TensorType.get())
|
| 1667 |
+
|
| 1668 |
+
return new_tensor_args, new_int_args
|
| 1669 |
+
|
| 1670 |
+
def generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 1671 |
+
self,
|
| 1672 |
+
name,
|
| 1673 |
+
kernel,
|
| 1674 |
+
codegen_args,
|
| 1675 |
+
cpp_op_schema,
|
| 1676 |
+
cpp_kernel_key,
|
| 1677 |
+
cpp_kernel_overload_name="",
|
| 1678 |
+
op_overload=None,
|
| 1679 |
+
raw_args=None,
|
| 1680 |
+
outputs=None,
|
| 1681 |
+
):
|
| 1682 |
+
if config.is_fbcode():
|
| 1683 |
+
assert op_overload is not None
|
| 1684 |
+
assert raw_args is not None
|
| 1685 |
+
assert outputs is not None
|
| 1686 |
+
|
| 1687 |
+
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
|
| 1688 |
+
name,
|
| 1689 |
+
cpp_kernel_key,
|
| 1690 |
+
op_overload,
|
| 1691 |
+
raw_args,
|
| 1692 |
+
outputs,
|
| 1693 |
+
)
|
| 1694 |
+
else:
|
| 1695 |
+
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
|
| 1696 |
+
name,
|
| 1697 |
+
kernel,
|
| 1698 |
+
codegen_args,
|
| 1699 |
+
cpp_op_schema,
|
| 1700 |
+
cpp_kernel_key,
|
| 1701 |
+
cpp_kernel_overload_name,
|
| 1702 |
+
)
|
| 1703 |
+
|
| 1704 |
+
def generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
|
| 1705 |
+
self,
|
| 1706 |
+
name,
|
| 1707 |
+
kernel,
|
| 1708 |
+
codegen_args,
|
| 1709 |
+
cpp_op_schema,
|
| 1710 |
+
cpp_kernel_key,
|
| 1711 |
+
cpp_kernel_overload_name="",
|
| 1712 |
+
):
|
| 1713 |
+
if cpp_kernel_key not in self.extern_call_ops:
|
| 1714 |
+
self.writeline(
|
| 1715 |
+
f"static auto op_{cpp_kernel_key} = c10::Dispatcher::singleton()"
|
| 1716 |
+
)
|
| 1717 |
+
self.writeline(
|
| 1718 |
+
f'\t.findSchemaOrThrow("{kernel}", "{cpp_kernel_overload_name}")'
|
| 1719 |
+
)
|
| 1720 |
+
self.writeline(f"\t.typed<{cpp_op_schema}>();")
|
| 1721 |
+
self.extern_call_ops.add(cpp_kernel_key)
|
| 1722 |
+
|
| 1723 |
+
self.writeline(
|
| 1724 |
+
f"auto {name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});"
|
| 1725 |
+
)
|
| 1726 |
+
|
| 1727 |
+
def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
|
| 1728 |
+
self,
|
| 1729 |
+
name,
|
| 1730 |
+
cpp_kernel_key,
|
| 1731 |
+
op_overload,
|
| 1732 |
+
raw_args, # contains both args and flatten kwargs
|
| 1733 |
+
outputs,
|
| 1734 |
+
):
|
| 1735 |
+
def extract_output_name(out):
|
| 1736 |
+
assert out is not None, "None, i.e. optional output is not supported"
|
| 1737 |
+
if isinstance(out, ir.MultiOutput):
|
| 1738 |
+
return out.get_name()
|
| 1739 |
+
elif isinstance(out, (list, tuple)):
|
| 1740 |
+
return type(out)(extract_output_name(o) for o in out)
|
| 1741 |
+
else:
|
| 1742 |
+
raise AssertionError(f"Unexpected output: {type(out)}")
|
| 1743 |
+
|
| 1744 |
+
# output_args has the same pytree structure as outputs
|
| 1745 |
+
output_args = extract_output_name(outputs)
|
| 1746 |
+
if isinstance(output_args, str):
|
| 1747 |
+
output_args = [output_args]
|
| 1748 |
+
|
| 1749 |
+
(
|
| 1750 |
+
tensor_call_args,
|
| 1751 |
+
int_call_args,
|
| 1752 |
+
) = self.generate_extern_kernel_args_decl_if_needed(
|
| 1753 |
+
op_overload, raw_args, output_args
|
| 1754 |
+
)
|
| 1755 |
+
|
| 1756 |
+
tensor_call_args_str = ", ".join(tensor_call_args)
|
| 1757 |
+
int_call_args_str = ", ".join(int_call_args)
|
| 1758 |
+
|
| 1759 |
+
extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1
|
| 1760 |
+
|
| 1761 |
+
self.writeline(
|
| 1762 |
+
f"aoti_torch_proxy_executor_call_function(proxy_executor, "
|
| 1763 |
+
f"{extern_kernel_node_index}, "
|
| 1764 |
+
f"{len(int_call_args)}, "
|
| 1765 |
+
f"std::vector<int64_t>{{{int_call_args_str}}}.data(), "
|
| 1766 |
+
f"{len(tensor_call_args)}, "
|
| 1767 |
+
f"std::vector<AtenTensorHandle>{{{tensor_call_args_str}}}.data());"
|
| 1768 |
+
)
|
| 1769 |
+
|
| 1770 |
+
self.extern_call_ops.add(cpp_kernel_key)
|
| 1771 |
+
|
| 1772 |
+
def generate_reset_kernel_saved_flags(self):
|
| 1773 |
+
pass
|
| 1774 |
+
|
| 1775 |
+
def generate_save_uncompiled_kernels(self):
|
| 1776 |
+
pass
|
| 1777 |
+
|
| 1778 |
+
def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
|
| 1779 |
+
if (
|
| 1780 |
+
config.abi_compatible
|
| 1781 |
+
and not is_legacy_abi
|
| 1782 |
+
and isinstance(type_, torch.OptionalType)
|
| 1783 |
+
):
|
| 1784 |
+
if val is None:
|
| 1785 |
+
return "0" # nullptr is not available in C
|
| 1786 |
+
if not isinstance(type_.getElementType(), torch.TensorType):
|
| 1787 |
+
var_name = f"var_{next(self.arg_var_id)}"
|
| 1788 |
+
self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
|
| 1789 |
+
return f"&{var_name}"
|
| 1790 |
+
elif config.c_shim_version == "2":
|
| 1791 |
+
# Similar to other data type, use pointer to denote optional tensor arg in v2 C shim
|
| 1792 |
+
base_handle = self.val_to_arg_str(val)
|
| 1793 |
+
if "wrap_with_raii_handle_if_needed" in base_handle:
|
| 1794 |
+
# wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to
|
| 1795 |
+
# explicitly store it. Otherwise, it will be destroyed before the fallback kernel call.
|
| 1796 |
+
tmp_var_name = f"var_{next(self.arg_var_id)}"
|
| 1797 |
+
self.writeline(
|
| 1798 |
+
f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};"
|
| 1799 |
+
)
|
| 1800 |
+
base_handle = tmp_var_name
|
| 1801 |
+
var_name = f"var_{next(self.arg_var_id)}"
|
| 1802 |
+
self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();")
|
| 1803 |
+
return f"&{var_name}"
|
| 1804 |
+
|
| 1805 |
+
return self.val_to_arg_str(val)
|
| 1806 |
+
|
| 1807 |
+
def val_to_arg_str(self, val) -> str:
|
| 1808 |
+
if val is None:
|
| 1809 |
+
# When None is passed as an argument, it represents an optional that does not contain a value.
|
| 1810 |
+
if config.abi_compatible:
|
| 1811 |
+
return "0" # nullptr is not available in C
|
| 1812 |
+
return "c10::nullopt"
|
| 1813 |
+
elif isinstance(val, bool):
|
| 1814 |
+
if config.abi_compatible:
|
| 1815 |
+
return "1" if val else "0"
|
| 1816 |
+
else:
|
| 1817 |
+
return "true" if val else "false"
|
| 1818 |
+
elif isinstance(val, int):
|
| 1819 |
+
# uint64_t is long on Linux, but long long on MacOS
|
| 1820 |
+
return f"{val}LL" if sys.platform == "darwin" else f"{val}L"
|
| 1821 |
+
elif isinstance(val, str):
|
| 1822 |
+
return f'"{val}"'
|
| 1823 |
+
elif isinstance(
|
| 1824 |
+
val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox)
|
| 1825 |
+
):
|
| 1826 |
+
return val.codegen_reference()
|
| 1827 |
+
elif isinstance(val, torch.device):
|
| 1828 |
+
return self.codegen_device(val)
|
| 1829 |
+
elif isinstance(val, torch.dtype):
|
| 1830 |
+
return self.codegen_dtype(val)
|
| 1831 |
+
elif isinstance(val, float) and val in [float("inf"), float("-inf")]:
|
| 1832 |
+
if val == float("inf"):
|
| 1833 |
+
return "std::numeric_limits<float>::infinity()"
|
| 1834 |
+
else:
|
| 1835 |
+
return "-std::numeric_limits<float>::infinity()"
|
| 1836 |
+
elif isinstance(val, (list, tuple)):
|
| 1837 |
+
# FIXME handle embedded optional types?
|
| 1838 |
+
result = f"{{{', '.join(self.val_to_arg_str(x) for x in val)}}}"
|
| 1839 |
+
if config.abi_compatible:
|
| 1840 |
+
static = self.is_statically_known_list_of_ints(val)
|
| 1841 |
+
# Need to pass the array length because we can't use std::vector
|
| 1842 |
+
int_var_array = self.codegen_int_array_var(
|
| 1843 |
+
result,
|
| 1844 |
+
known_statically=static,
|
| 1845 |
+
graph=self.get_codegened_graph(),
|
| 1846 |
+
)
|
| 1847 |
+
return f"{int_var_array}, {len(val)}"
|
| 1848 |
+
else:
|
| 1849 |
+
return result
|
| 1850 |
+
else:
|
| 1851 |
+
return repr(val)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
from itertools import chain, count
|
| 4 |
+
from typing import Any, List, Optional, TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import sympy
|
| 7 |
+
|
| 8 |
+
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
|
| 9 |
+
|
| 10 |
+
from .. import config
|
| 11 |
+
from ..codecache import CudaKernelParamCache
|
| 12 |
+
from ..triton_heuristics import grid as default_grid
|
| 13 |
+
from ..virtualized import V
|
| 14 |
+
from .cpp_wrapper_cpu import CppWrapperCpu
|
| 15 |
+
from .wrapper import SymbolicCallArg
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from ..graph import GraphLowering
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def is_int(s: str) -> bool:
|
| 22 |
+
# Cpp code gen adds L at the end of ints
|
| 23 |
+
# Lets remove it for checking whether we have an int or not
|
| 24 |
+
if s and s[-1] == "L":
|
| 25 |
+
s = s[:-1]
|
| 26 |
+
try:
|
| 27 |
+
int(s)
|
| 28 |
+
except ValueError:
|
| 29 |
+
return False
|
| 30 |
+
except TypeError:
|
| 31 |
+
return False
|
| 32 |
+
return True
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def is_float(s: str) -> bool:
|
| 36 |
+
try:
|
| 37 |
+
float(s)
|
| 38 |
+
except ValueError:
|
| 39 |
+
return False
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CppWrapperCuda(CppWrapperCpu):
|
| 44 |
+
"""
|
| 45 |
+
Generates cpp wrapper for running on GPU and calls CUDA kernels
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self):
|
| 49 |
+
self.device = "cuda"
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.grid_id = count()
|
| 52 |
+
self.cuda = True
|
| 53 |
+
|
| 54 |
+
def write_header(self):
|
| 55 |
+
if V.graph.is_const_graph:
|
| 56 |
+
# We do not write header for constant graph, it will be written by main module.
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
super().write_header()
|
| 60 |
+
|
| 61 |
+
self.header.splice("#include <filesystem>")
|
| 62 |
+
if config.abi_compatible:
|
| 63 |
+
self.header.splice(
|
| 64 |
+
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
self.header.splice(
|
| 68 |
+
"""
|
| 69 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 70 |
+
#include <c10/cuda/CUDAStream.h>
|
| 71 |
+
#include <ATen/cuda/EmptyTensor.h>
|
| 72 |
+
"""
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.header.splice(
|
| 76 |
+
"""
|
| 77 |
+
#define CUDA_DRIVER_CHECK(EXPR) \\
|
| 78 |
+
do { \\
|
| 79 |
+
CUresult code = EXPR; \\
|
| 80 |
+
const char *msg; \\
|
| 81 |
+
cuGetErrorString(code, &msg); \\
|
| 82 |
+
if (code != CUDA_SUCCESS) { \\
|
| 83 |
+
throw std::runtime_error( \\
|
| 84 |
+
std::string("CUDA driver error: ") + \\
|
| 85 |
+
std::string(msg)); \\
|
| 86 |
+
} \\
|
| 87 |
+
} while (0);
|
| 88 |
+
|
| 89 |
+
namespace {
|
| 90 |
+
|
| 91 |
+
struct Grid {
|
| 92 |
+
Grid(uint32_t x, uint32_t y, uint32_t z)
|
| 93 |
+
: grid_x(x), grid_y(y), grid_z(z) {}
|
| 94 |
+
uint32_t grid_x;
|
| 95 |
+
uint32_t grid_y;
|
| 96 |
+
uint32_t grid_z;
|
| 97 |
+
|
| 98 |
+
bool is_non_zero() {
|
| 99 |
+
return grid_x > 0 && grid_y > 0 && grid_z > 0;
|
| 100 |
+
}
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
} // anonymous namespace
|
| 104 |
+
|
| 105 |
+
static inline CUfunction loadKernel(
|
| 106 |
+
std::string filePath,
|
| 107 |
+
const std::string &funcName,
|
| 108 |
+
uint32_t sharedMemBytes,
|
| 109 |
+
const std::optional<std::string> &cubinDir = std::nullopt) {
|
| 110 |
+
if (cubinDir) {
|
| 111 |
+
std::filesystem::path p1{*cubinDir};
|
| 112 |
+
std::filesystem::path p2{filePath};
|
| 113 |
+
filePath = (p1 / p2.filename()).string();
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
CUmodule mod;
|
| 117 |
+
CUfunction func;
|
| 118 |
+
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
|
| 119 |
+
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
|
| 120 |
+
if (sharedMemBytes > 0) {
|
| 121 |
+
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
|
| 122 |
+
func,
|
| 123 |
+
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
| 124 |
+
sharedMemBytes
|
| 125 |
+
))
|
| 126 |
+
}
|
| 127 |
+
return func;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
static inline void launchKernel(
|
| 131 |
+
CUfunction func,
|
| 132 |
+
uint32_t gridX,
|
| 133 |
+
uint32_t gridY,
|
| 134 |
+
uint32_t gridZ,
|
| 135 |
+
uint32_t numWarps,
|
| 136 |
+
uint32_t sharedMemBytes,
|
| 137 |
+
void* args[],
|
| 138 |
+
cudaStream_t stream) {
|
| 139 |
+
CUDA_DRIVER_CHECK(cuLaunchKernel(
|
| 140 |
+
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
|
| 141 |
+
));
|
| 142 |
+
}
|
| 143 |
+
"""
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def write_get_raw_stream(self, index, graph=None):
|
| 147 |
+
name = f"stream{index}"
|
| 148 |
+
self.writeline(f"cudaStream_t {name};")
|
| 149 |
+
self.writeline(
|
| 150 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));"
|
| 151 |
+
)
|
| 152 |
+
return name
|
| 153 |
+
|
| 154 |
+
def define_kernel(
|
| 155 |
+
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
|
| 156 |
+
):
|
| 157 |
+
if not cuda:
|
| 158 |
+
return super().define_kernel(name, kernel, metadata, cuda)
|
| 159 |
+
|
| 160 |
+
def generate(self, is_inference):
|
| 161 |
+
self.prefix.writeline("\n")
|
| 162 |
+
if not V.graph.aot_mode:
|
| 163 |
+
for kernel in chain(
|
| 164 |
+
self.src_to_kernel.values(),
|
| 165 |
+
[entry[0] for entry in self.user_defined_kernel_cache.values()],
|
| 166 |
+
):
|
| 167 |
+
self.prefix.writeline(f"static CUfunction {kernel} = nullptr;")
|
| 168 |
+
self.prefix.writeline("\n")
|
| 169 |
+
return super().generate(is_inference)
|
| 170 |
+
|
| 171 |
+
@functools.lru_cache(None)
|
| 172 |
+
def generate_load_kernel_once(
|
| 173 |
+
self,
|
| 174 |
+
name: str,
|
| 175 |
+
mangled_name: str,
|
| 176 |
+
cubin_path: str,
|
| 177 |
+
shared_mem: int,
|
| 178 |
+
graph: "GraphLowering", # for per-graph caching
|
| 179 |
+
):
|
| 180 |
+
if V.graph.aot_mode:
|
| 181 |
+
self.writeline(f"if (kernels.{name} == nullptr) {{")
|
| 182 |
+
self.writeline(
|
| 183 |
+
f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);"""
|
| 184 |
+
)
|
| 185 |
+
self.writeline("}")
|
| 186 |
+
else:
|
| 187 |
+
self.writeline(f"if ({name} == nullptr) {{")
|
| 188 |
+
self.writeline(
|
| 189 |
+
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
|
| 190 |
+
)
|
| 191 |
+
self.writeline("}")
|
| 192 |
+
|
| 193 |
+
def generate_args_decl(self, call_args):
|
| 194 |
+
dynamic_symbols = V.graph.sizevars.free_symbols()
|
| 195 |
+
# TODO: only works for constant now, need type info
|
| 196 |
+
new_args = []
|
| 197 |
+
for arg in call_args:
|
| 198 |
+
var_name = f"var_{next(self.arg_var_id)}"
|
| 199 |
+
if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)):
|
| 200 |
+
self.writeline(f"auto {var_name} = {arg};")
|
| 201 |
+
elif isinstance(arg, sympy.Expr):
|
| 202 |
+
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
|
| 203 |
+
elif is_int(arg):
|
| 204 |
+
self.writeline(f"int {var_name} = {arg};")
|
| 205 |
+
elif is_float(arg):
|
| 206 |
+
self.writeline(f"float {var_name} = {arg};")
|
| 207 |
+
elif any(str(arg) == s.name for s in dynamic_symbols):
|
| 208 |
+
self.writeline(f"auto {var_name} = {arg};")
|
| 209 |
+
elif arg == "nullptr":
|
| 210 |
+
self.writeline(f"auto {var_name} = nullptr;")
|
| 211 |
+
elif arg == "c10::nullopt":
|
| 212 |
+
self.writeline(f"auto {var_name} = c10::nullopt;")
|
| 213 |
+
else:
|
| 214 |
+
if config.abi_compatible:
|
| 215 |
+
self.writeline(f"CUdeviceptr {var_name};")
|
| 216 |
+
self.writeline(
|
| 217 |
+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
self.writeline(
|
| 221 |
+
f"CUdeviceptr {var_name} = reinterpret_cast<CUdeviceptr>({arg}.data_ptr());"
|
| 222 |
+
)
|
| 223 |
+
new_args.append(f"&{var_name}")
|
| 224 |
+
|
| 225 |
+
return ", ".join(new_args)
|
| 226 |
+
|
| 227 |
+
def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
|
| 228 |
+
"""
|
| 229 |
+
Generate grid configs for launching a CUDA kernel using the grid
|
| 230 |
+
function from triton_heuristics.
|
| 231 |
+
"""
|
| 232 |
+
if not cuda:
|
| 233 |
+
return grid
|
| 234 |
+
assert isinstance(grid, list), f"expected {grid=} to be a list"
|
| 235 |
+
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
|
| 236 |
+
grid_fn = default_grid(*grid)
|
| 237 |
+
params = CudaKernelParamCache.get(name)
|
| 238 |
+
assert (
|
| 239 |
+
params is not None
|
| 240 |
+
), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}"
|
| 241 |
+
block_cfg = {
|
| 242 |
+
"XBLOCK": params["x_block"],
|
| 243 |
+
"YBLOCK": params["y_block"],
|
| 244 |
+
"ZBLOCK": params["z_block"],
|
| 245 |
+
}
|
| 246 |
+
return grid_fn(block_cfg)
|
| 247 |
+
|
| 248 |
+
def generate_kernel_call(
|
| 249 |
+
self,
|
| 250 |
+
name,
|
| 251 |
+
call_args,
|
| 252 |
+
grid=None,
|
| 253 |
+
device_index=None,
|
| 254 |
+
cuda=True,
|
| 255 |
+
triton=True,
|
| 256 |
+
arg_types=None,
|
| 257 |
+
grid_fn: str = "grid",
|
| 258 |
+
triton_meta=None,
|
| 259 |
+
):
|
| 260 |
+
if not cuda:
|
| 261 |
+
# Even in CppWrapperCuda, we may see cpp kernels
|
| 262 |
+
return super().generate_kernel_call(
|
| 263 |
+
name, call_args, grid, device_index, cuda, triton, arg_types
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
params = CudaKernelParamCache.get(name)
|
| 267 |
+
assert (
|
| 268 |
+
params is not None
|
| 269 |
+
), f"cuda kernel parameters for {name} should already exist at this moment"
|
| 270 |
+
mangled_name = params.get("mangled_name", None)
|
| 271 |
+
assert mangled_name is not None, "missing mangled_name"
|
| 272 |
+
cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None)
|
| 273 |
+
assert cubin_path is not None and os.path.exists(
|
| 274 |
+
cubin_path
|
| 275 |
+
), f"cubin file should already exist at this moment: {cubin_path}"
|
| 276 |
+
shared_mem = params.get("shared_mem", 0)
|
| 277 |
+
|
| 278 |
+
self.generate_load_kernel_once(
|
| 279 |
+
name, mangled_name, cubin_path, shared_mem, V.graph
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# args with value 1 are added into equal_to_1 and constants
|
| 283 |
+
# in triton_meta (in the Python codegen) which makes them
|
| 284 |
+
# inlined in the PTX and compiled CUBIN
|
| 285 |
+
if (
|
| 286 |
+
triton_meta is not None
|
| 287 |
+
and "configs" in triton_meta
|
| 288 |
+
and triton_meta["configs"]
|
| 289 |
+
):
|
| 290 |
+
equal_to_1 = triton_meta["configs"][0].equal_to_1
|
| 291 |
+
call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1]
|
| 292 |
+
|
| 293 |
+
call_args = self.generate_args_decl(call_args)
|
| 294 |
+
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
|
| 295 |
+
self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};")
|
| 296 |
+
stream = (
|
| 297 |
+
"stream"
|
| 298 |
+
if V.graph.aot_mode
|
| 299 |
+
else self.write_get_raw_stream(device_index, V.graph)
|
| 300 |
+
)
|
| 301 |
+
grid_name = f"{name}_grid_{next(self.grid_id)}"
|
| 302 |
+
assert isinstance(
|
| 303 |
+
grid, (list, tuple)
|
| 304 |
+
), f"expected grid to be a list or tuple but got: {grid=}"
|
| 305 |
+
|
| 306 |
+
grid = [V.graph.sizevars.simplify(item) for item in grid]
|
| 307 |
+
grid_uses_symbolic_shapes = any(item.free_symbols for item in grid)
|
| 308 |
+
grid_args = [self.grid_expr_printer(item) for item in grid]
|
| 309 |
+
grid_args_str = ", ".join(grid_args)
|
| 310 |
+
self.writeline(f"Grid {grid_name} = Grid({grid_args_str});")
|
| 311 |
+
|
| 312 |
+
if grid_uses_symbolic_shapes:
|
| 313 |
+
self.writeline(f"if ({grid_name}.is_non_zero()) {{")
|
| 314 |
+
kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name
|
| 315 |
+
self.writeline(
|
| 316 |
+
"launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(
|
| 317 |
+
kernel_var_name,
|
| 318 |
+
f"{grid_name}.grid_x",
|
| 319 |
+
f"{grid_name}.grid_y",
|
| 320 |
+
f"{grid_name}.grid_z",
|
| 321 |
+
params["num_warps"],
|
| 322 |
+
params["shared_mem"],
|
| 323 |
+
kernel_args_var,
|
| 324 |
+
stream,
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
if grid_uses_symbolic_shapes:
|
| 328 |
+
self.writeline("}")
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-311.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
| 3 |
+
|
| 4 |
+
from ... import ir
|
| 5 |
+
from ...autotune_process import CUDABenchmarkRequest
|
| 6 |
+
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout, TensorBox
|
| 7 |
+
from ...select_algorithm import ChoiceCaller
|
| 8 |
+
from ...utils import sympy_product
|
| 9 |
+
from ...virtualized import V
|
| 10 |
+
|
| 11 |
+
from ..common import IndentedBuffer, Kernel, OpOverrides, PrimitiveInfoType
|
| 12 |
+
from ..cpp import CppPrinter, DTYPE_TO_CPP
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from torch._inductor.codegen.cuda.cuda_template import CUDATemplate
|
| 16 |
+
|
| 17 |
+
log = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
cexpr = CppPrinter().doprint
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _normalize_idx(index: int, total_length: int) -> int:
|
| 23 |
+
return index if index >= 0 else index + total_length
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CUDAKernel(Kernel):
|
| 27 |
+
"""
|
| 28 |
+
Baseclass for CUDA / Cutlass based Kernels
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
overrides = OpOverrides # type: ignore[assignment]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CUDATemplateKernel(CUDAKernel):
|
| 35 |
+
"""
|
| 36 |
+
Template kernels defined by CUDA / Cutlass in C++.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
_EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
|
| 40 |
+
|
| 41 |
+
def __init__(self, kernel_name):
|
| 42 |
+
"""
|
| 43 |
+
Initializes a new instance of the CUDATemplateKernel class.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
kernel_name (str): The name of the kernel.
|
| 47 |
+
"""
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.kernel_name = kernel_name
|
| 50 |
+
# Mapping from arg name to IRNode.
|
| 51 |
+
self.named_nodes: Dict[str, IRNode] = {}
|
| 52 |
+
|
| 53 |
+
def arg_name(self, node: IRNode) -> Optional[str]:
|
| 54 |
+
"""
|
| 55 |
+
Returns arg name of a given input or output node.
|
| 56 |
+
"""
|
| 57 |
+
if node is None:
|
| 58 |
+
return None
|
| 59 |
+
return {**self.args.input_buffers, **self.args.output_buffers}.get(
|
| 60 |
+
node.get_name(), None
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def check_not_null(self, node: IRNode) -> str:
|
| 64 |
+
"""
|
| 65 |
+
Generates code to check that a node is not null.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
if node is None:
|
| 69 |
+
return ""
|
| 70 |
+
|
| 71 |
+
size_str = self.size(node, 0, -1)
|
| 72 |
+
name_str = self.arg_name(node)
|
| 73 |
+
if name_str is None:
|
| 74 |
+
return ""
|
| 75 |
+
|
| 76 |
+
res = IndentedBuffer(initial_indent=2)
|
| 77 |
+
res.tabwidth = 1
|
| 78 |
+
res.splice(
|
| 79 |
+
f"""
|
| 80 |
+
{{
|
| 81 |
+
if (!{name_str}) {{
|
| 82 |
+
int64_t {name_str}_size = {size_str};
|
| 83 |
+
if ({name_str}_size > 0) {{
|
| 84 |
+
throw std::runtime_error("input {name_str} is null but size is not 0!");
|
| 85 |
+
}}
|
| 86 |
+
}}
|
| 87 |
+
}}
|
| 88 |
+
"""
|
| 89 |
+
)
|
| 90 |
+
return res.getvalue()
|
| 91 |
+
|
| 92 |
+
def def_kernel(
|
| 93 |
+
self,
|
| 94 |
+
inputs: List[IRNode],
|
| 95 |
+
outputs: List[IRNode],
|
| 96 |
+
names_str: str = "",
|
| 97 |
+
input_reorder: Optional[List[int]] = None,
|
| 98 |
+
) -> str:
|
| 99 |
+
"""
|
| 100 |
+
Hook called from template code to generate function definition and
|
| 101 |
+
needed args.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
inputs: List of input IRNodes
|
| 105 |
+
outputs: List of output IRNodes
|
| 106 |
+
names_str: Comma separated list of input + output argument names.
|
| 107 |
+
input_reorder: The actual order of input nodes.
|
| 108 |
+
e.g. The template might have input argument defined as [X, W, Bias],
|
| 109 |
+
and the actual input passed into this template could be [Bias, X, W].
|
| 110 |
+
In this case, the `input_reorder` would be [2, 0, 1].
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
names = [x.strip() for x in names_str.strip().split(",")]
|
| 114 |
+
if len(inputs) + len(outputs) != len(names):
|
| 115 |
+
raise RuntimeError(
|
| 116 |
+
f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if input_reorder is not None:
|
| 120 |
+
assert len(inputs) == len(input_reorder)
|
| 121 |
+
else:
|
| 122 |
+
input_reorder = list(range(len(inputs)))
|
| 123 |
+
|
| 124 |
+
for idx in input_reorder:
|
| 125 |
+
name = names[idx]
|
| 126 |
+
node = inputs[idx]
|
| 127 |
+
if node is not None:
|
| 128 |
+
self.named_nodes[name] = node
|
| 129 |
+
self.args.input_buffers[node.get_name()] = name
|
| 130 |
+
|
| 131 |
+
for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
|
| 132 |
+
if node is not None:
|
| 133 |
+
self.named_nodes[name] = node
|
| 134 |
+
self.args.output_buffers[node.get_name()] = name
|
| 135 |
+
|
| 136 |
+
arg_defs, *_ = self.args.cpp_argdefs()
|
| 137 |
+
return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})"
|
| 138 |
+
|
| 139 |
+
def call_kernel(
|
| 140 |
+
self, name: str, node: "CUDATemplateBuffer", epilogue_nodes: List[ir.Buffer] # type: ignore[name-defined]
|
| 141 |
+
) -> None:
|
| 142 |
+
"""
|
| 143 |
+
Generates code to call the kernel through V.graph.wrapper_code.
|
| 144 |
+
used from within torch._inductor.wrapper.WrapperCodeGen
|
| 145 |
+
|
| 146 |
+
name: Name of kernel function.
|
| 147 |
+
node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
|
| 148 |
+
as well as all required inputs and outputs.
|
| 149 |
+
"""
|
| 150 |
+
wrapper = V.graph.wrapper_code
|
| 151 |
+
_, call_args, _ = self.args.python_argdefs()
|
| 152 |
+
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
| 153 |
+
for i in range(len(call_args)):
|
| 154 |
+
if V.graph.is_unspec_arg(call_args[i]):
|
| 155 |
+
call_args[i] = call_args[i] + ".item()"
|
| 156 |
+
else:
|
| 157 |
+
call_args[i] = f"c_void_p({call_args[i]}.data_ptr())"
|
| 158 |
+
|
| 159 |
+
# workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size.
|
| 160 |
+
# workspace_size should have already been retrieved prior to this call.
|
| 161 |
+
call_args.append("None")
|
| 162 |
+
|
| 163 |
+
if node.get_workspace_size() > 0:
|
| 164 |
+
call_args.append(f"c_void_p({node.get_name()}_workspace.data_ptr())")
|
| 165 |
+
else:
|
| 166 |
+
call_args.append("None")
|
| 167 |
+
|
| 168 |
+
wrapper.generate_kernel_call(
|
| 169 |
+
name,
|
| 170 |
+
call_args,
|
| 171 |
+
device_index=V.graph.scheduler.current_device.index,
|
| 172 |
+
cuda=True,
|
| 173 |
+
triton=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def dtype(self, node: IRNode) -> Optional[str]:
|
| 177 |
+
"""
|
| 178 |
+
Generates code which represents dtype of a given node.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
if node is None:
|
| 182 |
+
return "void"
|
| 183 |
+
return DTYPE_TO_CPP.get(node.get_layout().dtype)
|
| 184 |
+
|
| 185 |
+
def offset(self, node: IRNode) -> str:
|
| 186 |
+
"""
|
| 187 |
+
Generates code which represents offset of a given node.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
if node is None:
|
| 191 |
+
return "0"
|
| 192 |
+
return str(node.get_layout().offset)
|
| 193 |
+
|
| 194 |
+
def ptr(self, node: IRNode) -> str:
|
| 195 |
+
"""
|
| 196 |
+
Generates code which represents pointer of a given node.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
if node is None:
|
| 200 |
+
return "nullptr"
|
| 201 |
+
arg_name = self.arg_name(node)
|
| 202 |
+
if arg_name is None:
|
| 203 |
+
return "nullptr"
|
| 204 |
+
offset = self.offset(node)
|
| 205 |
+
return arg_name if offset == "0" else f"{arg_name} + {offset}"
|
| 206 |
+
|
| 207 |
+
def size(
|
| 208 |
+
self,
|
| 209 |
+
node: IRNode,
|
| 210 |
+
start_index: int,
|
| 211 |
+
end_index: Optional[int] = None,
|
| 212 |
+
default_value: int = 0,
|
| 213 |
+
) -> str:
|
| 214 |
+
"""
|
| 215 |
+
Hook called from template code to get the size of an arg.
|
| 216 |
+
Generates code which represents size of a given node in [start_index, end_index).
|
| 217 |
+
If node is None, returns default_value.
|
| 218 |
+
|
| 219 |
+
TODO: Will add needed args to pass it in if it is dynamic.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
if node is None:
|
| 223 |
+
return str(default_value)
|
| 224 |
+
|
| 225 |
+
start_index = _normalize_idx(start_index, len(node.get_size()))
|
| 226 |
+
if end_index is None:
|
| 227 |
+
end_index = start_index
|
| 228 |
+
end_index = _normalize_idx(end_index, len(node.get_size()))
|
| 229 |
+
|
| 230 |
+
sizes = node.get_size()[start_index : end_index + 1]
|
| 231 |
+
if len(sizes) == 0:
|
| 232 |
+
return str(default_value)
|
| 233 |
+
|
| 234 |
+
val = sympy_product(sizes)
|
| 235 |
+
return cexpr(self.rename_indexing(val))
|
| 236 |
+
|
| 237 |
+
def stride(self, node: IRNode, index: int, default_value: int = 0) -> str:
|
| 238 |
+
"""
|
| 239 |
+
Hook called from template code to get the stride of an arg.
|
| 240 |
+
Generates code which represents stride of a given node at index.
|
| 241 |
+
If node is None, returns default_value.
|
| 242 |
+
|
| 243 |
+
TODO: Will add needed args to pass it in if it is dynamic.
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
if node is None:
|
| 247 |
+
return str(default_value)
|
| 248 |
+
|
| 249 |
+
index = _normalize_idx(index, len(node.get_size()))
|
| 250 |
+
if index < 0:
|
| 251 |
+
return str(default_value)
|
| 252 |
+
|
| 253 |
+
stride = node.get_stride()[index]
|
| 254 |
+
return cexpr(self.rename_indexing(stride))
|
| 255 |
+
|
| 256 |
+
def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
|
| 257 |
+
"""
|
| 258 |
+
Hook called from template code to get the row or column stride of an arg.
|
| 259 |
+
This is required by some CUTLASS 2.X APIs.
|
| 260 |
+
If the node is in row_major, it returns stride[-2].
|
| 261 |
+
If the node is in column_major, it returns stride[-1].
|
| 262 |
+
|
| 263 |
+
TODO: Will add needed args to pass it in if it is dynamic.
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
if node is None or len(node.get_stride()) < 2:
|
| 267 |
+
return str(default_value)
|
| 268 |
+
|
| 269 |
+
stride0 = node.get_stride()[-1]
|
| 270 |
+
stride1 = node.get_stride()[-2]
|
| 271 |
+
if stride0 == 1:
|
| 272 |
+
return cexpr(self.rename_indexing(stride1))
|
| 273 |
+
elif stride1 == 1:
|
| 274 |
+
return cexpr(self.rename_indexing(stride0))
|
| 275 |
+
else:
|
| 276 |
+
raise RuntimeError(
|
| 277 |
+
f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class CUDATemplateCaller(ChoiceCaller):
|
| 282 |
+
"""
|
| 283 |
+
CUDATemplateCaller
|
| 284 |
+
|
| 285 |
+
This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller.
|
| 286 |
+
Attributes:
|
| 287 |
+
name (str): The name of the caller.
|
| 288 |
+
category (str): The category of the caller.
|
| 289 |
+
bmreq (CUDABenchmarkRequest): The benchmark request for the caller.
|
| 290 |
+
template_buffer (CUDATemplateBuffer): The template buffer for the caller.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
def __init__(
|
| 294 |
+
self,
|
| 295 |
+
name: str,
|
| 296 |
+
category: str,
|
| 297 |
+
input_nodes: List[Buffer],
|
| 298 |
+
layout: Layout,
|
| 299 |
+
make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str],
|
| 300 |
+
bmreq: CUDABenchmarkRequest,
|
| 301 |
+
template: "CUDATemplate", # type: ignore[name-defined]
|
| 302 |
+
info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg]
|
| 303 |
+
):
|
| 304 |
+
super().__init__(name, input_nodes, layout)
|
| 305 |
+
self.category = category
|
| 306 |
+
self.make_kernel_render = make_kernel_render
|
| 307 |
+
self.bmreq = bmreq
|
| 308 |
+
self.template = template
|
| 309 |
+
self.info_kwargs = info_kwargs
|
| 310 |
+
|
| 311 |
+
def precompile(self) -> None:
|
| 312 |
+
assert self.bmreq is not None
|
| 313 |
+
self.bmreq.precompile()
|
| 314 |
+
|
| 315 |
+
def benchmark(self, *args, out) -> float:
|
| 316 |
+
assert self.bmreq is not None
|
| 317 |
+
return self.bmreq.benchmark(
|
| 318 |
+
*args, output_tensor=out
|
| 319 |
+
) # @TODO: Hack for ensuring that Cutlass Kernel is preferred
|
| 320 |
+
|
| 321 |
+
def __str__(self):
|
| 322 |
+
return f"CUDATemplateCaller(source_file={self.bmreq.source_file})"
|
| 323 |
+
|
| 324 |
+
def call_name(self) -> str:
|
| 325 |
+
return f"cuda_template_kernels.{self.name}"
|
| 326 |
+
|
| 327 |
+
def hash_key(self) -> str:
|
| 328 |
+
return "-".join(
|
| 329 |
+
[
|
| 330 |
+
self.category,
|
| 331 |
+
self.bmreq.hash_key,
|
| 332 |
+
]
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
| 336 |
+
"""Information returned here is logged to the autotune log file when that is enabled."""
|
| 337 |
+
if self.info_kwargs is not None and "op" in self.info_kwargs:
|
| 338 |
+
op: Any = self.info_kwargs["op"]
|
| 339 |
+
epilogue_node_names: List[str] = [
|
| 340 |
+
getattr(en, "name", "no_name")
|
| 341 |
+
for en in self.info_kwargs.get("epilogue_nodes", []) # type: ignore[union-attr]
|
| 342 |
+
]
|
| 343 |
+
epilogue_node_strs: List[str] = [
|
| 344 |
+
str(en) for en in self.info_kwargs.get("epilogue_nodes", []) # type: ignore[union-attr]
|
| 345 |
+
]
|
| 346 |
+
return {
|
| 347 |
+
"backend": "CUDA",
|
| 348 |
+
"op_type": type(op).__name__,
|
| 349 |
+
"op_conf_name": str(op.configuration_name()),
|
| 350 |
+
"op_arch": str(op.arch),
|
| 351 |
+
"tile_shape": str(op.tile_description.tile_shape),
|
| 352 |
+
"epilogue_schedule": str(op.epilogue_schedule),
|
| 353 |
+
"kernel_schedule": str(op.kernel_schedule),
|
| 354 |
+
"element_accumulator": str(op.accumulator_type()),
|
| 355 |
+
"op_name": str(op.procedural_name()),
|
| 356 |
+
"epilogue_node_names": epilogue_node_names, # type: ignore[dict-item]
|
| 357 |
+
"epilogue_node_strs": epilogue_node_strs, # type: ignore[dict-item]
|
| 358 |
+
"instruction_shape": str(
|
| 359 |
+
op.tile_description.math_instruction.instruction_shape
|
| 360 |
+
),
|
| 361 |
+
}
|
| 362 |
+
else:
|
| 363 |
+
return {"backend": "CUDA", "op_type": "unknown"}
|
| 364 |
+
|
| 365 |
+
def output_node(self) -> TensorBox:
|
| 366 |
+
return TensorBox.create(
|
| 367 |
+
CUDATemplateBuffer(
|
| 368 |
+
layout=self.layout,
|
| 369 |
+
inputs=self.input_nodes,
|
| 370 |
+
make_kernel_render=self.make_kernel_render,
|
| 371 |
+
workspace_size=self.bmreq.workspace_size,
|
| 372 |
+
template=self.template,
|
| 373 |
+
)
|
| 374 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (252 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-311.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/gemm_template.py
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import re
|
| 4 |
+
from typing import cast, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
from ...config import cuda as inductor_cuda_config
|
| 7 |
+
from ...ir import Buffer, CUDATemplateBuffer, FixedLayout, IRNode, Layout
|
| 8 |
+
from ..common import IndentedBuffer
|
| 9 |
+
|
| 10 |
+
from . import cutlass_utils
|
| 11 |
+
from .cuda_kernel import CUDATemplateKernel
|
| 12 |
+
from .cuda_template import CUTLASSTemplate
|
| 13 |
+
from .cutlass_epilogue_gen import (
|
| 14 |
+
CutlassEVTEpilogueArgumentFormatter,
|
| 15 |
+
CutlassEVTEpilogueTypeFormatter,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
log = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
GEMM_TEMPLATE = r"""
|
| 21 |
+
{{template.header().getvalue()}}
|
| 22 |
+
{{template.globals().getvalue()}}
|
| 23 |
+
{{instance_definition}}
|
| 24 |
+
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
|
| 25 |
+
// Otherwise, computes the Gemm kernel using the given workspace ptr.
|
| 26 |
+
extern "C" {
|
| 27 |
+
{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} {
|
| 28 |
+
try {
|
| 29 |
+
{{kernel.check_not_null(X)}}
|
| 30 |
+
{{kernel.check_not_null(W)}}
|
| 31 |
+
{{kernel.check_not_null(Bias)}}
|
| 32 |
+
{{kernel.check_not_null(Y)}}
|
| 33 |
+
int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
|
| 34 |
+
int64_t M = {{kernel.size(X, -2)}};
|
| 35 |
+
int64_t K = {{kernel.size(X, -1)}};
|
| 36 |
+
int64_t N = {{kernel.size(W, -1)}};
|
| 37 |
+
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
|
| 38 |
+
using coord_t = cutlass::gemm::GemmCoord::Index;
|
| 39 |
+
{{instance_type}}::Arguments arguments;
|
| 40 |
+
{{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw,
|
| 41 |
+
X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}}
|
| 42 |
+
{{instance_type}} gemm_op;
|
| 43 |
+
if (workspace_size) {
|
| 44 |
+
*workspace_size = gemm_op.get_workspace_size(arguments);
|
| 45 |
+
return 0;
|
| 46 |
+
}
|
| 47 |
+
{
|
| 48 |
+
auto status = gemm_op.can_implement(arguments);
|
| 49 |
+
CUTLASS_CHECK(status);
|
| 50 |
+
}
|
| 51 |
+
{
|
| 52 |
+
auto status = gemm_op.initialize(arguments, workspace, stream);
|
| 53 |
+
CUTLASS_CHECK(status);
|
| 54 |
+
}
|
| 55 |
+
{
|
| 56 |
+
auto status = gemm_op(stream);
|
| 57 |
+
CUTLASS_CHECK(status);
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
catch (std::exception& e) {
|
| 61 |
+
std::cerr << "Runtime error: " << e.what() << std::endl;
|
| 62 |
+
return -1;
|
| 63 |
+
}
|
| 64 |
+
catch (...) {
|
| 65 |
+
return -1;
|
| 66 |
+
}
|
| 67 |
+
return 0;
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
GEMM_ARGS_CUTLASS_2X = r"""
|
| 74 |
+
int64_t batch_stride_x = {{kernel.stride(X, -3)}};
|
| 75 |
+
int64_t row_stride_x = {{kernel.row_or_column_stride(X)}};
|
| 76 |
+
int64_t batch_stride_w = {{kernel.stride(W, -3)}};
|
| 77 |
+
int64_t row_stride_w = {{kernel.row_or_column_stride(W)}};
|
| 78 |
+
int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}};
|
| 79 |
+
int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}};
|
| 80 |
+
int64_t batch_stride_y = {{kernel.stride(Y, -3)}};
|
| 81 |
+
int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}};
|
| 82 |
+
// Initialize GemmUniversalInstance arguments.
|
| 83 |
+
arguments = {
|
| 84 |
+
{{template.gemm_mode()}}, // GemmUniversalMode mode
|
| 85 |
+
{
|
| 86 |
+
static_cast<coord_t>(M),
|
| 87 |
+
static_cast<coord_t>(N),
|
| 88 |
+
static_cast<coord_t>(K)
|
| 89 |
+
}, // GemmCoord problem_size
|
| 90 |
+
{{split_k if split_k > 1 else 'B'}}, // int batch_count
|
| 91 |
+
{ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue
|
| 92 |
+
{{template.cutlass_type_cast(X, kernel.ptr(X))}}, // void const * ptr_A
|
| 93 |
+
{{template.cutlass_type_cast(W, kernel.ptr(W))}}, // void const * ptr_B
|
| 94 |
+
{{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // void const * ptr_C
|
| 95 |
+
{{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // void * ptr_D
|
| 96 |
+
batch_stride_x, // int64_t batch_stride_A
|
| 97 |
+
batch_stride_w, // int64_t batch_stride_B
|
| 98 |
+
batch_stride_bias, // int64_t batch_stride_C
|
| 99 |
+
batch_stride_y, // int64_t batch_stride_D
|
| 100 |
+
row_stride_x, // typename LayoutA::Stride::LongIndex lda
|
| 101 |
+
row_stride_w, // typename LayoutB::Stride::LongIndex ldb
|
| 102 |
+
row_stride_bias, // typename LayoutC::Stride::LongIndex ldc
|
| 103 |
+
row_stride_y, // typename LayoutC::Stride::LongIndex ldd
|
| 104 |
+
};
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
GEMM_ARGS_CUTLASS_3X = r"""
|
| 109 |
+
// Initialize GemmUniversal3xInstance arguments.
|
| 110 |
+
arguments = {
|
| 111 |
+
{{template.gemm_mode()}}, // GemmUniversalMode mode
|
| 112 |
+
{
|
| 113 |
+
static_cast<coord_t>({{M}}),
|
| 114 |
+
static_cast<coord_t>({{N}}),
|
| 115 |
+
static_cast<coord_t>(K),
|
| 116 |
+
static_cast<coord_t>(B)
|
| 117 |
+
}, // ProblemShape problem_shape
|
| 118 |
+
{
|
| 119 |
+
{{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A
|
| 120 |
+
{
|
| 121 |
+
{{template.cute_int(kernel.stride(X, -2), "stride_x0")}},
|
| 122 |
+
{{template.cute_int(kernel.stride(X, -1), "stride_x1")}},
|
| 123 |
+
{{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}}
|
| 124 |
+
}, // StrideA dA
|
| 125 |
+
{{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B
|
| 126 |
+
{
|
| 127 |
+
{{template.cute_int(kernel.stride(W, -1), "stride_w1")}},
|
| 128 |
+
{{template.cute_int(kernel.stride(W, -2), "stride_w0")}},
|
| 129 |
+
{{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}}
|
| 130 |
+
}, // StrideB dB
|
| 131 |
+
}, // MainloopArguments mainloop
|
| 132 |
+
{{epilogue_arguments}}
|
| 133 |
+
};
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
|
| 137 |
+
// see https://tinyurl.com/4rk89z48
|
| 138 |
+
{
|
| 139 |
+
{{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
|
| 140 |
+
{{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C
|
| 141 |
+
{
|
| 142 |
+
{{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}},
|
| 143 |
+
{{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}},
|
| 144 |
+
{{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}}
|
| 145 |
+
}, // StrideC dC
|
| 146 |
+
{{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D
|
| 147 |
+
{
|
| 148 |
+
{{template.cute_int(kernel.stride(Y, -2), "stride_y0")}},
|
| 149 |
+
{{template.cute_int(kernel.stride(Y, -1), "stride_y1")}},
|
| 150 |
+
{{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}}
|
| 151 |
+
}, // StrideD dD
|
| 152 |
+
}, // EpilogueArguments epilogue
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class CUTLASSGemmTemplate(CUTLASSTemplate):
|
| 157 |
+
"""
|
| 158 |
+
CUTLASS GEMM template, which is used to generate CUTLASS GEMM kernels
|
| 159 |
+
including those which allow flexible fusions with epilogues.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
input_nodes: List[Buffer],
|
| 165 |
+
layout: Layout,
|
| 166 |
+
alpha: float,
|
| 167 |
+
beta: float,
|
| 168 |
+
input_reorder: Optional[List[int]] = None,
|
| 169 |
+
can_fuse_epilogue: Optional[bool] = None,
|
| 170 |
+
):
|
| 171 |
+
"""
|
| 172 |
+
Args:
|
| 173 |
+
input_nodes: input nodes of the kernel
|
| 174 |
+
layout: layout of the output node
|
| 175 |
+
alpha: alpha value of the GEMM operation
|
| 176 |
+
beta: beta value of the GEMM operation
|
| 177 |
+
input_reorder: reorder of the input nodes
|
| 178 |
+
can_fuse_epilogue: If set to True, will only list and use operators capable of flexible epilogue fusions.
|
| 179 |
+
If False, it will not use those. If None, both may be listed, but it will not allow fusions.
|
| 180 |
+
Defaults to None
|
| 181 |
+
"""
|
| 182 |
+
super().__init__("cutlass_gemm", input_nodes, layout, input_reorder)
|
| 183 |
+
self.alpha = alpha
|
| 184 |
+
self.beta = beta
|
| 185 |
+
self.can_fuse_epilogue = can_fuse_epilogue
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def add_cutlass_gemm_choices(
|
| 189 |
+
choices,
|
| 190 |
+
layout,
|
| 191 |
+
input_nodes,
|
| 192 |
+
alpha=1,
|
| 193 |
+
beta=0,
|
| 194 |
+
input_reorder=None,
|
| 195 |
+
fuseable=True,
|
| 196 |
+
non_fuseable=True,
|
| 197 |
+
):
|
| 198 |
+
if non_fuseable:
|
| 199 |
+
if fuseable:
|
| 200 |
+
# list both fuseable and non-fuseable ops, and treat them all as non-fuseable
|
| 201 |
+
can_fuse_epilogue = False
|
| 202 |
+
else:
|
| 203 |
+
can_fuse_epilogue = None
|
| 204 |
+
|
| 205 |
+
cutlass_template = CUTLASSGemmTemplate(
|
| 206 |
+
input_nodes,
|
| 207 |
+
layout,
|
| 208 |
+
alpha=alpha,
|
| 209 |
+
beta=beta,
|
| 210 |
+
input_reorder=input_reorder,
|
| 211 |
+
can_fuse_epilogue=can_fuse_epilogue,
|
| 212 |
+
)
|
| 213 |
+
ops = cutlass_template.gen_ops()
|
| 214 |
+
for op in ops:
|
| 215 |
+
cutlass_template.maybe_append_choice(
|
| 216 |
+
choices,
|
| 217 |
+
op=op,
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
ops = []
|
| 221 |
+
if fuseable:
|
| 222 |
+
cutlass_template_evt = CUTLASSGemmTemplate(
|
| 223 |
+
input_nodes,
|
| 224 |
+
layout,
|
| 225 |
+
alpha=alpha,
|
| 226 |
+
beta=beta,
|
| 227 |
+
input_reorder=input_reorder,
|
| 228 |
+
can_fuse_epilogue=True,
|
| 229 |
+
)
|
| 230 |
+
# This will list only ops capable of EVT fusion
|
| 231 |
+
ops_evt = cutlass_template_evt.gen_ops()
|
| 232 |
+
for op in ops_evt:
|
| 233 |
+
cutlass_template_evt.maybe_append_choice(
|
| 234 |
+
choices,
|
| 235 |
+
op=op,
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
ops_evt = []
|
| 239 |
+
log.debug(
|
| 240 |
+
"Added %d cutlass gemm configs and %d fuseable gemm configs.",
|
| 241 |
+
len(ops),
|
| 242 |
+
len(ops_evt),
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def header(self) -> IndentedBuffer:
|
| 246 |
+
res = super().header()
|
| 247 |
+
res.splice(
|
| 248 |
+
"""
|
| 249 |
+
#include "cutlass/gemm/gemm.h"
|
| 250 |
+
#include "cutlass/gemm/device/gemm_universal.h"
|
| 251 |
+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
| 252 |
+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
| 253 |
+
#include "cutlass/gemm/collective/collective_builder.hpp"
|
| 254 |
+
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
| 255 |
+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
| 256 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 257 |
+
#include "cutlass/gemm/dispatch_policy.hpp"
|
| 258 |
+
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
| 259 |
+
#include "cutlass/util/distribution.h"
|
| 260 |
+
#include "cutlass/util/packed_stride.hpp"
|
| 261 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 262 |
+
"""
|
| 263 |
+
)
|
| 264 |
+
return res
|
| 265 |
+
|
| 266 |
+
@staticmethod
|
| 267 |
+
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821
|
| 268 |
+
assert cutlass_utils.try_import_cutlass()
|
| 269 |
+
import cutlass_library.library as cutlass_lib
|
| 270 |
+
|
| 271 |
+
if torch_layout.stride[-1] == 1:
|
| 272 |
+
return cutlass_lib.LayoutType.RowMajor
|
| 273 |
+
elif torch_layout.stride[-2] == 1:
|
| 274 |
+
return cutlass_lib.LayoutType.ColumnMajor
|
| 275 |
+
else:
|
| 276 |
+
return None
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def flip_cutlass_layout(
|
| 280 |
+
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821
|
| 281 |
+
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821
|
| 282 |
+
assert cutlass_utils.try_import_cutlass()
|
| 283 |
+
import cutlass_library.library as cutlass_lib
|
| 284 |
+
|
| 285 |
+
if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
|
| 286 |
+
return cutlass_lib.LayoutType.ColumnMajor
|
| 287 |
+
else:
|
| 288 |
+
return cutlass_lib.LayoutType.RowMajor
|
| 289 |
+
|
| 290 |
+
@staticmethod
|
| 291 |
+
def layout_match(torch_layout, cutlass_layout) -> bool:
|
| 292 |
+
return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout
|
| 293 |
+
|
| 294 |
+
@staticmethod
|
| 295 |
+
def set_alignment(torch_layout, op_element) -> bool:
|
| 296 |
+
alignment = cutlass_utils.get_max_alignment(torch_layout)
|
| 297 |
+
if alignment < op_element.alignment:
|
| 298 |
+
return False
|
| 299 |
+
else:
|
| 300 |
+
op_element.alignment = alignment
|
| 301 |
+
return True
|
| 302 |
+
|
| 303 |
+
@staticmethod
|
| 304 |
+
def has_tma_epilogue(op) -> bool:
|
| 305 |
+
assert cutlass_utils.try_import_cutlass()
|
| 306 |
+
import cutlass_library.library as cutlass_lib
|
| 307 |
+
|
| 308 |
+
result = False
|
| 309 |
+
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
| 310 |
+
epilogue_schedule_str = str(op.epilogue_schedule).split(".")[-1]
|
| 311 |
+
result = epilogue_schedule_str.lower().startswith("tma")
|
| 312 |
+
return result
|
| 313 |
+
|
| 314 |
+
@staticmethod
|
| 315 |
+
def supports_evt(op: "cutlass_library.gemm_op.GemmOperation") -> bool: # type: ignore[name-defined] # noqa: F821
|
| 316 |
+
"""
|
| 317 |
+
returns True if the op is capable of flexible epilogue fusions
|
| 318 |
+
using epilogue visitor trees.
|
| 319 |
+
|
| 320 |
+
See https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L283-L285 # noqa: B950
|
| 321 |
+
"""
|
| 322 |
+
assert cutlass_utils.try_import_cutlass()
|
| 323 |
+
import cutlass_library.library as cutlass_lib
|
| 324 |
+
|
| 325 |
+
if op.gemm_kind != cutlass_lib.GemmKind.Universal3x:
|
| 326 |
+
return False
|
| 327 |
+
if op.epilogue_schedule not in (
|
| 328 |
+
cutlass_lib.EpilogueScheduleType.TmaWarpSpecialized,
|
| 329 |
+
cutlass_lib.EpilogueScheduleType.TmaWarpSpecializedCooperative,
|
| 330 |
+
):
|
| 331 |
+
return False
|
| 332 |
+
|
| 333 |
+
return True
|
| 334 |
+
|
| 335 |
+
def render_evt_epilogue_declaration(
|
| 336 |
+
self,
|
| 337 |
+
template_output_node_name: str,
|
| 338 |
+
evt_type_name: str,
|
| 339 |
+
epilogue_nodes: List[IRNode],
|
| 340 |
+
) -> str:
|
| 341 |
+
"""Generates the epilogue for the EVT epilogue fusion"""
|
| 342 |
+
return CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(
|
| 343 |
+
template_output_node_name, evt_type_name, epilogue_nodes
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
def define_gemm_instance(
|
| 347 |
+
self,
|
| 348 |
+
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
|
| 349 |
+
output_buffer_name: str,
|
| 350 |
+
epilogue_nodes: Optional[List[IRNode]] = None,
|
| 351 |
+
) -> Tuple[str, str]:
|
| 352 |
+
assert cutlass_utils.try_import_cutlass()
|
| 353 |
+
import cutlass_library.gemm_operation as cutlass_gemm_op
|
| 354 |
+
import cutlass_library.library as cutlass_lib
|
| 355 |
+
|
| 356 |
+
from torch._inductor.codegen.cuda.cutlass_lib_extensions.gemm_operation_extensions import (
|
| 357 |
+
EmitGemmUniversal3xInstanceWithEVT,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
| 361 |
+
if epilogue_nodes is not None and len(epilogue_nodes) > 0:
|
| 362 |
+
emitter = EmitGemmUniversal3xInstanceWithEVT()
|
| 363 |
+
op.epilogue_functor = lambda epilogue_functor_type_name: self.render_evt_epilogue_declaration(
|
| 364 |
+
output_buffer_name, epilogue_functor_type_name, epilogue_nodes
|
| 365 |
+
)
|
| 366 |
+
else:
|
| 367 |
+
emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance()
|
| 368 |
+
op_def = emitter.emit(op)
|
| 369 |
+
pattern = re.compile(r"\s*struct\s(.*?)\s:")
|
| 370 |
+
decl = [line for line in op_def.split("\n") if "struct " in line][-1]
|
| 371 |
+
else:
|
| 372 |
+
if epilogue_nodes is not None and len(epilogue_nodes) > 0:
|
| 373 |
+
raise RuntimeError(
|
| 374 |
+
"EVT epilogue fusion is not supported for Cutlass 2.x ops."
|
| 375 |
+
)
|
| 376 |
+
emitter = cutlass_gemm_op.EmitGemmInstance()
|
| 377 |
+
op_def = emitter.emit(op)
|
| 378 |
+
op_def = op_def.replace(
|
| 379 |
+
"cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal"
|
| 380 |
+
)
|
| 381 |
+
op_def = op_def.replace("false,", "")
|
| 382 |
+
pattern = re.compile(r"\s*using\s(.*?)\s=")
|
| 383 |
+
decl = op_def.split("\n")[2]
|
| 384 |
+
match = pattern.match(decl)
|
| 385 |
+
if match is None:
|
| 386 |
+
raise RuntimeError("Invalid Gemm config: \n" + op_def)
|
| 387 |
+
op_type = match.groups()[0]
|
| 388 |
+
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
| 389 |
+
op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n"
|
| 390 |
+
op_type = f"{op_type}_device_type"
|
| 391 |
+
return op_def, op_type
|
| 392 |
+
|
| 393 |
+
@staticmethod
|
| 394 |
+
def should_swap_XW(
|
| 395 |
+
bias: IRNode,
|
| 396 |
+
beta: float,
|
| 397 |
+
) -> bool:
|
| 398 |
+
return True
|
| 399 |
+
|
| 400 |
+
# TODO(ipiszy): Check whether it's necessary to swap X/W.
|
| 401 |
+
# strides = bias.get_stride()
|
| 402 |
+
# if strides[-1] != 1:
|
| 403 |
+
# return True
|
| 404 |
+
# for stride in strides[:-1]:
|
| 405 |
+
# if stride != 0:
|
| 406 |
+
# return True
|
| 407 |
+
# return False
|
| 408 |
+
|
| 409 |
+
@staticmethod
|
| 410 |
+
def swap_XW(
|
| 411 |
+
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
|
| 412 |
+
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
|
| 413 |
+
# Swap X and W in GemmOperation.
|
| 414 |
+
new_op = copy.deepcopy(op)
|
| 415 |
+
new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout)
|
| 416 |
+
new_op.B.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.B.layout)
|
| 417 |
+
new_op.A, new_op.B = new_op.B, new_op.A
|
| 418 |
+
new_op.C.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.C.layout)
|
| 419 |
+
new_op.D.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.D.layout)
|
| 420 |
+
return new_op
|
| 421 |
+
|
| 422 |
+
def filter_op(
|
| 423 |
+
self,
|
| 424 |
+
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
|
| 425 |
+
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
|
| 426 |
+
assert cutlass_utils.try_import_cutlass()
|
| 427 |
+
import cutlass_library.library as cutlass_lib
|
| 428 |
+
|
| 429 |
+
# Skip simt kernels
|
| 430 |
+
if (
|
| 431 |
+
op.tile_description.math_instruction.opcode_class
|
| 432 |
+
== cutlass_lib.OpcodeClass.Simt
|
| 433 |
+
):
|
| 434 |
+
return None
|
| 435 |
+
|
| 436 |
+
# Only keep GemmUniversal kernels
|
| 437 |
+
if op.gemm_kind not in {
|
| 438 |
+
cutlass_lib.GemmKind.Universal,
|
| 439 |
+
cutlass_lib.GemmKind.Universal3x,
|
| 440 |
+
}:
|
| 441 |
+
return None
|
| 442 |
+
# Filter ops by dtypes.
|
| 443 |
+
X = self.input_nodes[0]
|
| 444 |
+
W = self.input_nodes[1]
|
| 445 |
+
accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype(
|
| 446 |
+
[X.get_dtype(), W.get_dtype()],
|
| 447 |
+
)
|
| 448 |
+
if not (
|
| 449 |
+
cutlass_utils.dtype_match(X.get_dtype(), op.A.element)
|
| 450 |
+
and cutlass_utils.dtype_match(W.get_dtype(), op.B.element)
|
| 451 |
+
and cutlass_utils.dtype_match(
|
| 452 |
+
self.output_node.get_layout().dtype, op.C.element
|
| 453 |
+
)
|
| 454 |
+
and cutlass_utils.dtype_match(
|
| 455 |
+
accumulator_torch_dtype, op.accumulator_type()
|
| 456 |
+
)
|
| 457 |
+
):
|
| 458 |
+
return None
|
| 459 |
+
|
| 460 |
+
# Filter ops by input layouts.
|
| 461 |
+
if not (
|
| 462 |
+
self.layout_match(X.get_layout(), op.A.layout)
|
| 463 |
+
and self.layout_match(W.get_layout(), op.B.layout)
|
| 464 |
+
):
|
| 465 |
+
return None
|
| 466 |
+
|
| 467 |
+
# Update op.
|
| 468 |
+
op = copy.deepcopy(op)
|
| 469 |
+
|
| 470 |
+
# Set output layout.
|
| 471 |
+
op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout())
|
| 472 |
+
|
| 473 |
+
# Filter ops by alignments and set alignments.
|
| 474 |
+
if not (
|
| 475 |
+
self.set_alignment(X.get_layout(), op.A)
|
| 476 |
+
and self.set_alignment(W.get_layout(), op.B)
|
| 477 |
+
and self.set_alignment(self.output_node.get_layout(), op.D)
|
| 478 |
+
):
|
| 479 |
+
return None
|
| 480 |
+
|
| 481 |
+
# Set epilogue.
|
| 482 |
+
# TODO: update epilogue functor according to epilogues.
|
| 483 |
+
op.element_epilogue = op.accumulator_type()
|
| 484 |
+
|
| 485 |
+
# Set bias layout and alignment.
|
| 486 |
+
if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None:
|
| 487 |
+
Bias = self.input_nodes[2]
|
| 488 |
+
bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout())
|
| 489 |
+
if op.gemm_kind != cutlass_lib.GemmKind.Universal3x:
|
| 490 |
+
if bias_layout != op.D.layout:
|
| 491 |
+
# For cutlass2, bias and output layout must match
|
| 492 |
+
return None
|
| 493 |
+
else:
|
| 494 |
+
op.C.layout = bias_layout
|
| 495 |
+
if not self.set_alignment(Bias.get_layout(), op.C):
|
| 496 |
+
return None
|
| 497 |
+
else:
|
| 498 |
+
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
| 499 |
+
op.C.element = cutlass_lib.DataType.void
|
| 500 |
+
else:
|
| 501 |
+
op.C.layout = op.D.layout
|
| 502 |
+
supports_evt: bool = self.supports_evt(op)
|
| 503 |
+
if (self.can_fuse_epilogue is not None) and (
|
| 504 |
+
self.can_fuse_epilogue != supports_evt
|
| 505 |
+
):
|
| 506 |
+
return None
|
| 507 |
+
if inductor_cuda_config.cutlass_only_evt_capable_ops and not supports_evt:
|
| 508 |
+
return None
|
| 509 |
+
return op
|
| 510 |
+
|
| 511 |
+
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined] # noqa: F821
|
| 512 |
+
assert cutlass_utils.try_import_cutlass()
|
| 513 |
+
import cutlass_library.gemm_operation as cutlass_gemm_op
|
| 514 |
+
import cutlass_library.library as cutlass_lib
|
| 515 |
+
|
| 516 |
+
ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
|
| 517 |
+
res: Dict[str, cutlass_gemm_op.GemmOperation] = dict()
|
| 518 |
+
num_3x_ops = 0
|
| 519 |
+
num_2x_ops = 0
|
| 520 |
+
for op_dict in ops.values():
|
| 521 |
+
for op_list in op_dict.values():
|
| 522 |
+
for op in op_list:
|
| 523 |
+
assert isinstance(op, cutlass_gemm_op.GemmOperation)
|
| 524 |
+
filter_res = self.filter_op(op)
|
| 525 |
+
if (
|
| 526 |
+
filter_res is not None
|
| 527 |
+
and res.get(filter_res.configuration_name(), None) is None
|
| 528 |
+
):
|
| 529 |
+
res[filter_res.configuration_name()] = filter_res
|
| 530 |
+
for op in res.values():
|
| 531 |
+
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
| 532 |
+
num_3x_ops += 1
|
| 533 |
+
else:
|
| 534 |
+
num_2x_ops += 1
|
| 535 |
+
log.debug(
|
| 536 |
+
"Got cutlass configs: total number of ops: %d, "
|
| 537 |
+
"total number of 3x ops: %d, total number of 2x ops: %d",
|
| 538 |
+
len(res),
|
| 539 |
+
num_3x_ops,
|
| 540 |
+
num_2x_ops,
|
| 541 |
+
)
|
| 542 |
+
return list(res.values())[: inductor_cuda_config.cutlass_max_profiling_configs]
|
| 543 |
+
|
| 544 |
+
def gemm_mode(self) -> str:
|
| 545 |
+
sizes = self.output_node.get_size()
|
| 546 |
+
if len(sizes) > 2:
|
| 547 |
+
return "cutlass::gemm::GemmUniversalMode::kBatched"
|
| 548 |
+
else:
|
| 549 |
+
return "cutlass::gemm::GemmUniversalMode::kGemm"
|
| 550 |
+
|
| 551 |
+
def render_gemm_arguments(
|
| 552 |
+
self,
|
| 553 |
+
argument_template: str,
|
| 554 |
+
epilogue_template: str,
|
| 555 |
+
should_swap_xw: bool,
|
| 556 |
+
X: IRNode,
|
| 557 |
+
W: IRNode,
|
| 558 |
+
Bias: IRNode,
|
| 559 |
+
Y: IRNode,
|
| 560 |
+
alpha: float,
|
| 561 |
+
beta: float,
|
| 562 |
+
kernel: CUDATemplateKernel,
|
| 563 |
+
epilogue_args,
|
| 564 |
+
) -> str:
|
| 565 |
+
options = dict(
|
| 566 |
+
alpha=self.alpha,
|
| 567 |
+
beta=self.beta,
|
| 568 |
+
X=X,
|
| 569 |
+
W=W,
|
| 570 |
+
Y=Y,
|
| 571 |
+
Bias=Bias,
|
| 572 |
+
template=self,
|
| 573 |
+
kernel=kernel,
|
| 574 |
+
M="M",
|
| 575 |
+
N="N",
|
| 576 |
+
epilogue_args=epilogue_args,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
if epilogue_template is not None:
|
| 580 |
+
if should_swap_xw:
|
| 581 |
+
# Swap
|
| 582 |
+
def clone_with_transposed_stride(node: IRNode) -> IRNode:
|
| 583 |
+
old_layout = node.get_layout()
|
| 584 |
+
new_stride = list(old_layout.stride)
|
| 585 |
+
new_stride[-2], new_stride[-1] = new_stride[-1], new_stride[-2]
|
| 586 |
+
new_layout = FixedLayout(
|
| 587 |
+
old_layout.device,
|
| 588 |
+
old_layout.dtype,
|
| 589 |
+
list(old_layout.size),
|
| 590 |
+
new_stride,
|
| 591 |
+
old_layout.offset,
|
| 592 |
+
)
|
| 593 |
+
return Buffer(node.get_name(), new_layout)
|
| 594 |
+
|
| 595 |
+
new_X = clone_with_transposed_stride(X)
|
| 596 |
+
new_W = clone_with_transposed_stride(W)
|
| 597 |
+
new_Bias = clone_with_transposed_stride(Bias)
|
| 598 |
+
new_Y = clone_with_transposed_stride(Y)
|
| 599 |
+
options["X"], options["W"], options["Bias"], options["Y"] = (
|
| 600 |
+
new_W,
|
| 601 |
+
new_X,
|
| 602 |
+
new_Bias,
|
| 603 |
+
new_Y,
|
| 604 |
+
)
|
| 605 |
+
options["M"], options["N"] = "N", "M"
|
| 606 |
+
|
| 607 |
+
epilogue_arguments = self._template_from_string(epilogue_template).render(
|
| 608 |
+
**options
|
| 609 |
+
)
|
| 610 |
+
arguments = self._template_from_string(argument_template).render(
|
| 611 |
+
epilogue_arguments=epilogue_arguments, **options
|
| 612 |
+
)
|
| 613 |
+
else:
|
| 614 |
+
arguments = self._template_from_string(GEMM_ARGS_CUTLASS_2X).render(
|
| 615 |
+
split_k=1, **options
|
| 616 |
+
)
|
| 617 |
+
return arguments
|
| 618 |
+
|
| 619 |
+
def render( # type: ignore[override]
|
| 620 |
+
self,
|
| 621 |
+
kernel: CUDATemplateKernel,
|
| 622 |
+
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
|
| 623 |
+
template_buffer_node: Optional[CUDATemplateBuffer] = None,
|
| 624 |
+
epilogue_nodes: Optional[List[IRNode]] = None,
|
| 625 |
+
**kwargs,
|
| 626 |
+
) -> str:
|
| 627 |
+
if epilogue_nodes is not None and len(epilogue_nodes) > 0:
|
| 628 |
+
assert self.can_fuse_epilogue and CUTLASSGemmTemplate.supports_evt(
|
| 629 |
+
op
|
| 630 |
+
), "op does not support EVT epilogue fusion"
|
| 631 |
+
assert (
|
| 632 |
+
template_buffer_node is not None
|
| 633 |
+
), "Template node is required for epilogue fusion"
|
| 634 |
+
assert isinstance(
|
| 635 |
+
template_buffer_node, CUDATemplateBuffer
|
| 636 |
+
), f"Template node has to be a CUDATemplateBuffer, is type {type(template_buffer_node)}"
|
| 637 |
+
assert (
|
| 638 |
+
template_buffer_node.name is not None
|
| 639 |
+
), "Output node has to be a Buffer with a name"
|
| 640 |
+
# This is the name of the output of the Matmul, before epilogues are applied.
|
| 641 |
+
# it is not necessarily materialized in global memory if we have an epilogue
|
| 642 |
+
|
| 643 |
+
template_output_node_name = (
|
| 644 |
+
template_buffer_node.name if template_buffer_node is not None else None
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
assert cutlass_utils.try_import_cutlass()
|
| 648 |
+
import cutlass_library.gemm_operation as cutlass_gemm_op
|
| 649 |
+
import cutlass_library.library as cutlass_lib
|
| 650 |
+
|
| 651 |
+
assert isinstance(
|
| 652 |
+
op, cutlass_gemm_op.GemmOperation
|
| 653 |
+
), "op argument is required and has to be an instance of GemmOperation"
|
| 654 |
+
if template_buffer_node is not None:
|
| 655 |
+
self.output_node = template_buffer_node
|
| 656 |
+
if epilogue_nodes is not None and len(epilogue_nodes) > 0:
|
| 657 |
+
self.output_node = cast(Buffer, epilogue_nodes[-1])
|
| 658 |
+
|
| 659 |
+
assert len(self.input_nodes) >= 2 and self.output_node is not None
|
| 660 |
+
X, W = self.input_nodes[0], self.input_nodes[1]
|
| 661 |
+
Y = self.output_node
|
| 662 |
+
Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2]
|
| 663 |
+
|
| 664 |
+
epilogue_template: Optional[str] = None
|
| 665 |
+
should_swap_xw: bool = False
|
| 666 |
+
epilogue_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}"
|
| 667 |
+
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
| 668 |
+
if Bias is not None and self.has_tma_epilogue(op):
|
| 669 |
+
if self.should_swap_XW(Bias, self.beta):
|
| 670 |
+
# TMA epilogue requires bias vector in column major to get best perf.
|
| 671 |
+
op = self.swap_XW(op)
|
| 672 |
+
should_swap_xw = True
|
| 673 |
+
if epilogue_nodes is not None and len(epilogue_nodes) > 0:
|
| 674 |
+
epilogue_args = (
|
| 675 |
+
CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(
|
| 676 |
+
cast(str, template_output_node_name), epilogue_nodes
|
| 677 |
+
)
|
| 678 |
+
)
|
| 679 |
+
epilogue_template = GEMM_ARGS_CUTLASS_3X_EPILOGUE
|
| 680 |
+
argument_template = GEMM_ARGS_CUTLASS_3X
|
| 681 |
+
else:
|
| 682 |
+
# TODO: Support split_k.
|
| 683 |
+
argument_template = GEMM_ARGS_CUTLASS_2X
|
| 684 |
+
|
| 685 |
+
instance_definition, instance_type = self.define_gemm_instance(
|
| 686 |
+
op, cast(str, template_output_node_name), epilogue_nodes
|
| 687 |
+
)
|
| 688 |
+
options = dict(
|
| 689 |
+
alpha=self.alpha,
|
| 690 |
+
beta=self.beta,
|
| 691 |
+
X=X,
|
| 692 |
+
W=W,
|
| 693 |
+
Y=Y,
|
| 694 |
+
Bias=Bias,
|
| 695 |
+
epilogue_template=epilogue_template,
|
| 696 |
+
argument_template=argument_template,
|
| 697 |
+
should_swap_xw=should_swap_xw,
|
| 698 |
+
template=self,
|
| 699 |
+
kernel=kernel,
|
| 700 |
+
instance_definition=instance_definition,
|
| 701 |
+
instance_type=instance_type,
|
| 702 |
+
input_reorder=self.input_reorder,
|
| 703 |
+
epilogue_args=epilogue_args,
|
| 704 |
+
)
|
| 705 |
+
res = self._template_from_string(GEMM_TEMPLATE).render(**options)
|
| 706 |
+
return res
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/memory_planning.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import dataclasses
|
| 5 |
+
import itertools
|
| 6 |
+
import pprint
|
| 7 |
+
from typing import Any, Dict, Iterable, List, Optional, Protocol
|
| 8 |
+
|
| 9 |
+
import sympy
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from .. import config, ir
|
| 13 |
+
from ..utils import cache_on_self, CachedMethod, IndentedBuffer
|
| 14 |
+
from ..virtualized import V
|
| 15 |
+
|
| 16 |
+
from .wrapper import (
|
| 17 |
+
AllocateLine,
|
| 18 |
+
FreeIfNotReusedLine,
|
| 19 |
+
MemoryPlanningLine,
|
| 20 |
+
NullLine,
|
| 21 |
+
ReuseLine,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
ALIGN_BYTES = 64
|
| 26 |
+
assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _align(nbytes):
|
| 30 |
+
"""Round up to the nearest multiple of ALIGN_BYTES"""
|
| 31 |
+
return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _is_aligned(v: sympy.Expr):
|
| 35 |
+
"""v can be statically proven to be a multiple of ALIGN_BYTES"""
|
| 36 |
+
if isinstance(v, (sympy.Add, sympy.Max)):
|
| 37 |
+
return all(map(_is_aligned, v.args))
|
| 38 |
+
return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class align(sympy.Function):
|
| 42 |
+
"""Symbolically round up to the nearest multiple of ALIGN_BYTES"""
|
| 43 |
+
|
| 44 |
+
nargs = (1,)
|
| 45 |
+
is_integer = True
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def eval(cls, value):
|
| 49 |
+
if isinstance(value, (int, sympy.Integer)):
|
| 50 |
+
return _align(int(value))
|
| 51 |
+
if _is_aligned(value):
|
| 52 |
+
return value
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclasses.dataclass
|
| 56 |
+
class LiveRange:
|
| 57 |
+
"""
|
| 58 |
+
A range where a given tensor is live. Begin and end are both counters
|
| 59 |
+
representing points in the program of grouped memory operations.
|
| 60 |
+
Begin is inclusive, end is exclusive.
|
| 61 |
+
|
| 62 |
+
Invariant: begin <= end
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
begin: float # int | ±inf
|
| 66 |
+
end: float # int | ±inf
|
| 67 |
+
|
| 68 |
+
def contains(self, other: LiveRange):
|
| 69 |
+
"""Is other entirely within self"""
|
| 70 |
+
return self.begin <= other.begin and other.end <= self.end
|
| 71 |
+
|
| 72 |
+
def join(self, other: LiveRange):
|
| 73 |
+
"""Combine two ranges using a union operation"""
|
| 74 |
+
return LiveRange(min(self.begin, other.begin), max(self.end, other.end))
|
| 75 |
+
|
| 76 |
+
def __len__(self):
|
| 77 |
+
return self.end - self.begin
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class LiveRanges:
|
| 81 |
+
"""
|
| 82 |
+
A collection of LiveRange regions, allowing for non-contiguous
|
| 83 |
+
live regions.
|
| 84 |
+
|
| 85 |
+
Invariant: LiveRanges.ranges is in sorted order and non-overlapping
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, ranges: Iterable[LiveRange]):
|
| 89 |
+
ranges = [*sorted(ranges, key=lambda x: x.begin)]
|
| 90 |
+
self.ranges = ranges[:1]
|
| 91 |
+
for r in ranges[1:]:
|
| 92 |
+
assert self.ranges[-1].begin <= r.begin
|
| 93 |
+
if self.ranges[-1].end >= r.begin:
|
| 94 |
+
self.ranges[-1] = LiveRange.join(self.ranges[-1], r)
|
| 95 |
+
else:
|
| 96 |
+
self.ranges.append(r)
|
| 97 |
+
|
| 98 |
+
def overlaps(self, other: LiveRanges):
|
| 99 |
+
"""Check if any pair of ranges in self and other overlap"""
|
| 100 |
+
left = collections.deque(self.ranges)
|
| 101 |
+
right = collections.deque(other.ranges)
|
| 102 |
+
while left and right:
|
| 103 |
+
if left[0].begin > right[0].begin:
|
| 104 |
+
left, right = right, left
|
| 105 |
+
assert left[0].begin <= right[0].begin
|
| 106 |
+
if left[0].end > right[0].begin:
|
| 107 |
+
return True
|
| 108 |
+
left.popleft()
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def begin(self):
|
| 113 |
+
return self.ranges[0].begin
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def end(self):
|
| 117 |
+
return self.ranges[-1].end
|
| 118 |
+
|
| 119 |
+
def __repr__(self):
|
| 120 |
+
return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class AllocationTreeNode:
|
| 124 |
+
"""
|
| 125 |
+
Abstract base class for nodes in allocation pool.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def allocate(self, block: Allocation, is_last: bool) -> bool:
|
| 129 |
+
"""
|
| 130 |
+
Try to assign block to a memory location in this bool. Return True if
|
| 131 |
+
an assignment was made.
|
| 132 |
+
"""
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
def get_live_ranges(self) -> LiveRanges:
|
| 136 |
+
"""Aggregate LiveRanges for all objects below this in tree"""
|
| 137 |
+
raise NotImplementedError()
|
| 138 |
+
|
| 139 |
+
def get_size_hint(self) -> int:
|
| 140 |
+
"""Number of bytes used for example inputs"""
|
| 141 |
+
raise NotImplementedError()
|
| 142 |
+
|
| 143 |
+
def get_symbolic_size(self) -> sympy.Expr:
|
| 144 |
+
"""Number of bytes needed at runtime"""
|
| 145 |
+
raise NotImplementedError()
|
| 146 |
+
|
| 147 |
+
def finalize(self, pool, offset) -> AllocationTreeNode:
|
| 148 |
+
"""Called after all allocations have been made"""
|
| 149 |
+
return self
|
| 150 |
+
|
| 151 |
+
def is_empty(self):
|
| 152 |
+
return False
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@dataclasses.dataclass
|
| 156 |
+
class Allocation(AllocationTreeNode):
|
| 157 |
+
"""
|
| 158 |
+
Represents memory allocated to a given node in the allocation pool.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
node: ir.Buffer
|
| 162 |
+
live_range: LiveRange
|
| 163 |
+
size_hint: int
|
| 164 |
+
symbolic_size: sympy.Expr
|
| 165 |
+
allocated: bool = False
|
| 166 |
+
pool: Optional[AllocationPool] = None
|
| 167 |
+
offset: Optional[sympy.Expr] = None
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def device(self):
|
| 171 |
+
return self.node.get_device()
|
| 172 |
+
|
| 173 |
+
def get_live_ranges(self):
|
| 174 |
+
return LiveRanges([self.live_range])
|
| 175 |
+
|
| 176 |
+
def get_size_hint(self):
|
| 177 |
+
return self.size_hint
|
| 178 |
+
|
| 179 |
+
def get_symbolic_size(self):
|
| 180 |
+
return self.symbolic_size
|
| 181 |
+
|
| 182 |
+
def mark_allocated(self):
|
| 183 |
+
assert not self.allocated
|
| 184 |
+
self.allocated = True
|
| 185 |
+
|
| 186 |
+
def finalize(self, pool, offset):
|
| 187 |
+
assert self.pool is None and self.offset is None
|
| 188 |
+
self.pool = pool
|
| 189 |
+
self.offset = offset
|
| 190 |
+
return self
|
| 191 |
+
|
| 192 |
+
def codegen_alloc_from_pool(self, wrapper):
|
| 193 |
+
assert self.pool
|
| 194 |
+
node = self.node
|
| 195 |
+
shape = tuple(node.get_size())
|
| 196 |
+
stride = tuple(node.get_stride())
|
| 197 |
+
return wrapper.codegen_alloc_from_pool(
|
| 198 |
+
self.pool.name, self.offset, node.get_dtype(), shape, stride
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def __repr__(self):
|
| 202 |
+
return (
|
| 203 |
+
f"{self.__class__.__name__}("
|
| 204 |
+
f"node={self.node.get_name()}, "
|
| 205 |
+
f"live_range={self.live_range}, "
|
| 206 |
+
f"size_hint={self.size_hint}, "
|
| 207 |
+
f"symbolic_size={self.symbolic_size}, "
|
| 208 |
+
f"pool={self.pool.name if self.pool else None}, "
|
| 209 |
+
f"offset={self.offset})"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@dataclasses.dataclass
|
| 214 |
+
class Empty(AllocationTreeNode):
|
| 215 |
+
"""
|
| 216 |
+
Placeholder to represent empty space in the allocation pool.
|
| 217 |
+
Only exists to get the size_hint correct in parent nodes.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
size_hint: int
|
| 221 |
+
|
| 222 |
+
def get_live_ranges(self):
|
| 223 |
+
return LiveRanges([])
|
| 224 |
+
|
| 225 |
+
def get_size_hint(self):
|
| 226 |
+
return self.size_hint
|
| 227 |
+
|
| 228 |
+
def get_symbolic_size(self):
|
| 229 |
+
return 0
|
| 230 |
+
|
| 231 |
+
def is_empty(self):
|
| 232 |
+
return True
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class MemorySplitProtocol(Protocol):
|
| 236 |
+
get_live_ranges: CachedMethod[[], LiveRanges]
|
| 237 |
+
get_size_hint: CachedMethod[[], int]
|
| 238 |
+
get_symbolic_size: CachedMethod[[], sympy.Expr]
|
| 239 |
+
|
| 240 |
+
def _allocate(self, block: Allocation, is_last: bool) -> bool:
|
| 241 |
+
...
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class ClearCacheOnAllocateMixin(MemorySplitProtocol):
|
| 245 |
+
"""
|
| 246 |
+
Helper to assist in caching get_live_ranges, get_size_hint, and
|
| 247 |
+
get_symbolic_size.
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
def allocate(self, block: Allocation, is_last: bool):
|
| 251 |
+
is_allocated = self._allocate(block, is_last)
|
| 252 |
+
if is_allocated:
|
| 253 |
+
self.clear_cache()
|
| 254 |
+
return is_allocated
|
| 255 |
+
|
| 256 |
+
def clear_cache(self):
|
| 257 |
+
self.get_live_ranges.clear_cache(self)
|
| 258 |
+
self.get_size_hint.clear_cache(self)
|
| 259 |
+
self.get_symbolic_size.clear_cache(self)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@dataclasses.dataclass
|
| 263 |
+
class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
|
| 264 |
+
"""
|
| 265 |
+
Contains a list of allocations not overlapping in LiveRanges.
|
| 266 |
+
|
| 267 |
+
Invariant: no pair (a,b) in self.allocations will have:
|
| 268 |
+
a.get_live_ranges().overlaps(b.get_live_ranges())
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
allocations: List[AllocationTreeNode]
|
| 272 |
+
|
| 273 |
+
def _allocate(self, block: Allocation, is_last: bool):
|
| 274 |
+
slot_size = self.get_size_hint()
|
| 275 |
+
block_size = block.get_size_hint()
|
| 276 |
+
if not is_last and block_size > slot_size:
|
| 277 |
+
return False # doesn't fit
|
| 278 |
+
|
| 279 |
+
block_live = block.get_live_ranges()
|
| 280 |
+
overlapping = [
|
| 281 |
+
s for s in self.allocations if s.get_live_ranges().overlaps(block_live)
|
| 282 |
+
]
|
| 283 |
+
if len(overlapping) > 1:
|
| 284 |
+
# TODO(jansel): we could try harder here by merging overlapping in space
|
| 285 |
+
return False
|
| 286 |
+
elif len(overlapping) == 1:
|
| 287 |
+
return overlapping[0].allocate(block, is_last)
|
| 288 |
+
else:
|
| 289 |
+
block.mark_allocated()
|
| 290 |
+
|
| 291 |
+
if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty):
|
| 292 |
+
self.allocations.pop()
|
| 293 |
+
|
| 294 |
+
if slot_size == block_size:
|
| 295 |
+
# perfect fit
|
| 296 |
+
self.allocations.append(block)
|
| 297 |
+
elif slot_size > block_size:
|
| 298 |
+
self.allocations.append(
|
| 299 |
+
SpatialSplit.create(block, slot_size - block_size)
|
| 300 |
+
)
|
| 301 |
+
else: # grow this allocation
|
| 302 |
+
assert is_last
|
| 303 |
+
self.allocations = [
|
| 304 |
+
*(
|
| 305 |
+
SpatialSplit.create(a, block_size - slot_size)
|
| 306 |
+
for a in self.allocations
|
| 307 |
+
),
|
| 308 |
+
block,
|
| 309 |
+
]
|
| 310 |
+
return True
|
| 311 |
+
|
| 312 |
+
@cache_on_self
|
| 313 |
+
def get_live_ranges(self) -> LiveRanges:
|
| 314 |
+
return LiveRanges(
|
| 315 |
+
itertools.chain.from_iterable(
|
| 316 |
+
x.get_live_ranges().ranges for x in self.allocations
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
@cache_on_self
|
| 321 |
+
def get_size_hint(self) -> int:
|
| 322 |
+
if not self.allocations:
|
| 323 |
+
return 0
|
| 324 |
+
return max(x.get_size_hint() for x in self.allocations)
|
| 325 |
+
|
| 326 |
+
@cache_on_self
|
| 327 |
+
def get_symbolic_size(self) -> sympy.Expr:
|
| 328 |
+
if not self.allocations:
|
| 329 |
+
return 0 # type: ignore[return-value]
|
| 330 |
+
return sympy.Max(*[x.get_symbolic_size() for x in self.allocations])
|
| 331 |
+
|
| 332 |
+
def is_empty(self):
|
| 333 |
+
return len(self.allocations) == 1 and self.allocations[0].is_empty()
|
| 334 |
+
|
| 335 |
+
def finalize(self, pool, offset):
|
| 336 |
+
self.allocations = [block.finalize(pool, offset) for block in self.allocations]
|
| 337 |
+
self.clear_cache()
|
| 338 |
+
if len(self.allocations) == 1:
|
| 339 |
+
return self.allocations[0]
|
| 340 |
+
return self
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@dataclasses.dataclass
|
| 344 |
+
class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode):
|
| 345 |
+
"""
|
| 346 |
+
Contains two allocations, left and right, that do not overlap in space.
|
| 347 |
+
Right will be allocated immediately after left in memory.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
left: TemporalSplit
|
| 351 |
+
right: TemporalSplit
|
| 352 |
+
|
| 353 |
+
@staticmethod
|
| 354 |
+
def create(left, extra_space):
|
| 355 |
+
assert isinstance(left, AllocationTreeNode)
|
| 356 |
+
assert isinstance(extra_space, int) and extra_space >= 1
|
| 357 |
+
return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)]))
|
| 358 |
+
|
| 359 |
+
def _allocate(self, block: Allocation, is_last: bool):
|
| 360 |
+
return self.left.allocate(block, False) or self.right.allocate(block, is_last)
|
| 361 |
+
|
| 362 |
+
@cache_on_self
|
| 363 |
+
def get_live_ranges(self):
|
| 364 |
+
return LiveRanges(
|
| 365 |
+
itertools.chain(
|
| 366 |
+
self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges
|
| 367 |
+
)
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
@cache_on_self
|
| 371 |
+
def get_size_hint(self) -> int:
|
| 372 |
+
return _align(self.left.get_size_hint()) + self.right.get_size_hint()
|
| 373 |
+
|
| 374 |
+
@cache_on_self
|
| 375 |
+
def get_symbolic_size(self) -> sympy.Expr:
|
| 376 |
+
return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size()
|
| 377 |
+
|
| 378 |
+
def finalize(self, pool, offset):
|
| 379 |
+
self.left = self.left.finalize(pool, offset)
|
| 380 |
+
self.right = self.right.finalize(
|
| 381 |
+
pool, offset + align(self.left.get_symbolic_size())
|
| 382 |
+
)
|
| 383 |
+
self.clear_cache()
|
| 384 |
+
if self.right.is_empty():
|
| 385 |
+
return self.left
|
| 386 |
+
return self
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
@dataclasses.dataclass
|
| 390 |
+
class AllocationPool:
|
| 391 |
+
"""
|
| 392 |
+
Represents a pool of allocations that will be generated by a single
|
| 393 |
+
call to torch.empty.
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
device: torch.device
|
| 397 |
+
root: TemporalSplit
|
| 398 |
+
can_expand: bool = True
|
| 399 |
+
restrict_live_range: Optional[LiveRange] = None
|
| 400 |
+
name: Optional[str] = None
|
| 401 |
+
names_to_del: List[str] = dataclasses.field(default_factory=list)
|
| 402 |
+
creation_cache: Dict[str, str] = dataclasses.field(default_factory=dict)
|
| 403 |
+
|
| 404 |
+
def allocate(self, block: Allocation, is_last: bool):
|
| 405 |
+
if self.restrict_live_range and not self.restrict_live_range.contains(
|
| 406 |
+
block.live_range
|
| 407 |
+
):
|
| 408 |
+
return False
|
| 409 |
+
|
| 410 |
+
is_last = self.can_expand and is_last
|
| 411 |
+
if self.root.allocate(block, is_last):
|
| 412 |
+
return True
|
| 413 |
+
|
| 414 |
+
if is_last:
|
| 415 |
+
return self.allocate_at_end(block)
|
| 416 |
+
|
| 417 |
+
return False
|
| 418 |
+
|
| 419 |
+
def allocate_at_end(self, block):
|
| 420 |
+
block.mark_allocated()
|
| 421 |
+
self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))])
|
| 422 |
+
return True
|
| 423 |
+
|
| 424 |
+
def finalize(self, name):
|
| 425 |
+
assert not self.name
|
| 426 |
+
self.name = name
|
| 427 |
+
self.names_to_del.append(name)
|
| 428 |
+
self.root.finalize(self, 0)
|
| 429 |
+
|
| 430 |
+
def codegen_create(self, wrapper, code: IndentedBuffer):
|
| 431 |
+
assert self.name
|
| 432 |
+
nbytes = self.root.get_symbolic_size()
|
| 433 |
+
for block in self.root.allocations:
|
| 434 |
+
if isinstance(block, Allocation) and nbytes == block.get_symbolic_size():
|
| 435 |
+
# optimization: fuse first allocation and pool creation
|
| 436 |
+
node = block.node
|
| 437 |
+
code.writeline(
|
| 438 |
+
wrapper.make_allocation(
|
| 439 |
+
self.name,
|
| 440 |
+
device=self.device,
|
| 441 |
+
dtype=node.get_dtype(),
|
| 442 |
+
shape=tuple(node.get_size()),
|
| 443 |
+
stride=tuple(node.get_stride()),
|
| 444 |
+
)
|
| 445 |
+
)
|
| 446 |
+
self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name
|
| 447 |
+
return
|
| 448 |
+
else:
|
| 449 |
+
code.writeline(
|
| 450 |
+
wrapper.make_allocation(
|
| 451 |
+
self.name,
|
| 452 |
+
device=self.device,
|
| 453 |
+
dtype=torch.uint8,
|
| 454 |
+
shape=(nbytes,),
|
| 455 |
+
stride=(1,),
|
| 456 |
+
)
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
def codegen_destroy(self, wrapper, code: IndentedBuffer):
|
| 460 |
+
code.writeline(wrapper.make_free_by_names(self.names_to_del))
|
| 461 |
+
|
| 462 |
+
def __eq__(self, other):
|
| 463 |
+
return self is other
|
| 464 |
+
|
| 465 |
+
def __hash__(self):
|
| 466 |
+
return id(self)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
@dataclasses.dataclass
|
| 470 |
+
class AllocationPools:
|
| 471 |
+
"""
|
| 472 |
+
Collection of many AllocationPool objects grouped by device.
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
device_to_pools: Dict[torch.device, List[AllocationPool]] = dataclasses.field(
|
| 476 |
+
default_factory=dict
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
def get_pools(self, block):
|
| 480 |
+
if block.device not in self.device_to_pools:
|
| 481 |
+
self.device_to_pools[block.device] = []
|
| 482 |
+
return self.device_to_pools[block.device]
|
| 483 |
+
|
| 484 |
+
def allocate(self, block: Allocation):
|
| 485 |
+
pools = self.get_pools(block)
|
| 486 |
+
|
| 487 |
+
for pool in pools:
|
| 488 |
+
if pool.allocate(block, is_last=pool is pools[-1]):
|
| 489 |
+
return
|
| 490 |
+
|
| 491 |
+
# everything is full, make a new pool
|
| 492 |
+
pools.append(
|
| 493 |
+
AllocationPool(
|
| 494 |
+
block.device,
|
| 495 |
+
TemporalSplit([block]),
|
| 496 |
+
can_expand=config.memory_pool != "none",
|
| 497 |
+
)
|
| 498 |
+
)
|
| 499 |
+
block.mark_allocated()
|
| 500 |
+
|
| 501 |
+
def allocate_output(self, block: Allocation):
|
| 502 |
+
"""Outputs get different pools so memory gets freed properly"""
|
| 503 |
+
pools = self.get_pools(block)
|
| 504 |
+
if pools and config.memory_pool in ("outputs", "combined"):
|
| 505 |
+
pools[-1].allocate_at_end(block)
|
| 506 |
+
else:
|
| 507 |
+
# create a new pool
|
| 508 |
+
block.mark_allocated()
|
| 509 |
+
pools.append(
|
| 510 |
+
AllocationPool(
|
| 511 |
+
block.device,
|
| 512 |
+
TemporalSplit([block]),
|
| 513 |
+
can_expand=config.memory_pool == "combined",
|
| 514 |
+
)
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
def finalize(self):
|
| 518 |
+
"""Called at the end of allocation process"""
|
| 519 |
+
for i, pool in enumerate(
|
| 520 |
+
itertools.chain.from_iterable(self.device_to_pools.values())
|
| 521 |
+
):
|
| 522 |
+
pool.finalize(f"pool{i}")
|
| 523 |
+
|
| 524 |
+
def pprint(self):
|
| 525 |
+
for pool in itertools.chain.from_iterable(self.device_to_pools.values()):
|
| 526 |
+
print()
|
| 527 |
+
print(pool.name)
|
| 528 |
+
print(pool.root.get_live_ranges())
|
| 529 |
+
pprint.pprint(pool.root)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class BufferGroup:
|
| 533 |
+
"""
|
| 534 |
+
Due to inplace reuse an allocated buffer can have many names.
|
| 535 |
+
This tracks these collections of buffers sharing underlying memory.
|
| 536 |
+
"""
|
| 537 |
+
|
| 538 |
+
def __init__(self, node: ir.Buffer):
|
| 539 |
+
self.node = node
|
| 540 |
+
self.names = [node.get_name()]
|
| 541 |
+
self.is_output = False
|
| 542 |
+
self.allocation: Optional[Allocation] = None
|
| 543 |
+
self.live_range = LiveRange(float("inf"), -float("inf"))
|
| 544 |
+
|
| 545 |
+
def update_usage(self, timestep: int):
|
| 546 |
+
"""Expand self.live_range to include timestep"""
|
| 547 |
+
self.live_range = LiveRange(
|
| 548 |
+
min(timestep, self.live_range.begin),
|
| 549 |
+
max(timestep, self.live_range.end),
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
def sym_nbytes(self):
|
| 553 |
+
return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize
|
| 554 |
+
|
| 555 |
+
def make_allocation(self):
|
| 556 |
+
assert not self.allocation, "multiple allocations"
|
| 557 |
+
assert isinstance(self.live_range.begin, int), "live ranges not computed"
|
| 558 |
+
nbytes = self.sym_nbytes()
|
| 559 |
+
# For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have
|
| 560 |
+
# size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored.
|
| 561 |
+
size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64)
|
| 562 |
+
self.allocation = Allocation(
|
| 563 |
+
self.node,
|
| 564 |
+
self.live_range,
|
| 565 |
+
size_hint=size_hint,
|
| 566 |
+
symbolic_size=nbytes,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
def __repr__(self):
|
| 570 |
+
return (
|
| 571 |
+
f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, "
|
| 572 |
+
f"live_range={self.live_range}"
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
@dataclasses.dataclass
|
| 577 |
+
class PoolMemoryPlanningLine(MemoryPlanningLine):
|
| 578 |
+
"""Abstract base class for {Alloc,Dealloc}FromPoolLine"""
|
| 579 |
+
|
| 580 |
+
group: BufferGroup
|
| 581 |
+
timestep: Optional[int] = None
|
| 582 |
+
|
| 583 |
+
@property
|
| 584 |
+
def node(self):
|
| 585 |
+
return self.group.node
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
@dataclasses.dataclass
|
| 589 |
+
class AllocFromPoolLine(PoolMemoryPlanningLine):
|
| 590 |
+
"""Similar to AllocationLine, but takes memory from a pool"""
|
| 591 |
+
|
| 592 |
+
is_first_pool_usage: bool = False
|
| 593 |
+
|
| 594 |
+
def codegen(self, code: IndentedBuffer):
|
| 595 |
+
allocation = self.group.allocation
|
| 596 |
+
assert allocation and allocation.pool
|
| 597 |
+
pool = allocation.pool
|
| 598 |
+
name = self.node.get_name()
|
| 599 |
+
|
| 600 |
+
if self.is_first_pool_usage:
|
| 601 |
+
pool.codegen_create(self.wrapper, code)
|
| 602 |
+
|
| 603 |
+
pool.names_to_del.extend(self.group.names)
|
| 604 |
+
alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper)
|
| 605 |
+
if alloc_from_pool in pool.creation_cache:
|
| 606 |
+
code.writeline(
|
| 607 |
+
self.wrapper.make_tensor_alias(
|
| 608 |
+
name, pool.creation_cache[alloc_from_pool], "alloc"
|
| 609 |
+
)
|
| 610 |
+
)
|
| 611 |
+
else:
|
| 612 |
+
pool.creation_cache[alloc_from_pool] = name
|
| 613 |
+
code.writeline(
|
| 614 |
+
f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}"
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
@dataclasses.dataclass
|
| 619 |
+
class DeallocFromPoolLine(PoolMemoryPlanningLine):
|
| 620 |
+
"""Similar to FreeIfNotReusedLine, but takes memory from a pool"""
|
| 621 |
+
|
| 622 |
+
is_last_pool_usage: bool = False
|
| 623 |
+
|
| 624 |
+
def codegen(self, code: IndentedBuffer):
|
| 625 |
+
if self.is_last_pool_usage:
|
| 626 |
+
assert self.group.allocation and self.group.allocation.pool
|
| 627 |
+
self.group.allocation.pool.codegen_destroy(self.wrapper, code)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
@dataclasses.dataclass
|
| 631 |
+
class MemoryPlanner:
|
| 632 |
+
"""
|
| 633 |
+
Coordination object to run memory planning passes during wrapper
|
| 634 |
+
codegen.
|
| 635 |
+
"""
|
| 636 |
+
|
| 637 |
+
wrapper: Any
|
| 638 |
+
pools: AllocationPools = dataclasses.field(default_factory=AllocationPools)
|
| 639 |
+
buffer_groups: Optional[List[BufferGroup]] = None
|
| 640 |
+
|
| 641 |
+
def plan(self, lines: List[Any]) -> List[Any]:
|
| 642 |
+
"""Call all the memory planning passes in sequence"""
|
| 643 |
+
lines = [*lines]
|
| 644 |
+
self.drop_removed_buffers(lines)
|
| 645 |
+
self.convert_to_pool_lines(lines)
|
| 646 |
+
self.compute_live_ranges(lines)
|
| 647 |
+
self.allocate_groups()
|
| 648 |
+
self.mark_first_last_usage(lines)
|
| 649 |
+
return lines
|
| 650 |
+
|
| 651 |
+
def drop_removed_buffers(self, lines):
|
| 652 |
+
"""
|
| 653 |
+
Replace any memory planning lines in V.graph.removed_buffers with NullLine
|
| 654 |
+
"""
|
| 655 |
+
# drop any removed buffers
|
| 656 |
+
for i, line in enumerate(lines):
|
| 657 |
+
if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)):
|
| 658 |
+
if line.node.get_name() in V.graph.removed_buffers:
|
| 659 |
+
lines[i] = NullLine(self.wrapper)
|
| 660 |
+
|
| 661 |
+
def compute_buffer_groups(self, lines):
|
| 662 |
+
"""
|
| 663 |
+
Populates self.buffer_groups with BufferGroup objects that join
|
| 664 |
+
allocations with common storage (due to inplace reuse) into a
|
| 665 |
+
single object.
|
| 666 |
+
"""
|
| 667 |
+
name_to_group = {}
|
| 668 |
+
for line in lines:
|
| 669 |
+
if isinstance(line, AllocateLine):
|
| 670 |
+
name = line.node.get_name()
|
| 671 |
+
assert name not in name_to_group
|
| 672 |
+
name_to_group[name] = BufferGroup(line.node)
|
| 673 |
+
elif isinstance(line, ReuseLine):
|
| 674 |
+
old_name = line.node.get_name()
|
| 675 |
+
new_name = line.reused_as.get_name()
|
| 676 |
+
assert new_name not in name_to_group
|
| 677 |
+
# TODO(jansel): we should support reusing buffers created via ExternKernelAlloc
|
| 678 |
+
if old_name in name_to_group:
|
| 679 |
+
name_to_group[old_name].names.append(new_name)
|
| 680 |
+
name_to_group[new_name] = name_to_group[old_name]
|
| 681 |
+
|
| 682 |
+
outputs = set(V.graph.get_output_names())
|
| 683 |
+
unique_groups = [*{id(g): g for g in name_to_group.values()}.values()]
|
| 684 |
+
for group in unique_groups:
|
| 685 |
+
group.is_output = any(x in outputs for x in group.names)
|
| 686 |
+
|
| 687 |
+
assert self.buffer_groups is None
|
| 688 |
+
self.buffer_groups = unique_groups
|
| 689 |
+
return name_to_group
|
| 690 |
+
|
| 691 |
+
def convert_to_pool_lines(self, lines):
|
| 692 |
+
"""
|
| 693 |
+
Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their
|
| 694 |
+
pool-based counterparts.
|
| 695 |
+
"""
|
| 696 |
+
name_to_group = self.compute_buffer_groups(lines)
|
| 697 |
+
for i, line in enumerate(lines):
|
| 698 |
+
if isinstance(line, AllocateLine):
|
| 699 |
+
if line.node.get_name() in name_to_group:
|
| 700 |
+
lines[i] = AllocFromPoolLine(
|
| 701 |
+
self.wrapper, name_to_group[line.node.get_name()]
|
| 702 |
+
)
|
| 703 |
+
elif isinstance(line, FreeIfNotReusedLine):
|
| 704 |
+
assert not line.is_reused
|
| 705 |
+
if line.node.get_name() in name_to_group:
|
| 706 |
+
lines[i] = DeallocFromPoolLine(
|
| 707 |
+
self.wrapper, name_to_group[line.node.get_name()]
|
| 708 |
+
)
|
| 709 |
+
elif isinstance(line, ReuseLine):
|
| 710 |
+
if line.node.get_name() in name_to_group:
|
| 711 |
+
line.delete_old = False
|
| 712 |
+
|
| 713 |
+
def compute_live_ranges(self, lines):
|
| 714 |
+
"""Populate every BufferGroup.live_ranges field based on first/last usage"""
|
| 715 |
+
timestep = 0
|
| 716 |
+
worklist = collections.deque(lines)
|
| 717 |
+
while worklist:
|
| 718 |
+
if isinstance(worklist[0], MemoryPlanningLine):
|
| 719 |
+
timestep += 1
|
| 720 |
+
while worklist and isinstance(worklist[0], MemoryPlanningLine):
|
| 721 |
+
line = worklist.popleft()
|
| 722 |
+
if isinstance(line, PoolMemoryPlanningLine):
|
| 723 |
+
line.group.update_usage(timestep)
|
| 724 |
+
line.timestep = timestep
|
| 725 |
+
else:
|
| 726 |
+
worklist.popleft()
|
| 727 |
+
|
| 728 |
+
timestep += 1
|
| 729 |
+
assert self.buffer_groups is not None
|
| 730 |
+
for group in self.buffer_groups:
|
| 731 |
+
if group.is_output:
|
| 732 |
+
group.update_usage(timestep)
|
| 733 |
+
|
| 734 |
+
def allocate_groups(self):
|
| 735 |
+
"""
|
| 736 |
+
Assign every allocation to a specific location in a specific AllocationPool.
|
| 737 |
+
"""
|
| 738 |
+
assert config.memory_pool in ("none", "intermediates", "outputs", "combined")
|
| 739 |
+
assert self.buffer_groups is not None
|
| 740 |
+
|
| 741 |
+
for group in self.buffer_groups:
|
| 742 |
+
group.make_allocation()
|
| 743 |
+
|
| 744 |
+
outputs: List[Allocation] = []
|
| 745 |
+
intermediates: List[Allocation] = []
|
| 746 |
+
for group in self.buffer_groups:
|
| 747 |
+
assert group.allocation
|
| 748 |
+
if group.is_output and config.memory_pool != "combined":
|
| 749 |
+
outputs.append(group.allocation)
|
| 750 |
+
else:
|
| 751 |
+
intermediates.append(group.allocation)
|
| 752 |
+
|
| 753 |
+
for block in sorted(
|
| 754 |
+
outputs,
|
| 755 |
+
key=lambda x: (
|
| 756 |
+
x.size_hint,
|
| 757 |
+
-len(x.live_range),
|
| 758 |
+
),
|
| 759 |
+
):
|
| 760 |
+
self.pools.allocate_output(block)
|
| 761 |
+
|
| 762 |
+
for block in sorted(
|
| 763 |
+
intermediates,
|
| 764 |
+
key=lambda x: (
|
| 765 |
+
-x.size_hint,
|
| 766 |
+
-len(x.live_range),
|
| 767 |
+
),
|
| 768 |
+
):
|
| 769 |
+
self.pools.allocate(block)
|
| 770 |
+
|
| 771 |
+
self.pools.finalize()
|
| 772 |
+
|
| 773 |
+
def mark_first_last_usage(self, lines):
|
| 774 |
+
"""
|
| 775 |
+
Populate the AllocFromPoolLine.is_first_pool_usage and
|
| 776 |
+
DeallocFromPoolLine.is_last_pool_usage fields so that pools
|
| 777 |
+
are created/destroyed.
|
| 778 |
+
"""
|
| 779 |
+
seen = set()
|
| 780 |
+
for line in lines:
|
| 781 |
+
if isinstance(line, AllocFromPoolLine):
|
| 782 |
+
assert line.group.allocation
|
| 783 |
+
pool = line.group.allocation.pool
|
| 784 |
+
assert pool is not None
|
| 785 |
+
if pool not in seen:
|
| 786 |
+
line.is_first_pool_usage = True
|
| 787 |
+
seen.add(pool)
|
| 788 |
+
|
| 789 |
+
seen = set()
|
| 790 |
+
for line in reversed(lines):
|
| 791 |
+
if isinstance(line, DeallocFromPoolLine):
|
| 792 |
+
assert line.group.allocation
|
| 793 |
+
pool = line.group.allocation.pool
|
| 794 |
+
assert pool is not None
|
| 795 |
+
if pool not in seen:
|
| 796 |
+
line.is_last_pool_usage = (
|
| 797 |
+
pool.root.get_live_ranges().end <= line.timestep
|
| 798 |
+
)
|
| 799 |
+
seen.add(pool)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/multi_kernel.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any, List
|
| 4 |
+
|
| 5 |
+
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
| 6 |
+
|
| 7 |
+
from .. import config
|
| 8 |
+
from ..codecache import PyCodeCache, TritonFuture
|
| 9 |
+
from ..utils import cache_on_self, do_bench
|
| 10 |
+
from ..virtualized import V
|
| 11 |
+
from .common import TensorArg
|
| 12 |
+
|
| 13 |
+
log = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_kernel_argdefs(kernel):
|
| 17 |
+
arg_defs, _, _ = kernel.args.python_argdefs()
|
| 18 |
+
return arg_defs
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_all_args(args_list):
|
| 22 |
+
all_args = max(args_list, key=len)[:]
|
| 23 |
+
for args in args_list:
|
| 24 |
+
assert set(args).issubset(set(all_args)), f"{args} v.s. {all_args}"
|
| 25 |
+
|
| 26 |
+
return all_args
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_all_kernel_argdefs(kernels):
|
| 30 |
+
"""
|
| 31 |
+
The logic here must match with `get_all_call_args`.
|
| 32 |
+
"""
|
| 33 |
+
argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels]
|
| 34 |
+
|
| 35 |
+
return _get_all_args(argdefs_list)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_all_call_args(call_args_list):
|
| 39 |
+
"""
|
| 40 |
+
Passed in the call_args for each subkernel and return the call_args for the
|
| 41 |
+
combined multi-kernel.
|
| 42 |
+
|
| 43 |
+
Note an algorithm as follows does not always work:
|
| 44 |
+
```
|
| 45 |
+
all_call_args: Dict[
|
| 46 |
+
Any, None
|
| 47 |
+
] = {} # use a dict rather than set to maintain insertion order
|
| 48 |
+
for call_args in call_args_list:
|
| 49 |
+
all_call_args.update({arg: None for arg in call_args})
|
| 50 |
+
|
| 51 |
+
all_call_args = list(all_call_args.keys())
|
| 52 |
+
```
|
| 53 |
+
It will fail if any kernel has the same argument passed in multiple times.
|
| 54 |
+
Check test_pass_same_arg_multi_times in test_multi_kernel.py
|
| 55 |
+
|
| 56 |
+
Instead, we pick the longest call args and assert that otehr call args are
|
| 57 |
+
a subset of it.
|
| 58 |
+
"""
|
| 59 |
+
return _get_all_args(call_args_list)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_numel_argdefs(kernel):
|
| 63 |
+
numel_argdefs = []
|
| 64 |
+
for tree in kernel.range_trees:
|
| 65 |
+
if tree.prefix != "r" or kernel.inside_reduction:
|
| 66 |
+
numel_argdefs.append(f"{tree.prefix}numel")
|
| 67 |
+
|
| 68 |
+
return numel_argdefs
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MultiKernelState:
|
| 72 |
+
"""
|
| 73 |
+
Maintain state of multi-kernel compilation so we don't define duplicated
|
| 74 |
+
multi-kernel for the same set of sub-kernels.
|
| 75 |
+
|
| 76 |
+
V.graph.wrapper_code has a reference to MultiKernelState instance.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self):
|
| 80 |
+
self.subkernel_to_kernel_name = {}
|
| 81 |
+
|
| 82 |
+
def define_kernel(self, kernels):
|
| 83 |
+
"""
|
| 84 |
+
Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}".
|
| 85 |
+
This has some minor issue.
|
| 86 |
+
|
| 87 |
+
E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca ,
|
| 88 |
+
there are 2 flavors of non-persistent reduction:
|
| 89 |
+
https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4
|
| 90 |
+
and
|
| 91 |
+
https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd
|
| 92 |
+
|
| 93 |
+
The only different is cache eviction policy.
|
| 94 |
+
|
| 95 |
+
We should name the multi-kernel differently in these 2 cases.
|
| 96 |
+
"""
|
| 97 |
+
kernel_names = tuple(k.kernel_name for k in kernels)
|
| 98 |
+
if kernel_names in self.subkernel_to_kernel_name:
|
| 99 |
+
return self.subkernel_to_kernel_name[kernel_names]
|
| 100 |
+
|
| 101 |
+
# name the multi kernel based on the first kernel
|
| 102 |
+
multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}"
|
| 103 |
+
self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name
|
| 104 |
+
|
| 105 |
+
if V.graph.cpp_wrapper:
|
| 106 |
+
# we should not generate any python code for multi-kernel during
|
| 107 |
+
# the second pass of cpp-wrapper.
|
| 108 |
+
return multi_kernel_name
|
| 109 |
+
|
| 110 |
+
wrapper = V.graph.wrapper_code
|
| 111 |
+
|
| 112 |
+
kernel_call_def_code = "\n".join(
|
| 113 |
+
[
|
| 114 |
+
f"""
|
| 115 |
+
def call{idx}(need_clone_args=False):
|
| 116 |
+
args = [{', '.join(get_kernel_argdefs(kernels[idx]))}]
|
| 117 |
+
if need_clone_args:
|
| 118 |
+
args, _ = multi_kernel_call.kernels[{idx}].clone_args(*args)
|
| 119 |
+
multi_kernel_call.kernels[{idx}].run(*args, {', '.join(get_numel_argdefs(kernels[idx]))}, grid=grid, stream=stream)
|
| 120 |
+
""".format(
|
| 121 |
+
idx
|
| 122 |
+
).strip(
|
| 123 |
+
"\n"
|
| 124 |
+
)
|
| 125 |
+
for idx in range(len(kernels))
|
| 126 |
+
]
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# add subkernel src code hashes to the multi-kernel source code so changing a
|
| 130 |
+
# subkernel implementation will result in a differnt py file for
|
| 131 |
+
# multi-kernel. This makes cache implementation straightforward since
|
| 132 |
+
# we can decide cache file name based on multi-kernel py file name
|
| 133 |
+
# directly.
|
| 134 |
+
#
|
| 135 |
+
# Without the hash added for subkernels, the cache file may be shared by
|
| 136 |
+
# different subkernels which is incorrect.
|
| 137 |
+
subkernel_hashes = "\n".join(
|
| 138 |
+
f"# subkernel{i} code hash: {kernel.code_hash}"
|
| 139 |
+
for i, kernel in enumerate(kernels)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
src_code = f"""
|
| 143 |
+
{subkernel_hashes}
|
| 144 |
+
def run(multi_kernel_call, {', '.join(get_all_kernel_argdefs(kernels))}, {', '.join(get_numel_argdefs(kernels[0]))}, grid, stream):
|
| 145 |
+
{kernel_call_def_code}
|
| 146 |
+
multi_kernel_call.run_with_argless_kernels([call0, call1])
|
| 147 |
+
""" # noqa: B950 line too long
|
| 148 |
+
wrapper.header.splice(
|
| 149 |
+
f"""
|
| 150 |
+
{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [
|
| 151 |
+
{", ".join(kernel_names)},
|
| 152 |
+
],
|
| 153 |
+
'''
|
| 154 |
+
"""
|
| 155 |
+
)
|
| 156 |
+
wrapper.header.splice(src_code)
|
| 157 |
+
wrapper.header.splice(
|
| 158 |
+
"""
|
| 159 |
+
'''
|
| 160 |
+
)
|
| 161 |
+
"""
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return multi_kernel_name
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class MultiKernel:
|
| 168 |
+
"""
|
| 169 |
+
This class maintains the compile time state for multi kernels.
|
| 170 |
+
|
| 171 |
+
Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
|
| 172 |
+
The generated definition for the multi-kernel will looks like:
|
| 173 |
+
```
|
| 174 |
+
multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code)
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def __init__(self, kernels):
|
| 181 |
+
assert len(kernels) >= 2
|
| 182 |
+
|
| 183 |
+
self.kernels = kernels
|
| 184 |
+
self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
|
| 185 |
+
kernels
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# need this since some code in inductor check if the kernel object has an args
|
| 189 |
+
# attribute to decide if it's a non-null kernel.
|
| 190 |
+
self.args = object()
|
| 191 |
+
|
| 192 |
+
def call_kernel(self, kernel_name):
|
| 193 |
+
"""
|
| 194 |
+
Collect the union of arguments from all subkernels as the arguments
|
| 195 |
+
for the multi-kernel.
|
| 196 |
+
"""
|
| 197 |
+
assert kernel_name == self.kernel_name
|
| 198 |
+
call_args_list = [kernel.get_call_args() for kernel in self.kernels]
|
| 199 |
+
|
| 200 |
+
all_call_args = get_all_call_args(call_args_list)
|
| 201 |
+
grid: List[Any] = []
|
| 202 |
+
|
| 203 |
+
if V.graph.cpp_wrapper:
|
| 204 |
+
# for the second pass of cpp-wrapper codegen, we should call
|
| 205 |
+
# the fast kernel directly
|
| 206 |
+
picked_kernel = MultiKernelCall.lookup_choice(kernel_name)
|
| 207 |
+
kernel_name = self.kernels[picked_kernel].kernel_name
|
| 208 |
+
final_call_args = call_args_list[picked_kernel]
|
| 209 |
+
else:
|
| 210 |
+
final_call_args = all_call_args
|
| 211 |
+
|
| 212 |
+
# numels for all subkernels should be the same. Use kernels[0] here
|
| 213 |
+
self.kernels[0].add_numel_to_call_args_and_grid(
|
| 214 |
+
kernel_name, final_call_args, grid
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
|
| 218 |
+
|
| 219 |
+
V.graph.wrapper_code.generate_kernel_call(
|
| 220 |
+
kernel_name,
|
| 221 |
+
final_call_args,
|
| 222 |
+
grid,
|
| 223 |
+
V.graph.scheduler.current_device.index,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
def codegen_nan_check(self):
|
| 227 |
+
wrapper = V.graph.wrapper_code
|
| 228 |
+
seen = set()
|
| 229 |
+
for k in self.kernels:
|
| 230 |
+
_, call_args, arg_types = k.args.python_argdefs()
|
| 231 |
+
for arg, arg_type in zip(call_args, arg_types):
|
| 232 |
+
if arg in seen:
|
| 233 |
+
continue
|
| 234 |
+
seen.add(arg)
|
| 235 |
+
if isinstance(arg_type, TensorArg):
|
| 236 |
+
line = f"assert not {arg}.isnan().any().item()"
|
| 237 |
+
wrapper.writeline(line)
|
| 238 |
+
line = f"assert not {arg}.isinf().any().item()"
|
| 239 |
+
wrapper.writeline(line)
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def removed_buffers(self):
|
| 243 |
+
return set.intersection(*[k.removed_buffers for k in self.kernels])
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def inplaced_to_remove(self):
|
| 247 |
+
return set.intersection(*[k.inplaced_to_remove for k in self.kernels])
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
@cache_on_self
|
| 251 |
+
def inplace_update_buffers(self):
|
| 252 |
+
"""
|
| 253 |
+
Make sure all kernels have the same inplace update mappings.
|
| 254 |
+
"""
|
| 255 |
+
for k in self.kernels[1:]:
|
| 256 |
+
assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers
|
| 257 |
+
return self.kernels[0].inplace_update_buffers
|
| 258 |
+
|
| 259 |
+
def warn_mix_layout(self, kernel_name: str):
|
| 260 |
+
pass
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class MultiKernelCall:
|
| 264 |
+
"""
|
| 265 |
+
This class is called at run time to actually run the kernel
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
def __init__(self, multi_kernel_name, kernels, src_code):
|
| 269 |
+
assert len(kernels) >= 2
|
| 270 |
+
self._kernels = kernels
|
| 271 |
+
self.multi_kernel_name = multi_kernel_name
|
| 272 |
+
|
| 273 |
+
self._run = PyCodeCache.load(src_code).run
|
| 274 |
+
self.disable_cache = os.environ.get(
|
| 275 |
+
"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE"
|
| 276 |
+
) == "1" or is_metric_table_enabled("persistent_red_perf")
|
| 277 |
+
|
| 278 |
+
self.picked_kernel = None
|
| 279 |
+
if config.triton.multi_kernel > 1:
|
| 280 |
+
# manually force a subkernel to ease perf testing
|
| 281 |
+
picked_by_config = config.triton.multi_kernel - 2
|
| 282 |
+
assert picked_by_config < len(self._kernels)
|
| 283 |
+
self.picked_kernel = picked_by_config
|
| 284 |
+
elif not self.disable_cache:
|
| 285 |
+
self.load_cache()
|
| 286 |
+
|
| 287 |
+
self._recorded = False
|
| 288 |
+
|
| 289 |
+
def cache_file_path(self):
|
| 290 |
+
py_file_path = self._run.__globals__["__file__"]
|
| 291 |
+
return os.path.splitext(py_file_path)[0] + ".picked_kernel"
|
| 292 |
+
|
| 293 |
+
def load_cache(self):
|
| 294 |
+
assert self.picked_kernel is None
|
| 295 |
+
path = self.cache_file_path()
|
| 296 |
+
if os.path.exists(path):
|
| 297 |
+
with open(path) as fd:
|
| 298 |
+
self.picked_kernel = int(fd.read())
|
| 299 |
+
assert self.picked_kernel >= 0 and self.picked_kernel < len(
|
| 300 |
+
self._kernels
|
| 301 |
+
)
|
| 302 |
+
log.debug(
|
| 303 |
+
"Load picked kernel %d from cache file %s", self.picked_kernel, path
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def store_cache(self):
|
| 307 |
+
assert self.picked_kernel is not None
|
| 308 |
+
path = self.cache_file_path()
|
| 309 |
+
with open(path, "w") as fd:
|
| 310 |
+
fd.write(str(self.picked_kernel))
|
| 311 |
+
log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path)
|
| 312 |
+
|
| 313 |
+
@property
|
| 314 |
+
def kernels(self):
|
| 315 |
+
"""
|
| 316 |
+
Read results from future.
|
| 317 |
+
|
| 318 |
+
This should be called after parallel compilation is done.
|
| 319 |
+
In case you call this before compilation is done,
|
| 320 |
+
it may slow down the parallel compilation.
|
| 321 |
+
"""
|
| 322 |
+
for i, kernel in enumerate(self._kernels):
|
| 323 |
+
if isinstance(kernel, TritonFuture):
|
| 324 |
+
self._kernels[i] = kernel.result()
|
| 325 |
+
|
| 326 |
+
return self._kernels
|
| 327 |
+
|
| 328 |
+
def run(self, *args, **kwargs):
|
| 329 |
+
self._run(self, *args, **kwargs)
|
| 330 |
+
|
| 331 |
+
@staticmethod
|
| 332 |
+
def benchmark_sub_kernels(kernel_calls):
|
| 333 |
+
"""
|
| 334 |
+
Benchmark all the sub kernels and return the execution time
|
| 335 |
+
(in milliseconds) for each of time.
|
| 336 |
+
|
| 337 |
+
Unit test may mock this method to force a specific kernel to
|
| 338 |
+
be picked.
|
| 339 |
+
"""
|
| 340 |
+
return [
|
| 341 |
+
do_bench(lambda: kernel_call(True), rep=40, fast_flush=True)
|
| 342 |
+
for kernel_call in kernel_calls
|
| 343 |
+
]
|
| 344 |
+
|
| 345 |
+
# record_choice and lookup_choice are helper functions for cpp-wrapper
|
| 346 |
+
# codegen. The first pass use record_choice to keep the choice and
|
| 347 |
+
# the second pass do lookup by calling lookup_choice.
|
| 348 |
+
#
|
| 349 |
+
# An alternative that reused the multi-kernel cache does not work well
|
| 350 |
+
# since during codegen of the second pass, it's very hard to know the
|
| 351 |
+
# path for the cache file. Also reading the cache file need do some IO
|
| 352 |
+
# which can be slower.
|
| 353 |
+
@staticmethod
|
| 354 |
+
def record_choice(multi_kernel_name, choice):
|
| 355 |
+
"""
|
| 356 |
+
Record the multi-kernel choice for cpp-wrapper first pass codegen
|
| 357 |
+
for the second pass.
|
| 358 |
+
|
| 359 |
+
We should do nothing if this function is not called during codegen.
|
| 360 |
+
"""
|
| 361 |
+
from torch._inductor.graph import GraphLowering
|
| 362 |
+
|
| 363 |
+
if not isinstance(V.graph, GraphLowering):
|
| 364 |
+
return
|
| 365 |
+
|
| 366 |
+
if not V.graph.record_multi_kernel_choice:
|
| 367 |
+
return
|
| 368 |
+
|
| 369 |
+
V.graph.multi_kernel_to_choice[multi_kernel_name] = choice
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def lookup_choice(multi_kernel_name):
|
| 373 |
+
# this should always been done during cpp-wrapper codegen
|
| 374 |
+
assert V.graph.record_multi_kernel_choice
|
| 375 |
+
# there should be no miss
|
| 376 |
+
return V.graph.multi_kernel_to_choice[multi_kernel_name]
|
| 377 |
+
|
| 378 |
+
def run_with_argless_kernels(self, kernel_calls):
|
| 379 |
+
if self.picked_kernel is None:
|
| 380 |
+
timings = self.benchmark_sub_kernels(kernel_calls)
|
| 381 |
+
self.picked_kernel = timings.index(min(timings))
|
| 382 |
+
k0 = self.kernels[0]
|
| 383 |
+
log.debug(
|
| 384 |
+
"pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s",
|
| 385 |
+
self.picked_kernel,
|
| 386 |
+
[k.inductor_meta.get("kernel_name") for k in self.kernels],
|
| 387 |
+
k0.size_hints,
|
| 388 |
+
k0.inductor_meta.get("reduction_hint"),
|
| 389 |
+
timings,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
def get_kernel_path(k):
|
| 393 |
+
return k.fn.fn.__code__.co_filename
|
| 394 |
+
|
| 395 |
+
get_metric_table("persistent_red_perf").add_row(
|
| 396 |
+
lambda: {
|
| 397 |
+
"kernel1_name": get_kernel_path(self.kernels[0]),
|
| 398 |
+
"kernel2_name": get_kernel_path(self.kernels[1]),
|
| 399 |
+
"kernel1_latency": timings[0],
|
| 400 |
+
"kernel2_latency": timings[1],
|
| 401 |
+
"size_hints": k0.size_hints,
|
| 402 |
+
"reduction_hint": k0.inductor_meta.get("reduction_hint"),
|
| 403 |
+
"speedup": timings[1] / timings[0],
|
| 404 |
+
}
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
if not self.disable_cache:
|
| 408 |
+
self.store_cache()
|
| 409 |
+
|
| 410 |
+
if not self._recorded:
|
| 411 |
+
self._recorded = True
|
| 412 |
+
self.record_choice(self.multi_kernel_name, self.picked_kernel)
|
| 413 |
+
kernel_calls[self.picked_kernel]()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_foreach.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, List, Tuple
|
| 5 |
+
|
| 6 |
+
from sympy import Integer
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .. import metrics
|
| 11 |
+
from ..scheduler import SchedulerNode
|
| 12 |
+
from ..utils import ceildiv, Placeholder
|
| 13 |
+
from ..virtualized import V
|
| 14 |
+
from .common import IndentedBuffer, Kernel
|
| 15 |
+
from .triton import gen_common_triton_imports, TritonKernel
|
| 16 |
+
from .triton_utils import config_of, signature_to_meta
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class PartitionState:
|
| 21 |
+
partitions: List[
|
| 22 |
+
List[Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]]
|
| 23 |
+
]
|
| 24 |
+
cur_partition: List[
|
| 25 |
+
Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]
|
| 26 |
+
]
|
| 27 |
+
cur_count: int
|
| 28 |
+
|
| 29 |
+
def finalize(self):
|
| 30 |
+
if self.cur_partition:
|
| 31 |
+
self.partitions.append(self.cur_partition)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ForeachKernel(Kernel):
|
| 35 |
+
MAX_NUM_ARGS = 250 # number where I would no longer get triton errors
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def _update_partition(partition_state, node_rw_count, node_info):
|
| 39 |
+
if partition_state.cur_count + node_rw_count > ForeachKernel.MAX_NUM_ARGS:
|
| 40 |
+
partition_state.partitions.append(partition_state.cur_partition)
|
| 41 |
+
partition_state.cur_partition = [node_info]
|
| 42 |
+
partition_state.cur_count = node_rw_count
|
| 43 |
+
else:
|
| 44 |
+
partition_state.cur_count += node_rw_count
|
| 45 |
+
partition_state.cur_partition.append(node_info)
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def horizontal_partition(subkernel_nodes, triton_scheduling):
|
| 49 |
+
"""Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
|
| 50 |
+
for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
|
| 51 |
+
(read/writes) and to have the same 2D or 1D blocking strategy."""
|
| 52 |
+
assert len(subkernel_nodes) >= 1
|
| 53 |
+
|
| 54 |
+
partition_state_1d = PartitionState([], [], 0)
|
| 55 |
+
yelem_to_partition_state_2d: Dict[Integer, PartitionState] = defaultdict(
|
| 56 |
+
lambda: PartitionState([], [], 0)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
for node in subkernel_nodes:
|
| 60 |
+
fused_nodes = node.get_nodes()
|
| 61 |
+
_, (numel, rnumel) = max(
|
| 62 |
+
fused_nodes, key=lambda x: int(x.is_reduction())
|
| 63 |
+
).group
|
| 64 |
+
tiled_groups = triton_scheduling.select_tiling(fused_nodes, numel, rnumel)
|
| 65 |
+
node_info = fused_nodes, tiled_groups, numel, rnumel
|
| 66 |
+
|
| 67 |
+
read_writes = node.read_writes
|
| 68 |
+
read_write_count = len(read_writes.reads) + len(read_writes.writes)
|
| 69 |
+
|
| 70 |
+
if tiled_groups[1] == 1:
|
| 71 |
+
ForeachKernel._update_partition(
|
| 72 |
+
partition_state_1d, read_write_count, node_info
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
y_elem = tiled_groups[0]
|
| 76 |
+
partition_state_2d = yelem_to_partition_state_2d[y_elem]
|
| 77 |
+
ForeachKernel._update_partition(
|
| 78 |
+
partition_state_2d, read_write_count, node_info
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
partition_state_1d.finalize()
|
| 82 |
+
all_partitions = partition_state_1d.partitions
|
| 83 |
+
for partition_state_2d in yelem_to_partition_state_2d.values():
|
| 84 |
+
partition_state_2d.finalize()
|
| 85 |
+
all_partitions.extend(partition_state_2d.partitions)
|
| 86 |
+
|
| 87 |
+
return all_partitions
|
| 88 |
+
|
| 89 |
+
def __init__(self):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.blocking_2d = False
|
| 92 |
+
self.block_size_1d = 1024 # Try tuning this value
|
| 93 |
+
self.block_size_2d = 32
|
| 94 |
+
self.num_warps = 8
|
| 95 |
+
self.sub_kernels = []
|
| 96 |
+
self.iter_vars_count = itertools.count()
|
| 97 |
+
self.x_block_count = 0
|
| 98 |
+
self.y_block_count = 0
|
| 99 |
+
|
| 100 |
+
def get_block_size(self):
|
| 101 |
+
if self.blocking_2d:
|
| 102 |
+
return self.block_size_2d
|
| 103 |
+
else:
|
| 104 |
+
return self.block_size_1d
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def codegen_pid_offsets(code, block_count, lower_bound, prefix):
|
| 108 |
+
if block_count == 0:
|
| 109 |
+
code.splice(f"{prefix}pid_offset = {prefix}pid")
|
| 110 |
+
else:
|
| 111 |
+
code.splice(f"{prefix}pid_offset = {prefix}pid - {lower_bound}")
|
| 112 |
+
|
| 113 |
+
def codegen_pid_range(self, code, x_elems):
|
| 114 |
+
num_x_blocks = ceildiv(x_elems, self.get_block_size())
|
| 115 |
+
upper_bound_x_pid = self.x_block_count + num_x_blocks
|
| 116 |
+
lower_bound_x_pid = self.x_block_count
|
| 117 |
+
|
| 118 |
+
if self.x_block_count == 0:
|
| 119 |
+
cond = "if"
|
| 120 |
+
else:
|
| 121 |
+
cond = "elif"
|
| 122 |
+
|
| 123 |
+
x_pid_bounds_check = (
|
| 124 |
+
f"xpid >= {lower_bound_x_pid} and xpid < {upper_bound_x_pid}"
|
| 125 |
+
)
|
| 126 |
+
code.splice(f"{cond} {x_pid_bounds_check}:")
|
| 127 |
+
|
| 128 |
+
with code.indent():
|
| 129 |
+
ForeachKernel.codegen_pid_offsets(
|
| 130 |
+
code, num_x_blocks, lower_bound_x_pid, "x"
|
| 131 |
+
)
|
| 132 |
+
self.x_block_count += num_x_blocks
|
| 133 |
+
|
| 134 |
+
def create_sub_kernel(self, *groups, index_dtype, mutations, reduction_hint):
|
| 135 |
+
sub_kernel = TritonKernel(
|
| 136 |
+
*groups,
|
| 137 |
+
index_dtype=index_dtype,
|
| 138 |
+
mutations=mutations,
|
| 139 |
+
pid_cache={
|
| 140 |
+
"tl.program_id(0)": "xpid_offset",
|
| 141 |
+
"tl.program_id(1)": "ypid",
|
| 142 |
+
},
|
| 143 |
+
reduction_hint=reduction_hint,
|
| 144 |
+
)
|
| 145 |
+
if self.blocking_2d:
|
| 146 |
+
assert len(groups) == 3
|
| 147 |
+
|
| 148 |
+
self.blocking_2d |= groups[1] != 1 and len(groups) == 3
|
| 149 |
+
metrics.generated_kernel_count -= 1
|
| 150 |
+
sub_kernel.args = self.args
|
| 151 |
+
sub_kernel.iter_vars_count = self.iter_vars_count
|
| 152 |
+
sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
|
| 153 |
+
self.sub_kernels.append(sub_kernel)
|
| 154 |
+
return sub_kernel
|
| 155 |
+
|
| 156 |
+
def jit_lines(self):
|
| 157 |
+
can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
|
| 158 |
+
size_dtype = "tl.int32" if can_use_32bit else "tl.int64"
|
| 159 |
+
_, _, signature = self.args.python_argdefs()
|
| 160 |
+
triton_meta = {
|
| 161 |
+
"signature": signature_to_meta(signature, size_dtype=size_dtype),
|
| 162 |
+
"device": V.graph.scheduler.current_device.index,
|
| 163 |
+
"device_type": V.graph.scheduler.current_device.type,
|
| 164 |
+
"constants": {},
|
| 165 |
+
}
|
| 166 |
+
triton_meta["configs"] = [config_of(signature)]
|
| 167 |
+
inductor_meta = {
|
| 168 |
+
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
|
| 169 |
+
"backend_hash": torch.utils._triton.triton_hash_with_backend(),
|
| 170 |
+
}
|
| 171 |
+
return f"""
|
| 172 |
+
@triton_heuristics.foreach(
|
| 173 |
+
num_warps={self.num_warps},
|
| 174 |
+
triton_meta={triton_meta!r},
|
| 175 |
+
inductor_meta={inductor_meta!r},
|
| 176 |
+
)
|
| 177 |
+
@triton.jit
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def grid(self):
|
| 181 |
+
return (
|
| 182 |
+
self.x_block_count,
|
| 183 |
+
ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d)
|
| 184 |
+
if self.blocking_2d
|
| 185 |
+
else 1,
|
| 186 |
+
1,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def codegen_kernel(self, name=None):
|
| 190 |
+
code = IndentedBuffer()
|
| 191 |
+
|
| 192 |
+
code.splice(gen_common_triton_imports())
|
| 193 |
+
argdefs, _, _ = self.args.python_argdefs()
|
| 194 |
+
code.splice(self.jit_lines())
|
| 195 |
+
code.writeline(
|
| 196 |
+
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
with code.indent():
|
| 200 |
+
code.splice("xpid = tl.program_id(0)")
|
| 201 |
+
if self.blocking_2d:
|
| 202 |
+
code.splice("ypid = tl.program_id(1)")
|
| 203 |
+
code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
|
| 204 |
+
code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
|
| 205 |
+
else:
|
| 206 |
+
code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")
|
| 207 |
+
|
| 208 |
+
for sub_kernel in self.sub_kernels:
|
| 209 |
+
assert len(sub_kernel.numels) <= 3
|
| 210 |
+
# TODO mlazos: support dynamic shapes
|
| 211 |
+
numel_ind = 0 if not self.blocking_2d else 1
|
| 212 |
+
self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind]))
|
| 213 |
+
with code.indent():
|
| 214 |
+
if self.blocking_2d:
|
| 215 |
+
code.splice(f"ynumel = {sub_kernel.numels[0]}")
|
| 216 |
+
code.splice(f"xnumel = {sub_kernel.numels[1]}")
|
| 217 |
+
else:
|
| 218 |
+
code.splice(f"xnumel = {sub_kernel.numels[0]}")
|
| 219 |
+
|
| 220 |
+
sub_kernel.codegen_body()
|
| 221 |
+
code.splice(sub_kernel.body)
|
| 222 |
+
|
| 223 |
+
code.splice("else:")
|
| 224 |
+
with code.indent():
|
| 225 |
+
code.splice("pass")
|
| 226 |
+
|
| 227 |
+
return code.getvalue()
|
| 228 |
+
|
| 229 |
+
def call_kernel(self, code, name: str):
|
| 230 |
+
_, call_args, _ = self.args.python_argdefs()
|
| 231 |
+
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
| 232 |
+
for i in range(len(call_args)):
|
| 233 |
+
if V.graph.is_unspec_arg(call_args[i]):
|
| 234 |
+
call_args[i] = call_args[i] + ".item()"
|
| 235 |
+
if V.graph.cpp_wrapper:
|
| 236 |
+
V.graph.wrapper_code.generate_kernel_call(
|
| 237 |
+
name,
|
| 238 |
+
call_args,
|
| 239 |
+
device_index=V.graph.scheduler.current_device.index,
|
| 240 |
+
grid=self.grid(),
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
# TODO: refactor generate_kernel_call
|
| 244 |
+
call_args_str = ", ".join(call_args)
|
| 245 |
+
stream_name = code.write_get_raw_stream(
|
| 246 |
+
V.graph.scheduler.current_device.index
|
| 247 |
+
)
|
| 248 |
+
code.writeline(
|
| 249 |
+
f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})"
|
| 250 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_split_scan.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Set
|
| 4 |
+
|
| 5 |
+
from torch._inductor import config, ir
|
| 6 |
+
|
| 7 |
+
from torch._inductor.codegen.triton import (
|
| 8 |
+
IterationRangesRoot,
|
| 9 |
+
triton_compute_type,
|
| 10 |
+
TritonKernel,
|
| 11 |
+
TritonKernelOverrides,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from torch._prims_common import prod
|
| 15 |
+
|
| 16 |
+
from torch.utils._sympy.functions import CeilDiv
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TritonSplitScanKernel(TritonKernel):
|
| 20 |
+
"""Generates a triton kernel that supports ops.scan calls while also splitting
|
| 21 |
+
the reduction dimension over multiple triton programs.
|
| 22 |
+
|
| 23 |
+
For this kernel, loop numels will always take the form ``(xdim, rdim)``
|
| 24 |
+
and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication
|
| 25 |
+
between blocks occurs within a global memory workspace buffer, which
|
| 26 |
+
must be zero-filled before launching the kernel.
|
| 27 |
+
|
| 28 |
+
Note that generation for ``ops.reduction`` is not supported.
|
| 29 |
+
|
| 30 |
+
For details of the communication strategy, see
|
| 31 |
+
https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
*groups,
|
| 38 |
+
index_dtype: str,
|
| 39 |
+
mutations: Optional[Set[str]] = None,
|
| 40 |
+
reduction_hint=ir.ReductionHint.DEFAULT,
|
| 41 |
+
min_elem_per_thread=0,
|
| 42 |
+
):
|
| 43 |
+
super().__init__(
|
| 44 |
+
*groups,
|
| 45 |
+
index_dtype=index_dtype,
|
| 46 |
+
mutations=mutations,
|
| 47 |
+
pid_cache=None,
|
| 48 |
+
reduction_hint=reduction_hint,
|
| 49 |
+
min_elem_per_thread=min_elem_per_thread,
|
| 50 |
+
)
|
| 51 |
+
self.no_x_dim = True
|
| 52 |
+
|
| 53 |
+
def initialize_range_tree(self, pid_cache):
|
| 54 |
+
prefixes = "yxr"
|
| 55 |
+
assert len(self.numels) <= len(
|
| 56 |
+
prefixes
|
| 57 |
+
), "z dimension not supported for split scan"
|
| 58 |
+
active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
|
| 59 |
+
|
| 60 |
+
grid_dims = "rxy"
|
| 61 |
+
for numel, prefix in zip(self.numels, active_prefixes):
|
| 62 |
+
is_reduction = prefix == "r"
|
| 63 |
+
tensor_dim = 0 if is_reduction else None
|
| 64 |
+
grid_dim = grid_dims.find(prefix)
|
| 65 |
+
self.range_trees.append(
|
| 66 |
+
IterationRangesRoot(
|
| 67 |
+
f"{prefix}index",
|
| 68 |
+
numel,
|
| 69 |
+
prefix,
|
| 70 |
+
grid_dim,
|
| 71 |
+
self,
|
| 72 |
+
pid_cache=pid_cache,
|
| 73 |
+
is_loop=False,
|
| 74 |
+
tensor_dim=tensor_dim,
|
| 75 |
+
grid_dim=grid_dim,
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
for tree in self.range_trees:
|
| 79 |
+
tree.codegen_header(self.body)
|
| 80 |
+
|
| 81 |
+
def reduction(self, dtype, src_dtype, reduction_type, value):
|
| 82 |
+
raise NotImplementedError("NYI TritonSplitDimKernel reductions")
|
| 83 |
+
|
| 84 |
+
def scan(self, dtype, combine_fn, value, init):
|
| 85 |
+
import triton.language as tl
|
| 86 |
+
|
| 87 |
+
compute_type = triton_compute_type(dtype)
|
| 88 |
+
compute_type_triton = getattr(tl, compute_type[3:])
|
| 89 |
+
|
| 90 |
+
element_nbits = compute_type_triton.primitive_bitwidth
|
| 91 |
+
|
| 92 |
+
scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64"
|
| 93 |
+
scratch_type_triton = getattr(tl, scratch_type[3:])
|
| 94 |
+
scratch_elems_per_block = 3 if element_nbits == 64 else 1
|
| 95 |
+
scratch_nbytes_per_block = scratch_elems_per_block * (
|
| 96 |
+
scratch_type_triton.primitive_bitwidth // 8
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
cse_load = functools.partial(self.cse.generate, self.loads)
|
| 100 |
+
cse_compute = functools.partial(self.cse.generate, self.compute)
|
| 101 |
+
|
| 102 |
+
assert len(self.numels) == 2, "Unexpected tiling"
|
| 103 |
+
min_rblock = config.triton.min_split_scan_rblock
|
| 104 |
+
max_blocks = prod(self.numels[:-1]) * CeilDiv(self.numels[-1], min_rblock)
|
| 105 |
+
nbytes = scratch_nbytes_per_block * max_blocks
|
| 106 |
+
scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True)
|
| 107 |
+
if offset != 0:
|
| 108 |
+
scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}")
|
| 109 |
+
runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})")
|
| 110 |
+
scratch_base = cse_load(
|
| 111 |
+
f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * "
|
| 112 |
+
f"{scratch_elems_per_block} * {runtime_rblocks}"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
masks = {f"{tree.prefix}mask" for tree in self.range_trees}
|
| 116 |
+
self.filter_masks(masks)
|
| 117 |
+
masks = sorted(masks)
|
| 118 |
+
if self._load_mask:
|
| 119 |
+
masks.append(self._load_mask)
|
| 120 |
+
|
| 121 |
+
value = cse_compute(f"{value}.to({compute_type})")
|
| 122 |
+
value = cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})")
|
| 123 |
+
init = cse_compute(f"tl.full([], {init}, {compute_type})")
|
| 124 |
+
if masks:
|
| 125 |
+
cond = " & ".join(masks)
|
| 126 |
+
masked_value = cse_compute(TritonKernelOverrides.where(cond, value, init))
|
| 127 |
+
else:
|
| 128 |
+
masked_value = value
|
| 129 |
+
|
| 130 |
+
combine_helper_fn = self._lift_helper(combine_fn, 2)
|
| 131 |
+
dim = self.triton_tensor_ndim() - 1
|
| 132 |
+
assert dim == 0, ""
|
| 133 |
+
|
| 134 |
+
block_sum = cse_compute(
|
| 135 |
+
f"tl.reduce({masked_value}, {dim}, {combine_helper_fn})"
|
| 136 |
+
)
|
| 137 |
+
exclusive_prefix = self.cse.newvar()
|
| 138 |
+
if element_nbits == 64:
|
| 139 |
+
self.compute.splice(
|
| 140 |
+
f"""
|
| 141 |
+
{exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64(
|
| 142 |
+
{scratch_base},
|
| 143 |
+
{block_sum},
|
| 144 |
+
{self.range_trees[-1].get_pid()},
|
| 145 |
+
{combine_helper_fn},
|
| 146 |
+
{init},
|
| 147 |
+
)
|
| 148 |
+
""",
|
| 149 |
+
strip=True,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
else:
|
| 153 |
+
assert element_nbits <= 32
|
| 154 |
+
value_as_uint_dtype = f"tl.uint{element_nbits}"
|
| 155 |
+
|
| 156 |
+
self.compute.splice(
|
| 157 |
+
f"""
|
| 158 |
+
{exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback(
|
| 159 |
+
{scratch_base},
|
| 160 |
+
{block_sum},
|
| 161 |
+
{self.range_trees[-1].get_pid()},
|
| 162 |
+
{combine_helper_fn},
|
| 163 |
+
{init},
|
| 164 |
+
DTYPE_VALUE_AS_UINT={value_as_uint_dtype},
|
| 165 |
+
DTYPE_PACK={scratch_type},
|
| 166 |
+
)
|
| 167 |
+
""",
|
| 168 |
+
strip=True,
|
| 169 |
+
)
|
| 170 |
+
# Compute final cumsum
|
| 171 |
+
block_scan = cse_compute(
|
| 172 |
+
f"tl.associative_scan({masked_value}, {dim}, {combine_helper_fn})"
|
| 173 |
+
)
|
| 174 |
+
return cse_compute(f"{combine_helper_fn}({exclusive_prefix}, {block_scan})")
|
| 175 |
+
|
| 176 |
+
def _get_heuristic(self):
|
| 177 |
+
return "split_scan"
|
| 178 |
+
|
| 179 |
+
def _get_grid_fn(self):
|
| 180 |
+
return "split_scan_grid"
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton_utils.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .. import config
|
| 6 |
+
from ..utils import _type_of, instance_descriptor
|
| 7 |
+
from ..virtualized import V
|
| 8 |
+
from .common import KernelArgType, SizeArg, TensorArg, WorkspaceArg
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def signature_of(arg: KernelArgType, *, size_dtype: str) -> str:
|
| 12 |
+
if isinstance(arg, TensorArg):
|
| 13 |
+
# TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes.
|
| 14 |
+
# Related PR: https://github.com/openai/triton/pull/2279/
|
| 15 |
+
if arg.dtype == torch.float8_e4m3fn:
|
| 16 |
+
tye = "*fp8e4nv"
|
| 17 |
+
elif arg.dtype == torch.float8_e5m2:
|
| 18 |
+
tye = "*fp8e5"
|
| 19 |
+
elif arg.dtype == torch.float8_e4m3fnuz:
|
| 20 |
+
tye = "*fp8e4b8"
|
| 21 |
+
elif arg.dtype == torch.float8_e5m2fnuz:
|
| 22 |
+
tye = "*fp8e5b16"
|
| 23 |
+
else:
|
| 24 |
+
tye = _type_of(arg.dtype)
|
| 25 |
+
if V.graph.is_unspec_arg(arg.buffer):
|
| 26 |
+
# had unwrapped 0d tensor as scalar
|
| 27 |
+
new_tye = tye.lstrip("*")
|
| 28 |
+
if new_tye in ["fp16", "bf16"]:
|
| 29 |
+
return "fp32"
|
| 30 |
+
else:
|
| 31 |
+
return new_tye
|
| 32 |
+
else:
|
| 33 |
+
return tye
|
| 34 |
+
if isinstance(arg, SizeArg):
|
| 35 |
+
if arg.expr is None:
|
| 36 |
+
# From triton/runtime/jit.py
|
| 37 |
+
# `None` is nullptr. Implicitly convert to *i8.
|
| 38 |
+
return "*i8"
|
| 39 |
+
elif isinstance(arg.expr, float):
|
| 40 |
+
return "fp32"
|
| 41 |
+
if size_dtype == "tl.int32":
|
| 42 |
+
return "i32"
|
| 43 |
+
elif size_dtype == "tl.int64":
|
| 44 |
+
return "i64"
|
| 45 |
+
else:
|
| 46 |
+
raise NotImplementedError(f"unhandled size_dtype {size_dtype}")
|
| 47 |
+
if isinstance(arg, WorkspaceArg):
|
| 48 |
+
return "*i8"
|
| 49 |
+
raise NotImplementedError(f"unhandled {type(arg)}: {arg}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def signature_to_meta(
|
| 53 |
+
signature: List[KernelArgType],
|
| 54 |
+
*,
|
| 55 |
+
size_dtype: str,
|
| 56 |
+
indices: Optional[List[int]] = None,
|
| 57 |
+
) -> Dict[int, str]:
|
| 58 |
+
if indices is None:
|
| 59 |
+
indices = list(range(len(signature)))
|
| 60 |
+
return {
|
| 61 |
+
i: signature_of(arg, size_dtype=size_dtype)
|
| 62 |
+
for i, arg in zip(indices, signature)
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def config_of(
|
| 67 |
+
args: List[KernelArgType],
|
| 68 |
+
*,
|
| 69 |
+
indices: Optional[List[int]] = None,
|
| 70 |
+
) -> Any:
|
| 71 |
+
if indices is None:
|
| 72 |
+
indices = list(range(len(args)))
|
| 73 |
+
|
| 74 |
+
def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool:
|
| 75 |
+
"""
|
| 76 |
+
Roughly follow triton code here:
|
| 77 |
+
https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222
|
| 78 |
+
"""
|
| 79 |
+
if isinstance(x, TensorArg):
|
| 80 |
+
if include_tensor:
|
| 81 |
+
offset_aligned = V.graph.sizevars.statically_known_multiple_of(
|
| 82 |
+
x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type]
|
| 83 |
+
)
|
| 84 |
+
return offset_aligned and not V.graph.scheduler.is_unaligned_buffer(
|
| 85 |
+
x.buffer
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
return False
|
| 89 |
+
if isinstance(x, SizeArg):
|
| 90 |
+
# TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with
|
| 91 |
+
# _maybe_evaluate_static...
|
| 92 |
+
if x.name.startswith("load_seed_offset"):
|
| 93 |
+
return False
|
| 94 |
+
if x.expr is None:
|
| 95 |
+
return False
|
| 96 |
+
if isinstance(x.expr, float):
|
| 97 |
+
return False
|
| 98 |
+
return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type]
|
| 99 |
+
if isinstance(x, WorkspaceArg):
|
| 100 |
+
return V.graph.sizevars.statically_known_multiple_of(x.nbytes, alignment) # type: ignore[arg-type]
|
| 101 |
+
raise NotImplementedError(f"unhandled {type(x)}: {x}")
|
| 102 |
+
|
| 103 |
+
if config.triton.divisible_by_16:
|
| 104 |
+
divisible_by_16 = tuple(
|
| 105 |
+
i
|
| 106 |
+
for i, arg in zip(indices, args)
|
| 107 |
+
if is_aligned(arg, alignment=16, include_tensor=True)
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
divisible_by_16 = ()
|
| 111 |
+
divisible_by_8 = tuple(
|
| 112 |
+
i
|
| 113 |
+
for i, arg in zip(indices, args)
|
| 114 |
+
if is_aligned(arg, alignment=8, include_tensor=False)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
equal_to_1 = tuple(
|
| 118 |
+
i
|
| 119 |
+
for i, arg in zip(indices, args)
|
| 120 |
+
if isinstance(arg, SizeArg)
|
| 121 |
+
and arg.expr is not None
|
| 122 |
+
and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type]
|
| 123 |
+
)
|
| 124 |
+
# ids_of_folded_args is set from equal_to_1
|
| 125 |
+
# and None args by the Triton compiler
|
| 126 |
+
ids_of_folded_args = tuple(equal_to_1)
|
| 127 |
+
|
| 128 |
+
return instance_descriptor(
|
| 129 |
+
divisible_by_16, equal_to_1, ids_of_folded_args, divisible_by_8
|
| 130 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/wrapper.py
ADDED
|
@@ -0,0 +1,1543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import contextlib
|
| 3 |
+
import dataclasses
|
| 4 |
+
import functools
|
| 5 |
+
import inspect
|
| 6 |
+
import operator
|
| 7 |
+
import re
|
| 8 |
+
from itertools import count
|
| 9 |
+
from typing import (
|
| 10 |
+
Any,
|
| 11 |
+
Callable,
|
| 12 |
+
Dict,
|
| 13 |
+
Iterator,
|
| 14 |
+
List,
|
| 15 |
+
Optional,
|
| 16 |
+
Set,
|
| 17 |
+
Tuple,
|
| 18 |
+
TYPE_CHECKING,
|
| 19 |
+
Union,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
import sympy
|
| 23 |
+
from sympy import Expr
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch._ops
|
| 27 |
+
from torch._dynamo.utils import counters, dynamo_timed
|
| 28 |
+
|
| 29 |
+
from torch._inductor.codegen.multi_kernel import MultiKernelState
|
| 30 |
+
from torch.fx.experimental.symbolic_shapes import SymTypes
|
| 31 |
+
from torch.fx.node import _get_qualified_name
|
| 32 |
+
from torch.utils._sympy.singleton_int import SingletonInt
|
| 33 |
+
|
| 34 |
+
from .. import codecache, config, ir
|
| 35 |
+
from ..ir import ReinterpretView
|
| 36 |
+
from ..utils import (
|
| 37 |
+
cache_on_self,
|
| 38 |
+
get_benchmark_name,
|
| 39 |
+
LineContext,
|
| 40 |
+
sympy_product,
|
| 41 |
+
sympy_str,
|
| 42 |
+
)
|
| 43 |
+
from ..virtualized import V
|
| 44 |
+
from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
|
| 45 |
+
from .triton_utils import config_of, signature_to_meta
|
| 46 |
+
|
| 47 |
+
if TYPE_CHECKING:
|
| 48 |
+
import triton
|
| 49 |
+
|
| 50 |
+
from ..graph import GraphLowering
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
pexpr = PythonPrinter().doprint
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
ReuseKey = Tuple[torch.device, torch.dtype, str]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def buffer_reuse_key(node: ir.Buffer) -> ReuseKey:
|
| 60 |
+
return (
|
| 61 |
+
node.get_device(),
|
| 62 |
+
node.get_dtype(),
|
| 63 |
+
# NB: this is symbolic so that we don't try to reuse a buffer
|
| 64 |
+
# for s0 for s1, just because they happen to share the same
|
| 65 |
+
# size hint
|
| 66 |
+
sympy_str(V.graph.sizevars.simplify(node.layout.storage_size())),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def convert_arg_type(arg: torch.Argument) -> str:
|
| 71 |
+
from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP
|
| 72 |
+
|
| 73 |
+
# use x.real_type instead of x.type so that we get ScalarType instead of int
|
| 74 |
+
python_type = repr(arg.real_type) # type: ignore[attr-defined]
|
| 75 |
+
|
| 76 |
+
if python_type == "Tensor":
|
| 77 |
+
# Conversions rules follow https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func
|
| 78 |
+
if arg.alias_info is not None and arg.alias_info.is_write:
|
| 79 |
+
return f"at::{python_type}&"
|
| 80 |
+
else:
|
| 81 |
+
return f"at::{python_type} const&"
|
| 82 |
+
|
| 83 |
+
if python_type in PYTHON_TO_CPP:
|
| 84 |
+
cpp_type = PYTHON_TO_CPP[python_type]
|
| 85 |
+
return cpp_type
|
| 86 |
+
|
| 87 |
+
# Convert args of container types e.g. Optional[*]
|
| 88 |
+
for py_container, cpp_container in CONTAINER_PYTHON_TO_CPP.items():
|
| 89 |
+
container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type)
|
| 90 |
+
if len(container_match) == 1:
|
| 91 |
+
contained_type = container_match[0]
|
| 92 |
+
assert (
|
| 93 |
+
contained_type in PYTHON_TO_CPP
|
| 94 |
+
), f"unsupported {py_container} type in convert_arg_type: {contained_type}"
|
| 95 |
+
cpp_contained_type = PYTHON_TO_CPP[contained_type]
|
| 96 |
+
return f"{cpp_container}<{cpp_contained_type}>"
|
| 97 |
+
|
| 98 |
+
raise AssertionError(f"unsupport python_type: {python_type}")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def convert_return_type(ret: torch.Argument) -> str:
|
| 102 |
+
# use x.real_type instead of x.type so that we get ScalarType instead of int
|
| 103 |
+
python_type = repr(ret.real_type) # type: ignore[attr-defined]
|
| 104 |
+
python_to_cpp = {
|
| 105 |
+
"Tensor": "at::Tensor",
|
| 106 |
+
"List[Tensor]": "std::vector<at::Tensor>",
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
cpp_type = python_to_cpp.get(python_type, None)
|
| 110 |
+
assert cpp_type is not None, f"NYI return type: {python_type}"
|
| 111 |
+
# An output aliasing an input is returned by reference only when it's a
|
| 112 |
+
# Tensor, not when it's a Tensor[]. For example, aten.split.Tensor's output
|
| 113 |
+
# aliases the input tensor, but the op returns a vector by value.
|
| 114 |
+
if python_type == "Tensor" and ret.alias_info is not None:
|
| 115 |
+
cpp_type += "&"
|
| 116 |
+
return cpp_type
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str:
|
| 120 |
+
args = kernel._schema.arguments
|
| 121 |
+
returns = kernel._schema.returns
|
| 122 |
+
|
| 123 |
+
num_returns = len(returns)
|
| 124 |
+
assert num_returns > 0, "must have at least one return value"
|
| 125 |
+
|
| 126 |
+
if num_returns == 1:
|
| 127 |
+
cpp_return_value = convert_return_type(returns[0])
|
| 128 |
+
elif num_returns > 1:
|
| 129 |
+
tuple_returns = ", ".join([convert_return_type(r) for r in returns])
|
| 130 |
+
cpp_return_value = f"std::tuple<{tuple_returns}>"
|
| 131 |
+
|
| 132 |
+
cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args]
|
| 133 |
+
return f"{cpp_return_value}({', '.join(cpp_arg_type)})" # type: ignore[possibly-undefined]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# TODO: Move to a well known place
|
| 137 |
+
TritonMetaParams = Dict[str, int]
|
| 138 |
+
TritonGrid = Union[
|
| 139 |
+
Tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], Tuple[int, ...]]
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def user_defined_kernel_grid_fn_code(
|
| 144 |
+
name: str,
|
| 145 |
+
configs: List["triton.Config"],
|
| 146 |
+
grids: List[TritonGrid],
|
| 147 |
+
wrapper: Optional["WrapperCodeGen"] = None,
|
| 148 |
+
) -> Tuple[str, str]:
|
| 149 |
+
output = IndentedBuffer()
|
| 150 |
+
|
| 151 |
+
def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr:
|
| 152 |
+
return item if isinstance(item, sympy.Expr) else sympy.Integer(item)
|
| 153 |
+
|
| 154 |
+
def determine_grid(grid: TritonGrid):
|
| 155 |
+
if wrapper is None or callable(grid):
|
| 156 |
+
# return as-is when used in eager mode or when grid is callable
|
| 157 |
+
return grid
|
| 158 |
+
# Grid contains ints/Expr, so utilize wrapper's expr printer for codegen
|
| 159 |
+
sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid)
|
| 160 |
+
return wrapper.codegen_shape_tuple(sympy_grid)
|
| 161 |
+
|
| 162 |
+
fn_name = f"grid_wrapper_for_{name}"
|
| 163 |
+
output.writeline(f"def {fn_name}(meta):")
|
| 164 |
+
with output.indent():
|
| 165 |
+
if len(grids) == 1:
|
| 166 |
+
grid = determine_grid(grids[0])
|
| 167 |
+
output.writeline(f"return {grid}")
|
| 168 |
+
else:
|
| 169 |
+
assert len(grids) > 1
|
| 170 |
+
assert len(grids) == len(configs)
|
| 171 |
+
seen = set()
|
| 172 |
+
for grid, c in zip(grids, configs):
|
| 173 |
+
guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()]
|
| 174 |
+
guards = " and ".join(guards)
|
| 175 |
+
grid = determine_grid(grid)
|
| 176 |
+
statement = f"if {guards}: return {grid}"
|
| 177 |
+
if statement in seen:
|
| 178 |
+
continue
|
| 179 |
+
seen.add(statement)
|
| 180 |
+
output.writeline(statement)
|
| 181 |
+
|
| 182 |
+
return fn_name, output.getvalue()
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@dataclasses.dataclass
|
| 186 |
+
class SymbolicCallArg:
|
| 187 |
+
inner: str
|
| 188 |
+
# the original symbolic expression represented by inner
|
| 189 |
+
inner_expr: sympy.Expr
|
| 190 |
+
|
| 191 |
+
def __str__(self):
|
| 192 |
+
return str(self.inner)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# Default thread stack sizes vary by platform:
|
| 196 |
+
# - Linux: 8 MB
|
| 197 |
+
# - macOS: 512 KB
|
| 198 |
+
# - Windows: 1 MB
|
| 199 |
+
# Just pick something comfortably smaller than the smallest for now.
|
| 200 |
+
MAX_STACK_ALLOCATION_SIZE = 1024 * 100
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class MemoryPlanningState:
|
| 204 |
+
def __init__(self):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.reuse_pool: Dict[
|
| 207 |
+
ReuseKey, List[FreeIfNotReusedLine]
|
| 208 |
+
] = collections.defaultdict(list)
|
| 209 |
+
self.total_allocated_buffer_size: int = 0
|
| 210 |
+
|
| 211 |
+
def __contains__(self, key: ReuseKey) -> bool:
|
| 212 |
+
return bool(self.reuse_pool.get(key, None))
|
| 213 |
+
|
| 214 |
+
def pop(self, key: ReuseKey) -> "FreeIfNotReusedLine":
|
| 215 |
+
item = self.reuse_pool[key].pop()
|
| 216 |
+
assert not item.is_reused
|
| 217 |
+
return item
|
| 218 |
+
|
| 219 |
+
def push(self, key: ReuseKey, item: "FreeIfNotReusedLine") -> None:
|
| 220 |
+
assert not item.is_reused
|
| 221 |
+
self.reuse_pool[key].append(item)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class WrapperLine:
|
| 225 |
+
pass
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@dataclasses.dataclass
|
| 229 |
+
class EnterSubgraphLine(WrapperLine):
|
| 230 |
+
wrapper: "WrapperCodeGen"
|
| 231 |
+
graph: "GraphLowering"
|
| 232 |
+
|
| 233 |
+
def codegen(self, code: IndentedBuffer) -> None:
|
| 234 |
+
self.wrapper.push_codegened_graph(self.graph)
|
| 235 |
+
code.do_indent()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@dataclasses.dataclass
|
| 239 |
+
class ExitSubgraphLine(WrapperLine):
|
| 240 |
+
wrapper: "WrapperCodeGen"
|
| 241 |
+
|
| 242 |
+
def codegen(self, code: IndentedBuffer) -> None:
|
| 243 |
+
self.wrapper.pop_codegened_graph()
|
| 244 |
+
code.do_unindent()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@dataclasses.dataclass
|
| 248 |
+
class EnterDeviceContextManagerLine(WrapperLine):
|
| 249 |
+
device_idx: int
|
| 250 |
+
last_seen_device_guard_index: Optional[int]
|
| 251 |
+
|
| 252 |
+
def codegen(self, code: IndentedBuffer) -> None:
|
| 253 |
+
if V.graph.cpp_wrapper:
|
| 254 |
+
code.writeline("\n")
|
| 255 |
+
if V.graph.aot_mode:
|
| 256 |
+
# In AOT mode, we have a stream provided as a param. A stream is
|
| 257 |
+
# associated with a device, so we never expect the device to change.
|
| 258 |
+
# CUDAStreamGuard sets the stream and the device.
|
| 259 |
+
if self.last_seen_device_guard_index is None:
|
| 260 |
+
if config.abi_compatible:
|
| 261 |
+
code.writeline(
|
| 262 |
+
"AOTICudaStreamGuard stream_guard(stream, this->device_idx_);"
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
code.writeline(
|
| 266 |
+
"at::cuda::CUDAStreamGuard stream_guard("
|
| 267 |
+
+ "at::cuda::getStreamFromExternal(stream, this->device_idx_));"
|
| 268 |
+
)
|
| 269 |
+
else:
|
| 270 |
+
assert (
|
| 271 |
+
self.last_seen_device_guard_index == self.device_idx
|
| 272 |
+
), "AOTInductor only supports running on one CUDA device"
|
| 273 |
+
else:
|
| 274 |
+
if self.last_seen_device_guard_index is None:
|
| 275 |
+
code.writeline(
|
| 276 |
+
f"AOTICudaGuard device_guard({self.device_idx});"
|
| 277 |
+
if config.abi_compatible
|
| 278 |
+
else f"at::cuda::CUDAGuard device_guard({self.device_idx});"
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
code.writeline(f"device_guard.set_index({self.device_idx});")
|
| 282 |
+
else:
|
| 283 |
+
# Note _DeviceGuard has less overhead than device, but only accepts
|
| 284 |
+
# integers
|
| 285 |
+
code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:")
|
| 286 |
+
code.do_indent()
|
| 287 |
+
code.writeline(V.graph.device_ops.set_device(self.device_idx))
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class ExitDeviceContextManagerLine(WrapperLine):
|
| 291 |
+
def codegen(self, code: IndentedBuffer) -> None:
|
| 292 |
+
if not V.graph.cpp_wrapper:
|
| 293 |
+
code.do_unindent()
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
@dataclasses.dataclass
|
| 297 |
+
class MemoryPlanningLine(WrapperLine):
|
| 298 |
+
wrapper: "WrapperCodeGen"
|
| 299 |
+
|
| 300 |
+
def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
|
| 301 |
+
"""First pass to find reuse"""
|
| 302 |
+
return self
|
| 303 |
+
|
| 304 |
+
def codegen(self, code: IndentedBuffer) -> None:
|
| 305 |
+
"""Second pass to output code"""
|
| 306 |
+
pass
|
| 307 |
+
|
| 308 |
+
def __str__(self) -> str:
|
| 309 |
+
"""
|
| 310 |
+
Emits a string representation that fits on one line.
|
| 311 |
+
"""
|
| 312 |
+
args: List[str] = []
|
| 313 |
+
for field in dataclasses.fields(self):
|
| 314 |
+
if field.name == "wrapper":
|
| 315 |
+
continue
|
| 316 |
+
val = getattr(self, field.name)
|
| 317 |
+
args.append(
|
| 318 |
+
f"{field.name}={val.get_name() if field.type is ir.Buffer else val}"
|
| 319 |
+
)
|
| 320 |
+
return f"{type(self).__name__}({', '.join(args)})"
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@dataclasses.dataclass
|
| 324 |
+
class AllocateLine(MemoryPlanningLine):
|
| 325 |
+
node: ir.Buffer
|
| 326 |
+
|
| 327 |
+
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
| 328 |
+
if self.node.get_name() in V.graph.removed_buffers:
|
| 329 |
+
return NullLine(self.wrapper)
|
| 330 |
+
|
| 331 |
+
# try to reuse a recently freed buffer
|
| 332 |
+
key = buffer_reuse_key(self.node)
|
| 333 |
+
if config.allow_buffer_reuse and key in state:
|
| 334 |
+
free_line = state.pop(key)
|
| 335 |
+
free_line.is_reused = True
|
| 336 |
+
return ReuseLine(self.wrapper, free_line.node, self.node)
|
| 337 |
+
|
| 338 |
+
if self.node.get_device().type == "cpu":
|
| 339 |
+
static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
|
| 340 |
+
if static_shape is not None:
|
| 341 |
+
state.total_allocated_buffer_size += int(
|
| 342 |
+
functools.reduce(operator.mul, static_shape, 1)
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
return self
|
| 346 |
+
|
| 347 |
+
def codegen(self, code: IndentedBuffer) -> None:
|
| 348 |
+
assert self.node.get_name() not in V.graph.removed_buffers
|
| 349 |
+
line = self.wrapper.make_buffer_allocation(self.node)
|
| 350 |
+
code.writeline(line)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
@dataclasses.dataclass
|
| 354 |
+
class FreeIfNotReusedLine(MemoryPlanningLine):
|
| 355 |
+
node: ir.Buffer
|
| 356 |
+
is_reused: bool = False
|
| 357 |
+
|
| 358 |
+
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
| 359 |
+
if isinstance(self.node.layout, (ir.AliasedLayout, ir.MultiOutputLayout)):
|
| 360 |
+
return self
|
| 361 |
+
assert not self.is_reused
|
| 362 |
+
if self.node.get_name() in V.graph.removed_buffers:
|
| 363 |
+
return NullLine(self.wrapper)
|
| 364 |
+
if config.allow_buffer_reuse:
|
| 365 |
+
state.push(buffer_reuse_key(self.node), self)
|
| 366 |
+
return self
|
| 367 |
+
|
| 368 |
+
def codegen(self, code: IndentedBuffer) -> None:
|
| 369 |
+
assert self.node.get_name() not in V.graph.removed_buffers
|
| 370 |
+
if not self.is_reused:
|
| 371 |
+
code.writeline(self.wrapper.make_buffer_free(self.node))
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
@dataclasses.dataclass
|
| 375 |
+
class ReuseLine(MemoryPlanningLine):
|
| 376 |
+
node: ir.Buffer
|
| 377 |
+
reused_as: ir.Buffer
|
| 378 |
+
delete_old: bool = True
|
| 379 |
+
|
| 380 |
+
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
| 381 |
+
if self.node.get_name() in V.graph.removed_buffers:
|
| 382 |
+
assert self.reused_as.get_name() in V.graph.removed_buffers
|
| 383 |
+
return NullLine(self.wrapper)
|
| 384 |
+
assert self.reused_as.get_name() not in V.graph.removed_buffers
|
| 385 |
+
return self
|
| 386 |
+
|
| 387 |
+
def codegen(self, code: IndentedBuffer) -> None:
|
| 388 |
+
assert self.node.get_name() not in V.graph.removed_buffers
|
| 389 |
+
assert self.reused_as.get_name() not in V.graph.removed_buffers
|
| 390 |
+
code.writeline(
|
| 391 |
+
self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old)
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class NullLine(MemoryPlanningLine):
|
| 396 |
+
pass
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
BufferName = str
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class WrapperCodeGen(CodeGen):
|
| 403 |
+
"""
|
| 404 |
+
Generate outer wrapper in Python that calls the kernels.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
def __init__(self):
|
| 408 |
+
super().__init__()
|
| 409 |
+
self._names_iter: Iterator[int] = count()
|
| 410 |
+
self.header = IndentedBuffer()
|
| 411 |
+
self.prefix = IndentedBuffer()
|
| 412 |
+
self.suffix = IndentedBuffer()
|
| 413 |
+
self.wrapper_call = IndentedBuffer()
|
| 414 |
+
# If the generated source code is exactly the same, reuse the
|
| 415 |
+
# pre-existing kernel for it
|
| 416 |
+
self.src_to_kernel: Dict[str, str] = {}
|
| 417 |
+
self.kernel_numel_expr: Set[Tuple[str, "GraphLowering"]] = set()
|
| 418 |
+
self.lines: List[Union[MemoryPlanningLine, LineContext]] = []
|
| 419 |
+
self.declare = ""
|
| 420 |
+
self.declare_maybe_reference = ""
|
| 421 |
+
self.ending = ""
|
| 422 |
+
self.open_bracket = "["
|
| 423 |
+
self.closed_bracket = "]"
|
| 424 |
+
self.comment = "#"
|
| 425 |
+
self.namespace = ""
|
| 426 |
+
self.none_str = "None"
|
| 427 |
+
self.size = "size()"
|
| 428 |
+
self.stride = "stride()"
|
| 429 |
+
self.last_seen_device_guard_index: Optional[int] = None
|
| 430 |
+
self.supports_intermediate_hooks = True
|
| 431 |
+
self.expr_printer = pexpr
|
| 432 |
+
self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {}
|
| 433 |
+
self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol
|
| 434 |
+
self.allow_stack_allocation: Optional[bool] = None
|
| 435 |
+
self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {}
|
| 436 |
+
self.computed_sizes: Set[sympy.Symbol] = set()
|
| 437 |
+
|
| 438 |
+
# this is used for tracking which GraphLowering instance---parent graph
|
| 439 |
+
# or (nested) subgraph---is currently codegened; the primary use case is
|
| 440 |
+
# including the graph instance into a cache key to avoid cross-graph
|
| 441 |
+
# caching during lowering of nested subgraphs
|
| 442 |
+
self.codegened_graph_stack = [V.graph]
|
| 443 |
+
|
| 444 |
+
self.write_header()
|
| 445 |
+
self.write_prefix()
|
| 446 |
+
|
| 447 |
+
if not V.graph.aot_mode:
|
| 448 |
+
for name, hashed in V.graph.constant_reprs.items():
|
| 449 |
+
# include a hash so our code cache puts different constants into different files
|
| 450 |
+
self.write_constant(name, hashed)
|
| 451 |
+
|
| 452 |
+
self.allocated: Set[BufferName] = set()
|
| 453 |
+
self.freed: Set[BufferName] = set()
|
| 454 |
+
|
| 455 |
+
# maps from reusing buffer to reused buffer
|
| 456 |
+
self.reuses: Dict[BufferName, BufferName] = dict()
|
| 457 |
+
|
| 458 |
+
self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment]
|
| 459 |
+
self.write_get_raw_stream
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
@functools.lru_cache(None)
|
| 463 |
+
def add_import_once(line: str) -> None:
|
| 464 |
+
self.header.writeline(line)
|
| 465 |
+
|
| 466 |
+
self.add_import_once = add_import_once
|
| 467 |
+
self._metas: Dict[str, str] = {}
|
| 468 |
+
self.multi_kernel_state = MultiKernelState()
|
| 469 |
+
|
| 470 |
+
def write_constant(self, name: str, hashed: str) -> None:
|
| 471 |
+
self.header.writeline(f"{name} = None # {hashed}")
|
| 472 |
+
|
| 473 |
+
def write_header(self) -> None:
|
| 474 |
+
self.header.splice(
|
| 475 |
+
f"""
|
| 476 |
+
from ctypes import c_void_p, c_long
|
| 477 |
+
import torch
|
| 478 |
+
import math
|
| 479 |
+
import random
|
| 480 |
+
import os
|
| 481 |
+
import tempfile
|
| 482 |
+
from math import inf, nan
|
| 483 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 484 |
+
from torch._inductor.utils import maybe_profile
|
| 485 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 486 |
+
|
| 487 |
+
from torch import device, empty_strided
|
| 488 |
+
from {codecache.__name__} import AsyncCompile
|
| 489 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 490 |
+
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
| 491 |
+
|
| 492 |
+
aten = torch.ops.aten
|
| 493 |
+
inductor_ops = torch.ops.inductor
|
| 494 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 495 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 496 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 497 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 498 |
+
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
|
| 499 |
+
async_compile = AsyncCompile()
|
| 500 |
+
|
| 501 |
+
"""
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
@cache_on_self
|
| 505 |
+
def write_triton_header_once(self) -> None:
|
| 506 |
+
self.header.splice(
|
| 507 |
+
"""
|
| 508 |
+
import triton
|
| 509 |
+
import triton.language as tl
|
| 510 |
+
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
|
| 511 |
+
{}
|
| 512 |
+
""".format(
|
| 513 |
+
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
| 514 |
+
)
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
def add_meta_once(self, meta: TritonMetaParams) -> str:
|
| 518 |
+
meta = repr(meta)
|
| 519 |
+
if meta not in self._metas:
|
| 520 |
+
var = f"meta{len(self._metas)}"
|
| 521 |
+
self._metas[meta] = var
|
| 522 |
+
self.header.writeline(f"{var} = {meta}")
|
| 523 |
+
return self._metas[meta]
|
| 524 |
+
|
| 525 |
+
@cache_on_self
|
| 526 |
+
def get_output_refs(self) -> List[str]:
|
| 527 |
+
return [x.codegen_reference(self.wrapper_call) for x in V.graph.graph_outputs]
|
| 528 |
+
|
| 529 |
+
def mark_output_type(self) -> None:
|
| 530 |
+
return
|
| 531 |
+
|
| 532 |
+
def codegen_input_size_asserts(self) -> None:
|
| 533 |
+
for name, buf in V.graph.graph_inputs.items():
|
| 534 |
+
if isinstance(buf, sympy.Expr):
|
| 535 |
+
continue
|
| 536 |
+
|
| 537 |
+
# comparing strides for 0 size tensor is tricky. Ignore them for now.
|
| 538 |
+
if sympy_product(buf.get_size()) == 0:
|
| 539 |
+
continue
|
| 540 |
+
size = self.codegen_shape_tuple(buf.get_size())
|
| 541 |
+
stride = self.codegen_shape_tuple(buf.get_stride())
|
| 542 |
+
self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})")
|
| 543 |
+
|
| 544 |
+
def codegen_input_nan_asserts(self) -> None:
|
| 545 |
+
self.prefix.writeline("# make sure graph inputs are not nan/inf")
|
| 546 |
+
for name, buf in V.graph.graph_inputs.items():
|
| 547 |
+
if isinstance(buf, sympy.Expr):
|
| 548 |
+
continue
|
| 549 |
+
|
| 550 |
+
line = f"assert not {name}.isnan().any().item()"
|
| 551 |
+
self.prefix.writeline(line)
|
| 552 |
+
line = f"assert not {name}.isinf().any().item()"
|
| 553 |
+
self.prefix.writeline(line)
|
| 554 |
+
|
| 555 |
+
def write_prefix(self) -> None:
|
| 556 |
+
self.prefix.splice(
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
async_compile.wait(globals())
|
| 560 |
+
del async_compile
|
| 561 |
+
|
| 562 |
+
def call(args):
|
| 563 |
+
"""
|
| 564 |
+
)
|
| 565 |
+
with self.prefix.indent():
|
| 566 |
+
if config.triton.debug_sync_graph:
|
| 567 |
+
self.prefix.writeline(V.graph.device_ops.synchronize())
|
| 568 |
+
if V.graph.graph_inputs:
|
| 569 |
+
lhs = ", ".join(V.graph.graph_input_names)
|
| 570 |
+
if len(V.graph.graph_input_names) == 1:
|
| 571 |
+
lhs += ","
|
| 572 |
+
self.prefix.writeline(f"{lhs} = args")
|
| 573 |
+
self.prefix.writeline("args.clear()")
|
| 574 |
+
|
| 575 |
+
self.codegen_inputs(self.prefix, V.graph.graph_inputs)
|
| 576 |
+
if config.size_asserts:
|
| 577 |
+
self.codegen_input_size_asserts()
|
| 578 |
+
if config.nan_asserts:
|
| 579 |
+
self.codegen_input_nan_asserts()
|
| 580 |
+
|
| 581 |
+
# this function (and below) takes a graph as input so
|
| 582 |
+
# that stream caching happens per graph instance. this
|
| 583 |
+
# is important for nested subgraph codegening.
|
| 584 |
+
def write_get_raw_stream(self, device_idx: int, graph=None) -> str:
|
| 585 |
+
self.write_triton_header_once()
|
| 586 |
+
name = f"stream{device_idx}"
|
| 587 |
+
self.writeline(f"{name} = get_raw_stream({device_idx})")
|
| 588 |
+
return name
|
| 589 |
+
|
| 590 |
+
def get_codegened_graph(self):
|
| 591 |
+
return self.codegened_graph_stack[-1]
|
| 592 |
+
|
| 593 |
+
def push_codegened_graph(self, graph):
|
| 594 |
+
self.codegened_graph_stack.append(graph)
|
| 595 |
+
|
| 596 |
+
def pop_codegened_graph(self):
|
| 597 |
+
return self.codegened_graph_stack.pop()
|
| 598 |
+
|
| 599 |
+
def next_kernel_suffix(self) -> str:
|
| 600 |
+
return f"{next(self._names_iter)}"
|
| 601 |
+
|
| 602 |
+
def codegen_device_guard_enter(self, device_idx: int) -> None:
|
| 603 |
+
self.writeline(
|
| 604 |
+
EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index)
|
| 605 |
+
)
|
| 606 |
+
self.last_seen_device_guard_index = device_idx
|
| 607 |
+
|
| 608 |
+
def codegen_device_guard_exit(self) -> None:
|
| 609 |
+
self.writeline(ExitDeviceContextManagerLine())
|
| 610 |
+
|
| 611 |
+
def generate_return(self, output_refs: List[str]) -> None:
|
| 612 |
+
if output_refs:
|
| 613 |
+
self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
|
| 614 |
+
else:
|
| 615 |
+
self.wrapper_call.writeline("return ()")
|
| 616 |
+
|
| 617 |
+
def generate_before_suffix(self, result: IndentedBuffer) -> None:
|
| 618 |
+
return
|
| 619 |
+
|
| 620 |
+
def generate_end(self, result: IndentedBuffer) -> None:
|
| 621 |
+
return
|
| 622 |
+
|
| 623 |
+
def generate_fallback_kernel(self, fallback_kernel, args):
|
| 624 |
+
self.generate_extern_kernel_alloc(fallback_kernel, args)
|
| 625 |
+
|
| 626 |
+
def generate_extern_kernel_alloc(self, extern_kernel, args):
|
| 627 |
+
output_name = extern_kernel.get_name()
|
| 628 |
+
origin_node = extern_kernel.get_origin_node()
|
| 629 |
+
kernel_name = extern_kernel.get_kernel_name()
|
| 630 |
+
ending = self.ending
|
| 631 |
+
if config.memory_planning and "view_as_complex" in kernel_name:
|
| 632 |
+
# view operation fallbacks cause issues since inductor
|
| 633 |
+
# doesn't know the memory is still needed and might reuse it.
|
| 634 |
+
ending = f".clone(){ending}"
|
| 635 |
+
self.writeline(
|
| 636 |
+
f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}"
|
| 637 |
+
)
|
| 638 |
+
if (
|
| 639 |
+
self.supports_intermediate_hooks
|
| 640 |
+
and config.generate_intermediate_hooks
|
| 641 |
+
and origin_node is not None
|
| 642 |
+
):
|
| 643 |
+
counters["inductor"]["intermediate_hooks"] += 1
|
| 644 |
+
self.writeline(
|
| 645 |
+
f"run_intermediate_hooks({origin_node.name!r}, {output_name})"
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
|
| 649 |
+
if output_view:
|
| 650 |
+
args.append(f"out={output_view.codegen_reference()}")
|
| 651 |
+
else:
|
| 652 |
+
args.append(f"out={codegen_reference}")
|
| 653 |
+
self.writeline(f"{kernel}({', '.join(args)})")
|
| 654 |
+
|
| 655 |
+
def generate_user_defined_triton_kernel(
|
| 656 |
+
self, kernel_name, grid, configs, args, triton_meta
|
| 657 |
+
):
|
| 658 |
+
grid, code = user_defined_kernel_grid_fn_code(
|
| 659 |
+
kernel_name, configs, grid, wrapper=self
|
| 660 |
+
)
|
| 661 |
+
# Must happen after free symbols are already codegened
|
| 662 |
+
# Emit the grid wrapper function right before the call
|
| 663 |
+
for line in code.split("\n"):
|
| 664 |
+
self.writeline(line)
|
| 665 |
+
|
| 666 |
+
stream_name = self.write_get_raw_stream(
|
| 667 |
+
V.graph.scheduler.current_device.index, V.graph
|
| 668 |
+
)
|
| 669 |
+
self.writeline(
|
| 670 |
+
f"{kernel_name}.run({', '.join(args)}, grid={grid}, stream={stream_name})"
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
def generate_scatter_fallback(
|
| 674 |
+
self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs
|
| 675 |
+
):
|
| 676 |
+
line = f"{kernel}({','.join(map(str, inputs))}"
|
| 677 |
+
if kernel == "aten.scatter_":
|
| 678 |
+
if reduce:
|
| 679 |
+
line += f", reduce={repr(reduce)}"
|
| 680 |
+
else:
|
| 681 |
+
line += ", ".join([""] + kwargs)
|
| 682 |
+
line += f"){self.ending}"
|
| 683 |
+
self.writeline(line)
|
| 684 |
+
|
| 685 |
+
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
|
| 686 |
+
indices_str = f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}"
|
| 687 |
+
args = [x, indices_str, values, accumulate]
|
| 688 |
+
self.writeline(self.wrap_kernel_call(kernel, args))
|
| 689 |
+
|
| 690 |
+
def generate_extern_kernel_alloc_and_find_schema_if_needed(
|
| 691 |
+
self,
|
| 692 |
+
name,
|
| 693 |
+
kernel,
|
| 694 |
+
codegen_args,
|
| 695 |
+
cpp_op_schema,
|
| 696 |
+
cpp_kernel_key,
|
| 697 |
+
cpp_kernel_overload_name="",
|
| 698 |
+
op_overload=None,
|
| 699 |
+
raw_args=None,
|
| 700 |
+
outputs=None,
|
| 701 |
+
):
|
| 702 |
+
self.writeline(f"{name} = {kernel}({', '.join(codegen_args)})")
|
| 703 |
+
|
| 704 |
+
def generate_inf_and_nan_checker(self, node):
|
| 705 |
+
# TODO: Add check for python too.
|
| 706 |
+
pass
|
| 707 |
+
|
| 708 |
+
@dynamo_timed
|
| 709 |
+
def generate(self, is_inference):
|
| 710 |
+
if config.profile_bandwidth:
|
| 711 |
+
self.write_triton_header_once()
|
| 712 |
+
result = IndentedBuffer()
|
| 713 |
+
result.splice(self.header)
|
| 714 |
+
|
| 715 |
+
with contextlib.ExitStack() as stack:
|
| 716 |
+
stack.enter_context(self.wrapper_call.indent())
|
| 717 |
+
if config.profiler_mark_wrapper_call:
|
| 718 |
+
self.generate_profiler_mark_wrapper_call(stack)
|
| 719 |
+
if config.profile_bandwidth:
|
| 720 |
+
self.generate_start_graph()
|
| 721 |
+
|
| 722 |
+
# We disable planning during training because it presently increases peak memory consumption.
|
| 723 |
+
if is_inference and config.memory_planning:
|
| 724 |
+
self.memory_plan()
|
| 725 |
+
# TODO: integrate memory planning & stack allocation?
|
| 726 |
+
self.allow_stack_allocation = False
|
| 727 |
+
else:
|
| 728 |
+
self.memory_plan_reuse()
|
| 729 |
+
|
| 730 |
+
if config.triton.store_cubin:
|
| 731 |
+
self.generate_reset_kernel_saved_flags()
|
| 732 |
+
|
| 733 |
+
for line in self.lines:
|
| 734 |
+
if isinstance(line, WrapperLine):
|
| 735 |
+
line.codegen(self.wrapper_call)
|
| 736 |
+
else:
|
| 737 |
+
self.wrapper_call.writeline(line)
|
| 738 |
+
|
| 739 |
+
output_refs = self.get_output_refs()
|
| 740 |
+
self.mark_output_type()
|
| 741 |
+
if config.triton.debug_sync_graph:
|
| 742 |
+
self.wrapper_call.writeline(V.graph.device_ops.synchronize())
|
| 743 |
+
|
| 744 |
+
if config.profile_bandwidth:
|
| 745 |
+
self.generate_end_graph()
|
| 746 |
+
|
| 747 |
+
if config.triton.store_cubin:
|
| 748 |
+
self.generate_save_uncompiled_kernels()
|
| 749 |
+
|
| 750 |
+
self.generate_return(output_refs)
|
| 751 |
+
|
| 752 |
+
self.finalize_prefix()
|
| 753 |
+
result.splice(self.prefix)
|
| 754 |
+
|
| 755 |
+
with result.indent():
|
| 756 |
+
result.splice(self.wrapper_call)
|
| 757 |
+
|
| 758 |
+
self.generate_before_suffix(result)
|
| 759 |
+
result.splice(self.suffix)
|
| 760 |
+
|
| 761 |
+
self.generate_end(result)
|
| 762 |
+
|
| 763 |
+
self.add_benchmark_harness(result)
|
| 764 |
+
|
| 765 |
+
return result.getvaluewithlinemap()
|
| 766 |
+
|
| 767 |
+
def memory_plan(self):
|
| 768 |
+
from .memory_planning import MemoryPlanner
|
| 769 |
+
|
| 770 |
+
self.lines = MemoryPlanner(self).plan(self.lines)
|
| 771 |
+
|
| 772 |
+
def memory_plan_reuse(self):
|
| 773 |
+
out_names = V.graph.get_output_names()
|
| 774 |
+
|
| 775 |
+
while (
|
| 776 |
+
self.lines
|
| 777 |
+
and isinstance(self.lines[-1], MemoryPlanningLine)
|
| 778 |
+
# TODO: this seems legit, NullLine has no node
|
| 779 |
+
and self.lines[-1].node.name not in out_names # type: ignore[attr-defined]
|
| 780 |
+
):
|
| 781 |
+
# these lines will be pointless
|
| 782 |
+
self.lines.pop()
|
| 783 |
+
|
| 784 |
+
# codegen allocations in two passes
|
| 785 |
+
planning_states = [MemoryPlanningState()]
|
| 786 |
+
past_planning_states = []
|
| 787 |
+
for i in range(len(self.lines)):
|
| 788 |
+
line = self.lines[i]
|
| 789 |
+
if isinstance(line, MemoryPlanningLine):
|
| 790 |
+
self.lines[i] = line.plan(planning_states[-1])
|
| 791 |
+
elif isinstance(line, EnterSubgraphLine):
|
| 792 |
+
planning_states.append(MemoryPlanningState())
|
| 793 |
+
elif isinstance(line, ExitSubgraphLine):
|
| 794 |
+
past_planning_states.append(planning_states.pop())
|
| 795 |
+
past_planning_states.append(planning_states.pop())
|
| 796 |
+
assert len(planning_states) == 0
|
| 797 |
+
|
| 798 |
+
# conservatively use the sum of all allocated buffer sizes
|
| 799 |
+
# in potentially nested scopes as the total allocated size
|
| 800 |
+
total_allocated_buffer_size = sum(
|
| 801 |
+
s.total_allocated_buffer_size for s in past_planning_states
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
self.allow_stack_allocation = (
|
| 805 |
+
self.allow_stack_allocation is not False
|
| 806 |
+
and config.allow_stack_allocation
|
| 807 |
+
and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
|
| 811 |
+
code.writeline(f"{self.declare}{name}_size = {name}.{self.size}{self.ending}")
|
| 812 |
+
|
| 813 |
+
def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
|
| 814 |
+
code.writeline(
|
| 815 |
+
f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}"
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
def codegen_inputs(
|
| 819 |
+
self, code: IndentedBuffer, graph_inputs: Dict[str, ir.TensorBox]
|
| 820 |
+
):
|
| 821 |
+
"""Assign all symbolic shapes to locals"""
|
| 822 |
+
|
| 823 |
+
@functools.lru_cache(None)
|
| 824 |
+
def sizeof(name):
|
| 825 |
+
self.codegen_input_size_var_decl(code, name)
|
| 826 |
+
return f"{name}_size"
|
| 827 |
+
|
| 828 |
+
@functools.lru_cache(None)
|
| 829 |
+
def strideof(name):
|
| 830 |
+
self.codegen_input_stride_var_decl(code, name)
|
| 831 |
+
return f"{name}_stride"
|
| 832 |
+
|
| 833 |
+
# Assign all symbolic shapes needed to local variables
|
| 834 |
+
needed = V.graph.sizevars.free_symbols()
|
| 835 |
+
|
| 836 |
+
def is_expr(x):
|
| 837 |
+
return isinstance(x[1], sympy.Expr)
|
| 838 |
+
|
| 839 |
+
graph_inputs_expr = list(filter(is_expr, graph_inputs.items()))
|
| 840 |
+
graph_inputs_tensors = list(
|
| 841 |
+
filter(lambda x: not is_expr(x), graph_inputs.items())
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
for name, shape in graph_inputs_expr:
|
| 845 |
+
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
|
| 846 |
+
if shape in needed:
|
| 847 |
+
needed.remove(shape) # type: ignore[arg-type]
|
| 848 |
+
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
|
| 849 |
+
|
| 850 |
+
for name, value in graph_inputs_tensors:
|
| 851 |
+
shapes = value.get_size()
|
| 852 |
+
for dim, shape in enumerate(shapes):
|
| 853 |
+
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
|
| 854 |
+
if shape in needed:
|
| 855 |
+
needed.remove(shape) # type: ignore[arg-type]
|
| 856 |
+
code.writeline(
|
| 857 |
+
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
for name, value in graph_inputs_tensors:
|
| 861 |
+
shapes = value.get_stride()
|
| 862 |
+
for dim, shape in enumerate(shapes):
|
| 863 |
+
shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type]
|
| 864 |
+
if shape in needed:
|
| 865 |
+
needed.remove(shape) # type: ignore[arg-type]
|
| 866 |
+
code.writeline(
|
| 867 |
+
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
def ensure_size_computed(self, sym: sympy.Symbol):
|
| 871 |
+
if isinstance(sym, sympy.Symbol) and sym.name.startswith("ps"):
|
| 872 |
+
if sym in self.computed_sizes:
|
| 873 |
+
return
|
| 874 |
+
self.computed_sizes.add(sym)
|
| 875 |
+
expr = V.graph.sizevars.inv_precomputed_replacements[sym]
|
| 876 |
+
self.writeline(
|
| 877 |
+
f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}"
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
def finalize_prefix(self):
|
| 881 |
+
pass
|
| 882 |
+
|
| 883 |
+
def codegen_python_sizevar(self, x: Expr) -> str:
|
| 884 |
+
return pexpr(V.graph.sizevars.simplify(x))
|
| 885 |
+
|
| 886 |
+
def codegen_sizevar(self, x: Expr) -> str:
|
| 887 |
+
return self.codegen_python_sizevar(x)
|
| 888 |
+
|
| 889 |
+
def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
|
| 890 |
+
return f"{basename}[{index}]"
|
| 891 |
+
|
| 892 |
+
def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
|
| 893 |
+
parts = list(map(self.codegen_python_sizevar, shape))
|
| 894 |
+
if len(parts) == 0:
|
| 895 |
+
return "()"
|
| 896 |
+
if len(parts) == 1:
|
| 897 |
+
return f"({parts[0]}, )"
|
| 898 |
+
return f"({', '.join(parts)})"
|
| 899 |
+
|
| 900 |
+
def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
|
| 901 |
+
return self.codegen_python_shape_tuple(shape)
|
| 902 |
+
|
| 903 |
+
def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
|
| 904 |
+
return "alloc_from_pool({})".format(
|
| 905 |
+
", ".join(
|
| 906 |
+
[
|
| 907 |
+
name,
|
| 908 |
+
pexpr(offset), # bytes not numel
|
| 909 |
+
str(dtype),
|
| 910 |
+
self.codegen_shape_tuple(shape),
|
| 911 |
+
self.codegen_shape_tuple(stride),
|
| 912 |
+
]
|
| 913 |
+
)
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
def codegen_reinterpret_view(self, data, size, stride, offset, writer) -> str:
|
| 917 |
+
size = self.codegen_shape_tuple(size)
|
| 918 |
+
stride = self.codegen_shape_tuple(stride)
|
| 919 |
+
offset = self.codegen_sizevar(offset)
|
| 920 |
+
return f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})"
|
| 921 |
+
|
| 922 |
+
def codegen_device_copy(self, src, dst):
|
| 923 |
+
self.writeline(f"{dst}.copy_({src})")
|
| 924 |
+
|
| 925 |
+
def codegen_multi_output(self, name, value):
|
| 926 |
+
self.writeline(f"{self.declare}{name} = {value}{self.ending}")
|
| 927 |
+
|
| 928 |
+
def codegen_dynamic_scalar(self, node):
|
| 929 |
+
(data,) = (t.codegen_reference() for t in node.inputs)
|
| 930 |
+
if node.is_bool:
|
| 931 |
+
self.writeline(f"{node.sym} = 1 if {data}.item() else 0")
|
| 932 |
+
else:
|
| 933 |
+
self.writeline(f"{node.sym} = {data}.item()")
|
| 934 |
+
# No one should ever use this buffer, but for uniformity
|
| 935 |
+
# define the variable and assign it None
|
| 936 |
+
self.writeline(f"{node.get_name()} = None")
|
| 937 |
+
|
| 938 |
+
def benchmark_compiled_module(self, output):
|
| 939 |
+
def add_fake_input(name, shape, stride, device, dtype):
|
| 940 |
+
output.writeline(
|
| 941 |
+
f"{name} = rand_strided("
|
| 942 |
+
f"{self.codegen_python_shape_tuple(shape)}, "
|
| 943 |
+
f"{self.codegen_python_shape_tuple(stride)}, "
|
| 944 |
+
f"device='{device}', dtype={dtype})"
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
def add_expr_input(name, val):
|
| 948 |
+
output.writeline(f"{name} = {val}")
|
| 949 |
+
|
| 950 |
+
output.writelines(
|
| 951 |
+
["", "", "def benchmark_compiled_module(times=10, repeat=10):"]
|
| 952 |
+
)
|
| 953 |
+
with output.indent():
|
| 954 |
+
output.splice(
|
| 955 |
+
"""
|
| 956 |
+
from torch._dynamo.testing import rand_strided
|
| 957 |
+
from torch._inductor.utils import print_performance
|
| 958 |
+
""",
|
| 959 |
+
strip=True,
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
for name, value in V.graph.constants.items():
|
| 963 |
+
# all the constants are global variables, that's why we need
|
| 964 |
+
# these 'global var_name' lines
|
| 965 |
+
output.writeline(f"global {name}")
|
| 966 |
+
add_fake_input(
|
| 967 |
+
name, value.size(), value.stride(), value.device, value.dtype
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
for name, value in V.graph.graph_inputs.items():
|
| 971 |
+
if isinstance(value, sympy.Symbol) and isinstance(
|
| 972 |
+
V.graph.sizevars.var_to_val.get(value, None), SingletonInt
|
| 973 |
+
):
|
| 974 |
+
# Inductor should only work with dense -> dense graph, and
|
| 975 |
+
# SingletonInts belong to metadata that should only live on
|
| 976 |
+
# the subclass.
|
| 977 |
+
continue
|
| 978 |
+
if isinstance(value, sympy.Expr): # Don't need to add symbolic
|
| 979 |
+
add_expr_input(name, V.graph.sizevars.size_hint(value))
|
| 980 |
+
else:
|
| 981 |
+
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
|
| 982 |
+
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
|
| 983 |
+
add_fake_input(
|
| 984 |
+
name, shape, stride, value.get_device(), value.get_dtype()
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])"
|
| 988 |
+
output.writeline(f"fn = lambda: {call_str}")
|
| 989 |
+
output.writeline("return print_performance(fn, times=times, repeat=repeat)")
|
| 990 |
+
|
| 991 |
+
def add_benchmark_harness(self, output):
|
| 992 |
+
"""
|
| 993 |
+
Append a benchmark harness to generated code for debugging
|
| 994 |
+
"""
|
| 995 |
+
if not config.benchmark_harness:
|
| 996 |
+
return
|
| 997 |
+
|
| 998 |
+
self.benchmark_compiled_module(output)
|
| 999 |
+
|
| 1000 |
+
output.writelines(["", "", 'if __name__ == "__main__":'])
|
| 1001 |
+
with output.indent():
|
| 1002 |
+
output.writelines(
|
| 1003 |
+
[
|
| 1004 |
+
"from torch._inductor.wrapper_benchmark import compiled_module_main",
|
| 1005 |
+
f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)",
|
| 1006 |
+
]
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
def define_kernel(
|
| 1010 |
+
self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
|
| 1011 |
+
):
|
| 1012 |
+
metadata_comment = f"{metadata}\n" if metadata else ""
|
| 1013 |
+
self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
|
| 1014 |
+
|
| 1015 |
+
def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
|
| 1016 |
+
original_name = kernel.__name__
|
| 1017 |
+
|
| 1018 |
+
from .common import KernelArgType, SizeArg, TensorArg
|
| 1019 |
+
|
| 1020 |
+
signature: List[KernelArgType] = []
|
| 1021 |
+
constants: Dict[int, Any] = {}
|
| 1022 |
+
non_constant_indices = []
|
| 1023 |
+
equal_to_1_arg_idx: List[int] = []
|
| 1024 |
+
for idx, key in enumerate(kernel.arg_names):
|
| 1025 |
+
if key not in kwargs:
|
| 1026 |
+
continue
|
| 1027 |
+
arg = kwargs[key]
|
| 1028 |
+
if idx in kernel.constexprs:
|
| 1029 |
+
constants[idx] = arg
|
| 1030 |
+
else:
|
| 1031 |
+
non_constant_indices.append(idx)
|
| 1032 |
+
if isinstance(arg, ir.Buffer):
|
| 1033 |
+
signature.append(
|
| 1034 |
+
TensorArg(
|
| 1035 |
+
name=key,
|
| 1036 |
+
buffer=arg.get_name(),
|
| 1037 |
+
dtype=arg.get_dtype(),
|
| 1038 |
+
)
|
| 1039 |
+
)
|
| 1040 |
+
elif isinstance(arg, ir.ReinterpretView):
|
| 1041 |
+
# for ReinterpretView we use the underlying
|
| 1042 |
+
# buffer name and note the (possibly non-zero)
|
| 1043 |
+
# offset relative to the underlying buffer
|
| 1044 |
+
signature.append(
|
| 1045 |
+
TensorArg(
|
| 1046 |
+
name=key,
|
| 1047 |
+
buffer=arg.data.get_name(),
|
| 1048 |
+
dtype=arg.get_dtype(),
|
| 1049 |
+
offset=arg.layout.offset,
|
| 1050 |
+
)
|
| 1051 |
+
)
|
| 1052 |
+
else:
|
| 1053 |
+
signature.append(SizeArg(key, arg))
|
| 1054 |
+
if arg is not None and V.graph.sizevars.statically_known_equals(arg, 1): # type: ignore[arg-type]
|
| 1055 |
+
equal_to_1_arg_idx.append(idx)
|
| 1056 |
+
index_dtype = "tl.int32"
|
| 1057 |
+
triton_meta = {
|
| 1058 |
+
"signature": signature_to_meta(
|
| 1059 |
+
signature,
|
| 1060 |
+
size_dtype=index_dtype,
|
| 1061 |
+
indices=non_constant_indices,
|
| 1062 |
+
),
|
| 1063 |
+
"device": V.graph.scheduler.current_device.index,
|
| 1064 |
+
"device_type": V.graph.scheduler.current_device.type,
|
| 1065 |
+
# Triton compiler includes equal_to_1 args into constants even
|
| 1066 |
+
# when they are not constexpr. otherwise there may be a segfault
|
| 1067 |
+
# during launching the Inductor-compiled Triton kernel.
|
| 1068 |
+
# TODO(aakhundov): add None args to constants, too. currently, this
|
| 1069 |
+
# causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input.
|
| 1070 |
+
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
|
| 1071 |
+
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
|
| 1072 |
+
"constants": {
|
| 1073 |
+
**constants,
|
| 1074 |
+
**{idx: 1 for idx in equal_to_1_arg_idx},
|
| 1075 |
+
},
|
| 1076 |
+
"configs": [
|
| 1077 |
+
config_of(
|
| 1078 |
+
signature,
|
| 1079 |
+
indices=non_constant_indices,
|
| 1080 |
+
)
|
| 1081 |
+
],
|
| 1082 |
+
}
|
| 1083 |
+
|
| 1084 |
+
# Distinguish between different functions using function id
|
| 1085 |
+
cache_key: List[Any] = [id(kernel.fn)]
|
| 1086 |
+
if len(configs) > 0:
|
| 1087 |
+
for arg in kwargs.values():
|
| 1088 |
+
# We need to key on non tensor arg only in autotune mode
|
| 1089 |
+
if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
|
| 1090 |
+
cache_key.append(arg)
|
| 1091 |
+
cache_key.append(str(triton_meta))
|
| 1092 |
+
cache_key = tuple(cache_key)
|
| 1093 |
+
|
| 1094 |
+
if cache_key in self.user_defined_kernel_cache:
|
| 1095 |
+
return self.user_defined_kernel_cache[cache_key]
|
| 1096 |
+
|
| 1097 |
+
name = f"{original_name}_{len(self.user_defined_kernel_cache)}"
|
| 1098 |
+
# Add to the cache for the next use
|
| 1099 |
+
self.user_defined_kernel_cache[cache_key] = (name, triton_meta)
|
| 1100 |
+
|
| 1101 |
+
compile_wrapper = IndentedBuffer()
|
| 1102 |
+
compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''")
|
| 1103 |
+
|
| 1104 |
+
from .triton import gen_common_triton_imports
|
| 1105 |
+
|
| 1106 |
+
compile_wrapper.splice(gen_common_triton_imports())
|
| 1107 |
+
|
| 1108 |
+
inductor_meta = {
|
| 1109 |
+
"kernel_name": name,
|
| 1110 |
+
"backend_hash": torch.utils._triton.triton_hash_with_backend(),
|
| 1111 |
+
}
|
| 1112 |
+
|
| 1113 |
+
configs = [
|
| 1114 |
+
{
|
| 1115 |
+
"kwargs": config.kwargs,
|
| 1116 |
+
"num_warps": config.num_warps,
|
| 1117 |
+
"num_stages": config.num_stages,
|
| 1118 |
+
}
|
| 1119 |
+
for config in configs
|
| 1120 |
+
]
|
| 1121 |
+
|
| 1122 |
+
compile_wrapper.splice(
|
| 1123 |
+
f"""
|
| 1124 |
+
@triton_heuristics.user_autotune(
|
| 1125 |
+
configs={configs!r},
|
| 1126 |
+
inductor_meta={inductor_meta!r},
|
| 1127 |
+
triton_meta={triton_meta!r},
|
| 1128 |
+
filename=__file__,
|
| 1129 |
+
custom_kernel=True,
|
| 1130 |
+
)
|
| 1131 |
+
@triton.jit
|
| 1132 |
+
"""
|
| 1133 |
+
)
|
| 1134 |
+
compile_wrapper.splice(kernel.src, strip=True)
|
| 1135 |
+
|
| 1136 |
+
# Also include any possible kernel being called indirectly
|
| 1137 |
+
from triton import JITFunction
|
| 1138 |
+
|
| 1139 |
+
symbols_included = {original_name}
|
| 1140 |
+
|
| 1141 |
+
def traverse(cur_kernel):
|
| 1142 |
+
for symbol_name in cur_kernel.fn.__code__.co_names:
|
| 1143 |
+
if symbol_name in symbols_included:
|
| 1144 |
+
continue
|
| 1145 |
+
if symbol_name in cur_kernel.fn.__globals__:
|
| 1146 |
+
symbol = cur_kernel.fn.__globals__[symbol_name]
|
| 1147 |
+
if isinstance(symbol, JITFunction):
|
| 1148 |
+
compile_wrapper.newline()
|
| 1149 |
+
compile_wrapper.writeline("@triton.jit")
|
| 1150 |
+
compile_wrapper.splice(symbol.src, strip=True)
|
| 1151 |
+
symbols_included.add(symbol_name)
|
| 1152 |
+
traverse(symbol)
|
| 1153 |
+
elif isinstance(symbol, (int, str, bool)):
|
| 1154 |
+
compile_wrapper.newline()
|
| 1155 |
+
compile_wrapper.writeline(f"{symbol_name} = {symbol!r}")
|
| 1156 |
+
symbols_included.add(symbol_name)
|
| 1157 |
+
|
| 1158 |
+
traverse(kernel)
|
| 1159 |
+
|
| 1160 |
+
compile_wrapper.writeline(
|
| 1161 |
+
f"''', device_str='{V.graph.scheduler.current_device.type}')"
|
| 1162 |
+
)
|
| 1163 |
+
_, lineno = inspect.getsourcelines(kernel.fn)
|
| 1164 |
+
srcfile = inspect.getsourcefile(kernel.fn)
|
| 1165 |
+
metadata = f"# Original path: {srcfile}:{lineno}"
|
| 1166 |
+
self.define_kernel(
|
| 1167 |
+
name,
|
| 1168 |
+
compile_wrapper.getvalue(),
|
| 1169 |
+
metadata,
|
| 1170 |
+
)
|
| 1171 |
+
return name, triton_meta
|
| 1172 |
+
|
| 1173 |
+
def generate_numel_expr(self, kernel_name: str, tree):
|
| 1174 |
+
expr = f"{kernel_name}_{tree.prefix}numel"
|
| 1175 |
+
if (expr, V.graph) not in self.kernel_numel_expr:
|
| 1176 |
+
# declare expr once in each graph (scope)
|
| 1177 |
+
self.kernel_numel_expr.add((expr, V.graph))
|
| 1178 |
+
self.writeline(
|
| 1179 |
+
f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}"
|
| 1180 |
+
)
|
| 1181 |
+
else:
|
| 1182 |
+
self.writeline(f"{expr} = {self.expr_printer(tree.numel)}{self.ending}")
|
| 1183 |
+
# We can get symbolic expressions here, like s0*64
|
| 1184 |
+
# It is fine to have them here, but we need to handle them correctly as their own type
|
| 1185 |
+
# This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*
|
| 1186 |
+
# scalars as well.
|
| 1187 |
+
# This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for
|
| 1188 |
+
# constant now, need type info. I agree, this needs type info, and while this is not true type info
|
| 1189 |
+
# it suffices as a type hint for the purposes of producing the correct code for this type.
|
| 1190 |
+
return SymbolicCallArg(expr, tree.numel)
|
| 1191 |
+
|
| 1192 |
+
def generate_workspace_allocation(self, nbytes, device, zero_fill):
|
| 1193 |
+
line = self.make_allocation(
|
| 1194 |
+
"workspace", device, torch.uint8, shape=(nbytes,), stride=(1,)
|
| 1195 |
+
)
|
| 1196 |
+
self.writeline(line)
|
| 1197 |
+
if zero_fill:
|
| 1198 |
+
self.writeline(f"workspace.zero_(){self.ending}")
|
| 1199 |
+
|
| 1200 |
+
def wrap_kernel_call(self, name, call_args):
|
| 1201 |
+
return f"{name}({', '.join(call_args)}){self.ending}"
|
| 1202 |
+
|
| 1203 |
+
def generate_profiler_mark_wrapper_call(self, stack):
|
| 1204 |
+
self.wrapper_call.writeline("from torch.profiler import record_function")
|
| 1205 |
+
self.wrapper_call.writeline(
|
| 1206 |
+
f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):"
|
| 1207 |
+
)
|
| 1208 |
+
stack.enter_context(self.wrapper_call.indent())
|
| 1209 |
+
|
| 1210 |
+
def generate_start_graph(self):
|
| 1211 |
+
self.wrapper_call.writeline("start_graph()")
|
| 1212 |
+
|
| 1213 |
+
def generate_end_graph(self):
|
| 1214 |
+
self.wrapper_call.writeline("end_graph()")
|
| 1215 |
+
|
| 1216 |
+
def generate_reset_kernel_saved_flags(self):
|
| 1217 |
+
self.wrapper_call.splice(
|
| 1218 |
+
"""
|
| 1219 |
+
for kernel in globals().values():
|
| 1220 |
+
if isinstance(kernel, torch._inductor.triton_heuristics.CachingAutotuner):
|
| 1221 |
+
kernel.cuda_kernel_saved = False
|
| 1222 |
+
"""
|
| 1223 |
+
)
|
| 1224 |
+
|
| 1225 |
+
def generate_save_uncompiled_kernels(self):
|
| 1226 |
+
"""
|
| 1227 |
+
Precompile and save the CUBINs of the Triton kernels that haven't
|
| 1228 |
+
been precompiled and saved as a side effect of running the generated
|
| 1229 |
+
JIT model (Python wrapper). This can happen when the model contains
|
| 1230 |
+
control flow: only one pass through the control flow operators covers
|
| 1231 |
+
the kernels that are saved, the remaining kernels are not launched,
|
| 1232 |
+
hence not saved. The main purpose of this codegen is to compile and
|
| 1233 |
+
save the Triton kernels outside the active control flow path for
|
| 1234 |
+
subsequent AOTInductor code generation and compilation.
|
| 1235 |
+
"""
|
| 1236 |
+
self.wrapper_call.splice(
|
| 1237 |
+
"""
|
| 1238 |
+
for kernel in globals().values():
|
| 1239 |
+
if isinstance(kernel, torch._inductor.triton_heuristics.CachingAutotuner):
|
| 1240 |
+
if not kernel.cuda_kernel_saved:
|
| 1241 |
+
if len(kernel.launchers) == 0:
|
| 1242 |
+
kernel.precompile()
|
| 1243 |
+
kernel.save_cuda_kernel(
|
| 1244 |
+
grid=(0, 0, 0), # use dummy grid
|
| 1245 |
+
stream="stream", # use dummy stream
|
| 1246 |
+
launcher=kernel.launchers[0],
|
| 1247 |
+
)
|
| 1248 |
+
"""
|
| 1249 |
+
)
|
| 1250 |
+
|
| 1251 |
+
def generate_default_grid(self, name: str, grid_args: List[Any]):
|
| 1252 |
+
return grid_args
|
| 1253 |
+
|
| 1254 |
+
def generate_kernel_call(
|
| 1255 |
+
self,
|
| 1256 |
+
name,
|
| 1257 |
+
call_args,
|
| 1258 |
+
grid=None,
|
| 1259 |
+
device_index=None,
|
| 1260 |
+
cuda=True,
|
| 1261 |
+
triton=True,
|
| 1262 |
+
arg_types=None,
|
| 1263 |
+
grid_fn: str = "grid",
|
| 1264 |
+
triton_meta=None,
|
| 1265 |
+
):
|
| 1266 |
+
"""
|
| 1267 |
+
Generates kernel call code.
|
| 1268 |
+
|
| 1269 |
+
cuda: Defines whether the backend is GPU. Otherwise the backend is CPU.
|
| 1270 |
+
|
| 1271 |
+
triton: Defines whether the GPU backend uses Triton for codegen.
|
| 1272 |
+
Otherwise it uses the CUDA language for codegen.
|
| 1273 |
+
Only valid when cuda == True.
|
| 1274 |
+
"""
|
| 1275 |
+
if cuda:
|
| 1276 |
+
call_args_str = ", ".join(pexpr(item) for item in call_args)
|
| 1277 |
+
stream_name = self.write_get_raw_stream(
|
| 1278 |
+
V.graph.scheduler.current_device.index, V.graph
|
| 1279 |
+
)
|
| 1280 |
+
if triton:
|
| 1281 |
+
grid_str = ", ".join(pexpr(item) for item in grid)
|
| 1282 |
+
grid_str = f"{grid_fn}({grid_str})"
|
| 1283 |
+
self.writeline(
|
| 1284 |
+
f"{name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
|
| 1285 |
+
)
|
| 1286 |
+
else:
|
| 1287 |
+
stream_ptr = f"c_void_p({stream_name})"
|
| 1288 |
+
self.writeline(f"{name}.{name}({call_args_str}, {stream_ptr})")
|
| 1289 |
+
else:
|
| 1290 |
+
self.writeline(self.wrap_kernel_call(name, call_args))
|
| 1291 |
+
|
| 1292 |
+
def writeline(self, line):
|
| 1293 |
+
self.lines.append(line)
|
| 1294 |
+
|
| 1295 |
+
def enter_context(self, ctx):
|
| 1296 |
+
self.lines.append(LineContext(ctx))
|
| 1297 |
+
|
| 1298 |
+
def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
|
| 1299 |
+
raise NotImplementedError()
|
| 1300 |
+
|
| 1301 |
+
def val_to_arg_str(self, s):
|
| 1302 |
+
if isinstance(s, SymTypes):
|
| 1303 |
+
return pexpr(sympy.expand(repr(s)))
|
| 1304 |
+
elif isinstance(s, sympy.Expr):
|
| 1305 |
+
return pexpr(s)
|
| 1306 |
+
elif isinstance(s, (tuple, list)):
|
| 1307 |
+
|
| 1308 |
+
@dataclasses.dataclass
|
| 1309 |
+
class Shim:
|
| 1310 |
+
ref: Any
|
| 1311 |
+
|
| 1312 |
+
def __repr__(self):
|
| 1313 |
+
return self.ref
|
| 1314 |
+
|
| 1315 |
+
return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s))
|
| 1316 |
+
elif isinstance(s, torch._ops.OpOverload):
|
| 1317 |
+
return _get_qualified_name(s)
|
| 1318 |
+
elif isinstance(s, (ir.Buffer, ReinterpretView)):
|
| 1319 |
+
return s.codegen_reference()
|
| 1320 |
+
else:
|
| 1321 |
+
return repr(s)
|
| 1322 |
+
|
| 1323 |
+
# The following methods are for memory management
|
| 1324 |
+
def make_buffer_allocation(self, buffer):
|
| 1325 |
+
device = buffer.get_device()
|
| 1326 |
+
dtype = buffer.get_dtype()
|
| 1327 |
+
shape = tuple(buffer.get_size())
|
| 1328 |
+
stride = tuple(buffer.get_stride())
|
| 1329 |
+
return self.make_allocation(buffer.get_name(), device, dtype, shape, stride)
|
| 1330 |
+
|
| 1331 |
+
def make_allocation(self, name, device, dtype, shape, stride):
|
| 1332 |
+
if device.type in ("cpu", "cuda"):
|
| 1333 |
+
# optimized path for faster allocations, saving ~2us versus the stuff below
|
| 1334 |
+
return (
|
| 1335 |
+
f"{name} = empty_strided_{device.type}("
|
| 1336 |
+
f"{self.codegen_shape_tuple(shape)}, "
|
| 1337 |
+
f"{self.codegen_shape_tuple(stride)}, "
|
| 1338 |
+
f"{dtype})"
|
| 1339 |
+
)
|
| 1340 |
+
# all other devices:
|
| 1341 |
+
return (
|
| 1342 |
+
f"{name} = empty_strided("
|
| 1343 |
+
f"{self.codegen_shape_tuple(shape)}, "
|
| 1344 |
+
f"{self.codegen_shape_tuple(stride)}, "
|
| 1345 |
+
f"device='{device.type}', dtype={dtype})"
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
def make_tensor_alias(self, new_name, old_name, comment=""):
|
| 1349 |
+
return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}"
|
| 1350 |
+
|
| 1351 |
+
def make_buffer_free(self, buffer):
|
| 1352 |
+
return f"del {buffer.get_name()}"
|
| 1353 |
+
|
| 1354 |
+
def make_free_by_names(self, names_to_del: List[str]):
|
| 1355 |
+
return f"del {', '.join(name for name in names_to_del)}"
|
| 1356 |
+
|
| 1357 |
+
def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
|
| 1358 |
+
return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse"
|
| 1359 |
+
|
| 1360 |
+
def make_buffer_reuse(self, old, new, delete_old: bool):
|
| 1361 |
+
assert old.get_dtype() == new.get_dtype()
|
| 1362 |
+
old_name = old.get_name()
|
| 1363 |
+
new_name = new.get_name()
|
| 1364 |
+
del_line = ";"
|
| 1365 |
+
if old_name not in V.graph.get_output_names() and delete_old:
|
| 1366 |
+
del_line = f"; {self.make_buffer_free(old)}"
|
| 1367 |
+
|
| 1368 |
+
if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
|
| 1369 |
+
if old_name in self.stack_allocated_buffers:
|
| 1370 |
+
self.stack_allocated_buffers[new_name] = new
|
| 1371 |
+
return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
|
| 1372 |
+
|
| 1373 |
+
reinterpret_view = self.codegen_reinterpret_view(
|
| 1374 |
+
old, new.get_size(), new.get_stride(), 0, self.wrapper_call
|
| 1375 |
+
)
|
| 1376 |
+
if reinterpret_view in self.stack_allocated_buffers:
|
| 1377 |
+
self.stack_allocated_buffers[new_name] = new
|
| 1378 |
+
return f"{self.declare_maybe_reference}{new_name} = {reinterpret_view}{del_line} {self.comment} reuse"
|
| 1379 |
+
|
| 1380 |
+
def codegen_deferred_allocation(self, name, layout):
|
| 1381 |
+
self.writeline(
|
| 1382 |
+
DeferredLine(
|
| 1383 |
+
name,
|
| 1384 |
+
f"{self.declare_maybe_reference}{name} = {layout.view.codegen_reference()}{self.ending} "
|
| 1385 |
+
f"{self.comment} alias",
|
| 1386 |
+
)
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
def codegen_allocation(self, buffer):
|
| 1390 |
+
assert (
|
| 1391 |
+
buffer.get_workspace_size() == 0
|
| 1392 |
+
), "Only support zero workspace size for now!"
|
| 1393 |
+
|
| 1394 |
+
name = buffer.get_name()
|
| 1395 |
+
|
| 1396 |
+
if name in V.graph.removed_buffers or name in self.allocated:
|
| 1397 |
+
return
|
| 1398 |
+
self.allocated.add(name)
|
| 1399 |
+
if isinstance(
|
| 1400 |
+
buffer,
|
| 1401 |
+
(ir.ExternKernelAlloc, ir.MultiOutput),
|
| 1402 |
+
):
|
| 1403 |
+
return
|
| 1404 |
+
|
| 1405 |
+
layout = buffer.get_layout()
|
| 1406 |
+
if isinstance(layout, ir.MutationLayout):
|
| 1407 |
+
return
|
| 1408 |
+
if isinstance(layout, ir.AliasedLayout):
|
| 1409 |
+
assert isinstance(
|
| 1410 |
+
layout.view, ir.ReinterpretView
|
| 1411 |
+
), f"unexpected {type(layout.view)}: {layout.view}"
|
| 1412 |
+
self.codegen_allocation(layout.view.data)
|
| 1413 |
+
self.codegen_deferred_allocation(name, layout)
|
| 1414 |
+
return
|
| 1415 |
+
|
| 1416 |
+
self.writeline(AllocateLine(self, buffer))
|
| 1417 |
+
|
| 1418 |
+
def codegen_free(self, buffer):
|
| 1419 |
+
assert (
|
| 1420 |
+
buffer.get_workspace_size() == 0
|
| 1421 |
+
), "Only support zero workspace size for now!"
|
| 1422 |
+
|
| 1423 |
+
name = buffer.get_name()
|
| 1424 |
+
|
| 1425 |
+
# can be freed but not reused
|
| 1426 |
+
if isinstance(buffer, ir.InputBuffer):
|
| 1427 |
+
self.writeline(self.make_buffer_free(buffer))
|
| 1428 |
+
return
|
| 1429 |
+
|
| 1430 |
+
if not self.can_reuse(buffer):
|
| 1431 |
+
return
|
| 1432 |
+
self.freed.add(name)
|
| 1433 |
+
|
| 1434 |
+
self.writeline(FreeIfNotReusedLine(self, buffer))
|
| 1435 |
+
|
| 1436 |
+
def can_reuse(self, input_buffer, output_buffer=None):
|
| 1437 |
+
name = input_buffer.get_name()
|
| 1438 |
+
if (
|
| 1439 |
+
name in V.graph.removed_buffers
|
| 1440 |
+
or name in V.graph.graph_inputs
|
| 1441 |
+
or name in V.graph.constants
|
| 1442 |
+
or name in V.graph.never_reuse_buffers
|
| 1443 |
+
or name in self.freed
|
| 1444 |
+
):
|
| 1445 |
+
return False
|
| 1446 |
+
|
| 1447 |
+
return True
|
| 1448 |
+
|
| 1449 |
+
def did_reuse(self, buffer, reused_buffer):
|
| 1450 |
+
# Check whether a given buffer was reused by a possible reuser in the wrapper codegen
|
| 1451 |
+
# Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
|
| 1452 |
+
return (
|
| 1453 |
+
buffer.get_name() in self.reuses
|
| 1454 |
+
and self.reuses[buffer.get_name()] == reused_buffer.get_name()
|
| 1455 |
+
)
|
| 1456 |
+
|
| 1457 |
+
def codegen_inplace_reuse(self, input_buffer, output_buffer):
|
| 1458 |
+
assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
|
| 1459 |
+
self.codegen_allocation(input_buffer)
|
| 1460 |
+
self.freed.add(input_buffer.get_name())
|
| 1461 |
+
self.allocated.add(output_buffer.get_name())
|
| 1462 |
+
self.reuses[output_buffer.get_name()] = input_buffer.get_name()
|
| 1463 |
+
self.writeline(ReuseLine(self, input_buffer, output_buffer))
|
| 1464 |
+
|
| 1465 |
+
def codegen_unbacked_symbol_decl(self, symbol):
|
| 1466 |
+
name = str(symbol)
|
| 1467 |
+
if name in self.unbacked_symbol_decls:
|
| 1468 |
+
return name
|
| 1469 |
+
else:
|
| 1470 |
+
# When in CppWrapperCpu, we should only generate the declaration once
|
| 1471 |
+
self.unbacked_symbol_decls.add(name)
|
| 1472 |
+
return self.declare + name
|
| 1473 |
+
|
| 1474 |
+
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
|
| 1475 |
+
for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs):
|
| 1476 |
+
self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}")
|
| 1477 |
+
|
| 1478 |
+
def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
|
| 1479 |
+
for inner_output, outer_output in zip(
|
| 1480 |
+
subgraph.graph.graph_outputs, outer_outputs
|
| 1481 |
+
):
|
| 1482 |
+
self.writeline(
|
| 1483 |
+
f"{outer_output} = {inner_output.codegen_reference()}{self.ending}"
|
| 1484 |
+
)
|
| 1485 |
+
|
| 1486 |
+
def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
|
| 1487 |
+
try:
|
| 1488 |
+
self.push_codegened_graph(subgraph.graph)
|
| 1489 |
+
self.writeline(f"{self.comment} subgraph: {subgraph.name}")
|
| 1490 |
+
self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs)
|
| 1491 |
+
parent_graph = V.graph
|
| 1492 |
+
with V.set_graph_handler(subgraph.graph):
|
| 1493 |
+
subgraph.graph.codegen_subgraph(
|
| 1494 |
+
parent_graph=parent_graph,
|
| 1495 |
+
)
|
| 1496 |
+
self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs)
|
| 1497 |
+
finally:
|
| 1498 |
+
self.pop_codegened_graph()
|
| 1499 |
+
|
| 1500 |
+
def codegen_conditional(self, conditional):
|
| 1501 |
+
name = conditional.get_name()
|
| 1502 |
+
outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
|
| 1503 |
+
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
|
| 1504 |
+
|
| 1505 |
+
self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
|
| 1506 |
+
self.writeline(f"if {conditional.predicate.codegen_reference()}.item():")
|
| 1507 |
+
self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
|
| 1508 |
+
self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
|
| 1509 |
+
self.writeline(ExitSubgraphLine(self))
|
| 1510 |
+
self.writeline("else:")
|
| 1511 |
+
self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
|
| 1512 |
+
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
|
| 1513 |
+
self.writeline(ExitSubgraphLine(self))
|
| 1514 |
+
|
| 1515 |
+
@staticmethod
|
| 1516 |
+
def statically_known_int_or_none(x):
|
| 1517 |
+
try:
|
| 1518 |
+
val = V.graph._shape_env._maybe_evaluate_static(x)
|
| 1519 |
+
return int(x)
|
| 1520 |
+
except Exception:
|
| 1521 |
+
return None
|
| 1522 |
+
|
| 1523 |
+
@staticmethod
|
| 1524 |
+
def statically_known_list_of_ints_or_none(lst):
|
| 1525 |
+
result = []
|
| 1526 |
+
for x in lst:
|
| 1527 |
+
num = WrapperCodeGen.statically_known_int_or_none(x)
|
| 1528 |
+
if num is None:
|
| 1529 |
+
return None
|
| 1530 |
+
result.append(num)
|
| 1531 |
+
return result
|
| 1532 |
+
|
| 1533 |
+
@staticmethod
|
| 1534 |
+
def is_statically_known_list_of_ints(lst):
|
| 1535 |
+
return WrapperCodeGen.statically_known_list_of_ints_or_none(lst) is not None
|
| 1536 |
+
|
| 1537 |
+
@staticmethod
|
| 1538 |
+
def static_shape_for_buffer_or_none(buffer):
|
| 1539 |
+
return WrapperCodeGen.statically_known_list_of_ints_or_none(buffer.get_size())
|
| 1540 |
+
|
| 1541 |
+
@staticmethod
|
| 1542 |
+
def can_prove_buffer_has_static_shape(buffer):
|
| 1543 |
+
return WrapperCodeGen.static_shape_for_buffer_or_none(buffer) is not None
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/comm_analysis.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from enum import IntEnum
|
| 3 |
+
|
| 4 |
+
import sympy
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from . import ir
|
| 8 |
+
|
| 9 |
+
from .utils import get_dtype_size, sympy_product
|
| 10 |
+
from .virtualized import V
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class NCCL_COLL(IntEnum):
|
| 14 |
+
ALL_REDUCE = 0
|
| 15 |
+
ALL_GATHER = 1
|
| 16 |
+
REDUCE_SCATTER = 2
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class NVIDIA_GPU_TYPE(IntEnum):
|
| 20 |
+
VOLTA = 0
|
| 21 |
+
AMPERE = 1
|
| 22 |
+
HOPPER = 2
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_gpu_type() -> NVIDIA_GPU_TYPE:
|
| 26 |
+
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
|
| 27 |
+
if "V100" in gpu_info:
|
| 28 |
+
return NVIDIA_GPU_TYPE.VOLTA
|
| 29 |
+
elif "A100" in gpu_info:
|
| 30 |
+
return NVIDIA_GPU_TYPE.AMPERE
|
| 31 |
+
elif "H100" in gpu_info:
|
| 32 |
+
return NVIDIA_GPU_TYPE.HOPPER
|
| 33 |
+
else:
|
| 34 |
+
# for other gpu types, assume Ampere
|
| 35 |
+
return NVIDIA_GPU_TYPE.AMPERE
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
|
| 39 |
+
if isinstance(node, ir._CollectiveKernel):
|
| 40 |
+
kernel_name = node.python_kernel_name
|
| 41 |
+
assert kernel_name is not None
|
| 42 |
+
if "all_reduce" in kernel_name:
|
| 43 |
+
return NCCL_COLL.ALL_REDUCE
|
| 44 |
+
elif "all_gather" in kernel_name:
|
| 45 |
+
return NCCL_COLL.ALL_GATHER
|
| 46 |
+
elif "reduce_scatter" in kernel_name:
|
| 47 |
+
return NCCL_COLL.REDUCE_SCATTER
|
| 48 |
+
else:
|
| 49 |
+
raise Exception(f"Unsupported collective kernel: {kernel_name}")
|
| 50 |
+
|
| 51 |
+
if isinstance(node, (ir.AllReduce, ir.AllReduceCoalesced)):
|
| 52 |
+
return NCCL_COLL.ALL_REDUCE
|
| 53 |
+
elif isinstance(node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)):
|
| 54 |
+
return NCCL_COLL.ALL_GATHER
|
| 55 |
+
elif isinstance(node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)):
|
| 56 |
+
return NCCL_COLL.REDUCE_SCATTER
|
| 57 |
+
else:
|
| 58 |
+
raise Exception(f"Unsupported collective type: {node}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_collective_input_size_bytes(node: ir.IRNode) -> int:
|
| 62 |
+
sz_bytes = 0
|
| 63 |
+
for inp in node.inputs: # type: ignore[attr-defined]
|
| 64 |
+
shape = inp.layout.size
|
| 65 |
+
numel = sympy_product(inp.layout.size)
|
| 66 |
+
if isinstance(numel, sympy.Integer):
|
| 67 |
+
# For ease of testing
|
| 68 |
+
numel = int(numel)
|
| 69 |
+
else:
|
| 70 |
+
numel = V.graph.sizevars.size_hint(numel)
|
| 71 |
+
sz_bytes += numel * get_dtype_size(inp.layout.dtype)
|
| 72 |
+
return sz_bytes
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_collective_group_size(node: ir.IRNode) -> int:
|
| 76 |
+
if type(node) == ir._CollectiveKernel:
|
| 77 |
+
from torch.distributed.distributed_c10d import _get_group_size_by_name
|
| 78 |
+
|
| 79 |
+
return _get_group_size_by_name(node.constant_args[-1])
|
| 80 |
+
elif isinstance(node, ir.CollectiveKernel):
|
| 81 |
+
return node.constant_args[2] # type: ignore[attr-defined]
|
| 82 |
+
else:
|
| 83 |
+
raise TypeError(f"Unsupported collective type: {node}")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
####################################################################################################################
|
| 87 |
+
# The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
|
| 88 |
+
####################################################################################################################
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class NCCL_HW(IntEnum):
|
| 92 |
+
NVLINK = 0
|
| 93 |
+
PCI = 1
|
| 94 |
+
NET = 2
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class NCCL_ALGO(IntEnum):
|
| 98 |
+
TREE = 0
|
| 99 |
+
RING = 1
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class NCCL_PROTO(IntEnum):
|
| 103 |
+
# The ordering and enum values here matches original in
|
| 104 |
+
# https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28
|
| 105 |
+
# For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990
|
| 106 |
+
LL = 0 # Low-latency
|
| 107 |
+
# LL128 = 1 # Low-latency 128-byte
|
| 108 |
+
# SIMPLE = 2
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# Latencies in us
|
| 112 |
+
# len(NCCL_ALGO) x len(NCCL_PROTO)
|
| 113 |
+
# NOTE: use array instead of tensor to prevent incompatibility with fake mode
|
| 114 |
+
baseLat = [
|
| 115 |
+
# Tree
|
| 116 |
+
[
|
| 117 |
+
6.8, # LL
|
| 118 |
+
],
|
| 119 |
+
# Ring
|
| 120 |
+
[
|
| 121 |
+
6.6, # LL
|
| 122 |
+
],
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
# Latencies in us
|
| 126 |
+
# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
|
| 127 |
+
hwLat = [
|
| 128 |
+
# NVLINK
|
| 129 |
+
[
|
| 130 |
+
[0.6], # Tree (LL)
|
| 131 |
+
[0.6], # Ring (LL)
|
| 132 |
+
],
|
| 133 |
+
# PCI
|
| 134 |
+
[
|
| 135 |
+
[1.0], # Tree (LL)
|
| 136 |
+
[1.0], # Ring (LL)
|
| 137 |
+
],
|
| 138 |
+
# NET
|
| 139 |
+
[
|
| 140 |
+
[5.0], # Tree (LL)
|
| 141 |
+
[2.7], # Ring (LL)
|
| 142 |
+
],
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# LL128 max BW per channel
|
| 147 |
+
llMaxBws = [
|
| 148 |
+
# Volta-N1/Intel-N2/Intel-N4
|
| 149 |
+
[
|
| 150 |
+
39.0,
|
| 151 |
+
39.0,
|
| 152 |
+
20.4,
|
| 153 |
+
],
|
| 154 |
+
# Ampere-N1/AMD-N2/AMD-N4
|
| 155 |
+
[
|
| 156 |
+
87.7,
|
| 157 |
+
22.5, # avg of ring & tree
|
| 158 |
+
19.0,
|
| 159 |
+
],
|
| 160 |
+
# Hopper-N1/AMD-N2/AMD-N4
|
| 161 |
+
[
|
| 162 |
+
87.7,
|
| 163 |
+
22.5, # avg of ring & tree
|
| 164 |
+
19.0,
|
| 165 |
+
],
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
| 170 |
+
"""
|
| 171 |
+
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
| 172 |
+
|
| 173 |
+
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
|
| 174 |
+
We aim to estimate the runtime as accurately as possible.
|
| 175 |
+
|
| 176 |
+
Assumptions:
|
| 177 |
+
- only ring algorithm (NCCL_ALGO_RING) is used
|
| 178 |
+
- only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
|
| 179 |
+
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
| 180 |
+
- collective is one of: allreduce, reducescatter, allgather
|
| 181 |
+
"""
|
| 182 |
+
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
|
| 183 |
+
# Convert bytes to GB
|
| 184 |
+
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
|
| 185 |
+
|
| 186 |
+
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
|
| 187 |
+
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
| 188 |
+
num_gpus_per_node = 8
|
| 189 |
+
group_size = get_collective_group_size(node)
|
| 190 |
+
nNodes = math.ceil(group_size / num_gpus_per_node)
|
| 191 |
+
nRanks = group_size # this is total # of gpus globally that participate in this collective op
|
| 192 |
+
|
| 193 |
+
if nRanks <= 1:
|
| 194 |
+
return 0
|
| 195 |
+
|
| 196 |
+
# Assumes ring algorithm
|
| 197 |
+
nccl_algo = NCCL_ALGO.RING
|
| 198 |
+
nccl_proto = NCCL_PROTO.LL
|
| 199 |
+
coll = get_collective_type(node)
|
| 200 |
+
|
| 201 |
+
# =============== bandwidth computation ===============
|
| 202 |
+
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns
|
| 203 |
+
|
| 204 |
+
bwIntra = torch._inductor.config.intra_node_bw
|
| 205 |
+
bwInter = torch._inductor.config.inter_node_bw
|
| 206 |
+
|
| 207 |
+
compCapIndex = get_gpu_type()
|
| 208 |
+
index2 = nNodes - 1 if nNodes <= 2 else 2
|
| 209 |
+
# LL: for single node, we look at GPU type; for multi-node, we look at CPU type
|
| 210 |
+
index1 = compCapIndex if nNodes == 1 else 0
|
| 211 |
+
llMaxBw = llMaxBws[index1][index2]
|
| 212 |
+
|
| 213 |
+
# NOTE: each step of ring algorithm is synchronized,
|
| 214 |
+
# and is bottlenecked by the slowest link which is the inter-node interconnect.
|
| 215 |
+
# hence when nNodes >= 2, bw is inter-node bandwidth.
|
| 216 |
+
# NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc
|
| 217 |
+
# have this as `if nNodes <= 2` which seems wrong. Corrected it here.
|
| 218 |
+
bw = bwIntra if nNodes == 1 else bwInter
|
| 219 |
+
nChannels = 2 # Assume # channels is 2
|
| 220 |
+
busBw = nChannels * bw
|
| 221 |
+
|
| 222 |
+
# Various model refinements
|
| 223 |
+
busBw = min(
|
| 224 |
+
llMaxBw,
|
| 225 |
+
busBw
|
| 226 |
+
* (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
if coll == NCCL_COLL.ALL_REDUCE:
|
| 230 |
+
nsteps = 2 * (nRanks - 1)
|
| 231 |
+
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
| 232 |
+
nsteps = nRanks - 1
|
| 233 |
+
|
| 234 |
+
# Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
|
| 235 |
+
ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined]
|
| 236 |
+
bandwidth = busBw * ratio
|
| 237 |
+
# Convert GB/s to GB/ns
|
| 238 |
+
bandwidth_GB_per_ns = bandwidth / 1e9
|
| 239 |
+
|
| 240 |
+
# =============== latency computation ===============
|
| 241 |
+
intraHw = NCCL_HW.NVLINK
|
| 242 |
+
hw = intraHw if nNodes == 1 else NCCL_HW.NET
|
| 243 |
+
|
| 244 |
+
if coll == NCCL_COLL.ALL_REDUCE:
|
| 245 |
+
if nNodes > 1:
|
| 246 |
+
nInterSteps = 2 * nNodes
|
| 247 |
+
else:
|
| 248 |
+
nInterSteps = 0
|
| 249 |
+
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
| 250 |
+
nInterSteps = nNodes - 1
|
| 251 |
+
|
| 252 |
+
# First compute latency in us; then at the end, convert it to ns
|
| 253 |
+
latency = baseLat[nccl_algo][nccl_proto]
|
| 254 |
+
intraLat = hwLat[intraHw][nccl_algo][nccl_proto]
|
| 255 |
+
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto]
|
| 256 |
+
|
| 257 |
+
# Inter-node rings still have to launch nsteps * net overhead.
|
| 258 |
+
netOverhead = 0.0
|
| 259 |
+
if nNodes > 1:
|
| 260 |
+
netOverhead = 1.0 # getNetOverhead(comm);
|
| 261 |
+
intraLat = max(intraLat, netOverhead)
|
| 262 |
+
latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined]
|
| 263 |
+
# Convert us to ns
|
| 264 |
+
latency_ns = latency * 1e3
|
| 265 |
+
|
| 266 |
+
# =============== final result ===============
|
| 267 |
+
transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
|
| 268 |
+
return transport_ns + latency_ns
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
################################################################################################################
|
| 272 |
+
# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
|
| 273 |
+
################################################################################################################
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py
ADDED
|
@@ -0,0 +1,2159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables,
|
| 3 |
+
which share the same memory pool. Sharing a memory pool is an extremely
|
| 4 |
+
important optimization when chaining multiple CUDA graphs together, as it
|
| 5 |
+
prevents you from needing to copy intermediate tensors from one graph to the
|
| 6 |
+
next, and reduces overall memory usage by allowing dead memory from the first
|
| 7 |
+
pool to be reused in the second.
|
| 8 |
+
|
| 9 |
+
The standard graph/make_graph_callables support sharing memory pool, but
|
| 10 |
+
with a lot of caveats. CUDA graph trees remove these restrictions:
|
| 11 |
+
|
| 12 |
+
* Previously, if you recorded graphs A, B, you had to replay A, B in that
|
| 13 |
+
order. With CUDA graph trees, after replaying A, you can change your
|
| 14 |
+
mind and record/replay a different graph B'; we will support efficient
|
| 15 |
+
execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In
|
| 16 |
+
other words: we support arbitrary trees of CUDA graph operations, not just
|
| 17 |
+
sequences (this is why this feature is called CUDA graph trees.)
|
| 18 |
+
|
| 19 |
+
* Previously, if you executed graph A, some non-CUDA graph code, and then
|
| 20 |
+
graph B, after executing graph B, it was not safe to retain any references
|
| 21 |
+
to intermediates produced by A. With CUDA graph trees, we track if any
|
| 22 |
+
outputs of graph A are still live by the time graph B is run, and make
|
| 23 |
+
sure graph B doesn't clobber there memory when reusing the CUDA graphs
|
| 24 |
+
pool. You'll get a separate recording of B depending on what tensors
|
| 25 |
+
stay live or dead.
|
| 26 |
+
|
| 27 |
+
CUDA graph trees are flexible enough to be used in Dynamo across graph breaks,
|
| 28 |
+
which is their primary use case.
|
| 29 |
+
|
| 30 |
+
The ability to switch from replay to record is fairly nontrivial: remember that
|
| 31 |
+
when you replay a CUDA graph, you only replay CUDA operations; no CPU side state
|
| 32 |
+
is updated. In particular, the CPU-side book-keeping for the allocator is not
|
| 33 |
+
reconstructed. However, to record a new child CUDA graph, we must restore this
|
| 34 |
+
book-keeping. This is what checkpoint pool state is used for.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from __future__ import annotations
|
| 38 |
+
|
| 39 |
+
import contextlib
|
| 40 |
+
import dataclasses
|
| 41 |
+
import functools
|
| 42 |
+
import gc
|
| 43 |
+
import itertools
|
| 44 |
+
import operator
|
| 45 |
+
import sys
|
| 46 |
+
import threading
|
| 47 |
+
import traceback
|
| 48 |
+
import warnings
|
| 49 |
+
import weakref
|
| 50 |
+
from collections import defaultdict
|
| 51 |
+
|
| 52 |
+
from enum import auto, Enum
|
| 53 |
+
from typing import (
|
| 54 |
+
Any,
|
| 55 |
+
Callable,
|
| 56 |
+
cast,
|
| 57 |
+
Dict,
|
| 58 |
+
Iterator,
|
| 59 |
+
List,
|
| 60 |
+
Optional,
|
| 61 |
+
Sequence,
|
| 62 |
+
Set,
|
| 63 |
+
Tuple,
|
| 64 |
+
Union,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
import torch.fx
|
| 68 |
+
from torch import Tensor
|
| 69 |
+
from torch._dynamo.mutation_guard import GenerationTracker
|
| 70 |
+
from torch._dynamo.utils import preserve_rng_state
|
| 71 |
+
from torch._inductor.compile_fx import (
|
| 72 |
+
align_inputs_from_check_idxs,
|
| 73 |
+
copy_misaligned_inputs,
|
| 74 |
+
get_expanded_dims,
|
| 75 |
+
get_input_idxs_to_check,
|
| 76 |
+
index_expanded_dims,
|
| 77 |
+
remove_unaligned_input_idxs,
|
| 78 |
+
static_input,
|
| 79 |
+
)
|
| 80 |
+
from torch.multiprocessing.reductions import StorageWeakRef
|
| 81 |
+
from torch.storage import UntypedStorage
|
| 82 |
+
from torch.types import _bool
|
| 83 |
+
from torch.utils import _pytree as pytree
|
| 84 |
+
from torch.utils.weak import TensorWeakRef
|
| 85 |
+
|
| 86 |
+
StorageWeakRefPointer = int
|
| 87 |
+
StorageDataPtr = int
|
| 88 |
+
NBytes = int
|
| 89 |
+
|
| 90 |
+
if torch.backends.cuda.is_built():
|
| 91 |
+
from torch._C import (
|
| 92 |
+
_cuda_CUDAAllocator_AllocatorState as AllocatorState,
|
| 93 |
+
_set_cached_tensors_enabled as _set_cached_tensors_enabled,
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
|
| 97 |
+
class AllocatorState: # type: ignore[no-redef]
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
def _set_cached_tensors_enabled(enabled: _bool) -> None:
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
log = torch._logging.getArtifactLogger(__name__, "cudagraphs")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
from . import config
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@dataclasses.dataclass(frozen=True)
|
| 111 |
+
class GraphID:
|
| 112 |
+
"Unique counter of a cuda graph recording"
|
| 113 |
+
id: int
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@dataclasses.dataclass(frozen=True)
|
| 117 |
+
class FunctionID:
|
| 118 |
+
"Unique counter of a function wrapped in cudagraphify_impl"
|
| 119 |
+
id: int
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@dataclasses.dataclass(frozen=True)
|
| 123 |
+
class WrappedFunction:
|
| 124 |
+
"""
|
| 125 |
+
Represents a function that you want to record for CUDA graph replay,
|
| 126 |
+
with a little more metadata so we can identify if we have an applicable
|
| 127 |
+
CUDA graph in our CUDA graph tree for it.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
model: Callable[..., Any]
|
| 131 |
+
static_input_idxs: Sequence[int]
|
| 132 |
+
id: FunctionID
|
| 133 |
+
constants: Tuple[torch.Tensor, ...]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def clear_cublass_cache():
|
| 137 |
+
"""
|
| 138 |
+
Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for
|
| 139 |
+
doing warmup within a CUDAGraph private pool because we do not want persistent allocations from
|
| 140 |
+
one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors
|
| 141 |
+
from the previous generation are freed. This frees them the memory pool, but not elsewhere.
|
| 142 |
+
A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated
|
| 143 |
+
in the next run. The memory would be in use in two places.
|
| 144 |
+
|
| 145 |
+
To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required
|
| 146 |
+
it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the
|
| 147 |
+
program. There is no overhead to this on replay since cudagraphs removes allocation overhead.
|
| 148 |
+
"""
|
| 149 |
+
torch._C._cuda_clearCublasWorkspaces()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@contextlib.contextmanager
|
| 153 |
+
def clear_cublas_manager():
|
| 154 |
+
"Context manager around clearing cublas caches that will clear on enter and exit"
|
| 155 |
+
clear_cublass_cache()
|
| 156 |
+
try:
|
| 157 |
+
yield
|
| 158 |
+
finally:
|
| 159 |
+
clear_cublass_cache()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@contextlib.contextmanager
|
| 163 |
+
def disable_conv_cache_emptying():
|
| 164 |
+
prev = torch._C._cuda_get_conv_benchmark_empty_cache()
|
| 165 |
+
torch._C._cudnn_set_conv_benchmark_empty_cache(False)
|
| 166 |
+
try:
|
| 167 |
+
yield
|
| 168 |
+
finally:
|
| 169 |
+
torch._C._cudnn_set_conv_benchmark_empty_cache(prev)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@contextlib.contextmanager
|
| 173 |
+
def enable_history_recording():
|
| 174 |
+
"Turns on history recording in the CUDA Caching Allocator"
|
| 175 |
+
enabled = torch._C._cuda_isHistoryEnabled()
|
| 176 |
+
try:
|
| 177 |
+
if not enabled:
|
| 178 |
+
torch.cuda.memory._record_memory_history()
|
| 179 |
+
yield
|
| 180 |
+
finally:
|
| 181 |
+
if not enabled:
|
| 182 |
+
torch.cuda.memory._record_memory_history(None)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def get_history_recording():
|
| 186 |
+
# TODO - remove, prevents cleanup
|
| 187 |
+
if not config.triton.cudagraph_trees_history_recording:
|
| 188 |
+
return contextlib.nullcontext()
|
| 189 |
+
return enable_history_recording()
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class TreeManagerContainer:
|
| 193 |
+
"""
|
| 194 |
+
Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator,
|
| 195 |
+
the tree and its corresponding memory pool should be kept alive as long as any outstanding
|
| 196 |
+
graph or tensor which is an output of a graph remains alive.
|
| 197 |
+
|
| 198 |
+
There is a single tree manager container per device.
|
| 199 |
+
|
| 200 |
+
The lifecycle of a tree_manager is:
|
| 201 |
+
- Is constructed, no graph, no fns, no tensors
|
| 202 |
+
- Tree manager is fetched, resulting in tree manager being allocated
|
| 203 |
+
- We generate a bunch of functions, calling add_strong_reference
|
| 204 |
+
- These functions die, calling finalize_reference
|
| 205 |
+
- When all the functions die, we finalize_tree_manager.
|
| 206 |
+
|
| 207 |
+
TODO: in the future, we would like to do the following once storage weak refs land
|
| 208 |
+
- We look for all the live storages and add references to THOSE
|
| 209 |
+
- We count as storages die
|
| 210 |
+
- All the storages are dead, we deallocate the tree manager
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(self, device_index):
|
| 214 |
+
# This class keeps a strong reference to tree_manager,
|
| 215 |
+
# but upon all other strong references to the tree_manager will reset it to None.
|
| 216 |
+
# We need a strong reference so that we can still access its attributes upon cleanup.
|
| 217 |
+
self.tree_manager: Optional[CUDAGraphTreeManager] = None
|
| 218 |
+
|
| 219 |
+
# Number of outstanding references to the current tree manager
|
| 220 |
+
self.live_cudagraphify_fns = 0
|
| 221 |
+
|
| 222 |
+
self.device_index = device_index
|
| 223 |
+
|
| 224 |
+
# Following two objects are only set in the case that Tensor outputs outlive
|
| 225 |
+
# the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from
|
| 226 |
+
# deallocation.
|
| 227 |
+
self.live_storages_count = 0
|
| 228 |
+
self.graph: Optional[torch.cuda.CUDAGraph] = None
|
| 229 |
+
|
| 230 |
+
self.lock = threading.Lock()
|
| 231 |
+
|
| 232 |
+
def _finalize_tensor(self):
|
| 233 |
+
with self.lock:
|
| 234 |
+
self.live_storages_count -= 1
|
| 235 |
+
if self.live_storages_count == 0:
|
| 236 |
+
self.graph = None
|
| 237 |
+
|
| 238 |
+
# manager was used again after existing cleanup,
|
| 239 |
+
# we shouldnt set it to None
|
| 240 |
+
if self.live_cudagraphify_fns == 0:
|
| 241 |
+
self.tree_manager = None
|
| 242 |
+
|
| 243 |
+
def finalize_cudagraphify_fn(self):
|
| 244 |
+
with self.lock:
|
| 245 |
+
self.live_cudagraphify_fns -= 1
|
| 246 |
+
if self.live_cudagraphify_fns == 0:
|
| 247 |
+
self._finalize_tree_manager()
|
| 248 |
+
|
| 249 |
+
def _finalize_tree_manager(self):
|
| 250 |
+
assert self.lock.locked()
|
| 251 |
+
self.tree_manager = None
|
| 252 |
+
|
| 253 |
+
# TODO - when issue #91395 is landed, we can set a weakref on
|
| 254 |
+
# storages and trigger a deallocation when all outputs of the
|
| 255 |
+
# cudagraph are dead.
|
| 256 |
+
|
| 257 |
+
# live_storages = list(
|
| 258 |
+
# tree_manager.live_cudagraph_pool_storages_in_curr_execution()
|
| 259 |
+
# )
|
| 260 |
+
|
| 261 |
+
# # Maintain reference to graph to keep tensors alive
|
| 262 |
+
# assert len(tree_manager.roots) > 0, "expected at least one use"
|
| 263 |
+
# root = next(tree_manager.get_roots())
|
| 264 |
+
# self.graph = root.graph
|
| 265 |
+
# seen_storages = set()
|
| 266 |
+
# for stor in live_storages:
|
| 267 |
+
# if stor in seen_storages:
|
| 268 |
+
# continue
|
| 269 |
+
# seen_storages.add(stor)
|
| 270 |
+
# self.live_storages_count += 1
|
| 271 |
+
# . weakref.finalize(stor, self._finalize_tensor)
|
| 272 |
+
|
| 273 |
+
def add_strong_reference(self, fn: Callable[..., Any]):
|
| 274 |
+
with self.lock:
|
| 275 |
+
self.live_cudagraphify_fns += 1
|
| 276 |
+
|
| 277 |
+
weakref.finalize(fn, self.finalize_cudagraphify_fn)
|
| 278 |
+
|
| 279 |
+
def get_tree_manager(self) -> CUDAGraphTreeManager:
|
| 280 |
+
with self.lock:
|
| 281 |
+
if self.tree_manager is None:
|
| 282 |
+
self.tree_manager = CUDAGraphTreeManager(self.device_index)
|
| 283 |
+
return self.tree_manager
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
local = threading.local()
|
| 287 |
+
|
| 288 |
+
# one tree manager per device
|
| 289 |
+
local.tree_manager_containers = {}
|
| 290 |
+
local.tree_manager_locks = defaultdict(threading.Lock)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# only incremented by user call of mark_step_begin
|
| 294 |
+
class MarkStepBox:
|
| 295 |
+
mark_step_counter = 0
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# We need to register this as an object that will be copied over as TLS when new
|
| 299 |
+
# threads are created in autograd
|
| 300 |
+
torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers)
|
| 301 |
+
torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def mark_step_begin():
|
| 305 |
+
"Indicates that a new iteration of inference or training is about to begin."
|
| 306 |
+
|
| 307 |
+
# iterate down to distinguish from GenerationTracking counter
|
| 308 |
+
MarkStepBox.mark_step_counter -= 1
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def reset_cudagraph_trees():
|
| 312 |
+
"Clear all cudagraph trees"
|
| 313 |
+
# see shutdown below for why this is necessary
|
| 314 |
+
container_dict = get_obj(local, "tree_manager_containers")
|
| 315 |
+
locks_dict = get_obj(local, "tree_manager_locks")
|
| 316 |
+
for device, lock in locks_dict.items():
|
| 317 |
+
with lock:
|
| 318 |
+
container = container_dict.get(device)
|
| 319 |
+
if not container or not container.tree_manager:
|
| 320 |
+
continue
|
| 321 |
+
|
| 322 |
+
container.tree_manager.shutdown()
|
| 323 |
+
|
| 324 |
+
_set_cached_tensors_enabled(False)
|
| 325 |
+
container_dict.clear()
|
| 326 |
+
|
| 327 |
+
MarkStepBox.mark_step_counter = 0
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def get_obj(local, attr_name):
|
| 331 |
+
if hasattr(local, attr_name):
|
| 332 |
+
return getattr(local, attr_name)
|
| 333 |
+
else:
|
| 334 |
+
assert torch._C._is_key_in_tls(attr_name)
|
| 335 |
+
return torch._C._get_obj_in_tls(attr_name)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def get_container(device_index: int):
|
| 339 |
+
container_dict = get_obj(local, "tree_manager_containers")
|
| 340 |
+
lock = get_obj(local, "tree_manager_locks")[device_index]
|
| 341 |
+
|
| 342 |
+
with lock:
|
| 343 |
+
if device_index not in container_dict:
|
| 344 |
+
container_dict[device_index] = TreeManagerContainer(device_index)
|
| 345 |
+
|
| 346 |
+
return container_dict[device_index]
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def get_manager(
|
| 350 |
+
device_index: int, create_if_none_exists=True
|
| 351 |
+
) -> Optional[CUDAGraphTreeManager]:
|
| 352 |
+
if create_if_none_exists:
|
| 353 |
+
return get_container(device_index).get_tree_manager()
|
| 354 |
+
return get_container(device_index).tree_manager
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def cudagraphify_impl(model, inputs, static_input_idxs, *args, **kwargs):
|
| 358 |
+
fn_cache: Dict[Tuple[int, ...], Callable[..., Any]] = {}
|
| 359 |
+
|
| 360 |
+
# Detect int inputs: we need to index on these
|
| 361 |
+
int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)]
|
| 362 |
+
get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None
|
| 363 |
+
|
| 364 |
+
del inputs
|
| 365 |
+
|
| 366 |
+
def deferred_cudagraphify(inputs):
|
| 367 |
+
int_key = get_ints(inputs)
|
| 368 |
+
fn = fn_cache.get(int_key)
|
| 369 |
+
if fn is not None:
|
| 370 |
+
return fn(inputs)
|
| 371 |
+
|
| 372 |
+
if int_key is None:
|
| 373 |
+
log.info("recording cudagraph tree for graph without symints")
|
| 374 |
+
else:
|
| 375 |
+
log.info("recording cudagraph tree for symint key %s", int_key)
|
| 376 |
+
|
| 377 |
+
# first get indices we need to check to align, then update our static inputs,
|
| 378 |
+
# and finally copy
|
| 379 |
+
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
|
| 380 |
+
new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
|
| 381 |
+
copy_misaligned_inputs(inputs, check_input_idxs)
|
| 382 |
+
|
| 383 |
+
fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
|
| 384 |
+
fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs)
|
| 385 |
+
fn_cache[int_key] = fn
|
| 386 |
+
|
| 387 |
+
return out
|
| 388 |
+
|
| 389 |
+
return deferred_cudagraphify
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def cudagraphify(
|
| 393 |
+
model,
|
| 394 |
+
inputs,
|
| 395 |
+
static_input_idxs=(),
|
| 396 |
+
*,
|
| 397 |
+
device_index: int,
|
| 398 |
+
is_backward: bool,
|
| 399 |
+
is_inference: bool,
|
| 400 |
+
stack_traces: Optional[StackTraces] = None,
|
| 401 |
+
constants: Tuple[torch.Tensor, ...] = (),
|
| 402 |
+
):
|
| 403 |
+
manager = get_container(device_index).get_tree_manager()
|
| 404 |
+
assert not (is_backward and is_inference)
|
| 405 |
+
mode = (
|
| 406 |
+
CompilationMode.BACKWARD
|
| 407 |
+
if is_backward
|
| 408 |
+
else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD)
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
return manager.add_function(
|
| 412 |
+
model,
|
| 413 |
+
inputs,
|
| 414 |
+
static_input_idxs,
|
| 415 |
+
stack_traces,
|
| 416 |
+
mode,
|
| 417 |
+
constants,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class StorageWeakRefWrapper:
|
| 422 |
+
"""
|
| 423 |
+
Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked.
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
__slots__ = ["ref", "_data_ptr", "extra_ref_check"]
|
| 427 |
+
|
| 428 |
+
storage_ref: Optional[StorageWeakRef]
|
| 429 |
+
|
| 430 |
+
def __init__(
|
| 431 |
+
self,
|
| 432 |
+
inp: Union[Tensor, UntypedStorage],
|
| 433 |
+
extra_ref_check: Optional[Callable[[], None]] = None,
|
| 434 |
+
):
|
| 435 |
+
"""
|
| 436 |
+
extra_ref_check is an additional check we need to run to check if the
|
| 437 |
+
weak ref has expired. in checking storage use count we assume extra_ref_check
|
| 438 |
+
will hold an additional reference to the storage.
|
| 439 |
+
"""
|
| 440 |
+
if isinstance(inp, Tensor):
|
| 441 |
+
stor = inp.untyped_storage()
|
| 442 |
+
else:
|
| 443 |
+
assert isinstance(inp, UntypedStorage)
|
| 444 |
+
stor = inp
|
| 445 |
+
self.ref = StorageWeakRef(stor)
|
| 446 |
+
self._data_ptr = stor.data_ptr()
|
| 447 |
+
self.extra_ref_check = extra_ref_check
|
| 448 |
+
|
| 449 |
+
@classmethod
|
| 450 |
+
def from_weakref_and_data_ptr(cls, cdata, data_ptr, extra_ref_check=None):
|
| 451 |
+
instance = cls.__new__(cls)
|
| 452 |
+
instance._data_ptr = data_ptr
|
| 453 |
+
instance.ref = StorageWeakRef.from_weakref(cdata)
|
| 454 |
+
instance.extra_ref_check = extra_ref_check
|
| 455 |
+
return instance
|
| 456 |
+
|
| 457 |
+
def __call__(self) -> Optional[StorageWeakRefPointer]:
|
| 458 |
+
if self.expired():
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
return self.ref.cdata
|
| 462 |
+
|
| 463 |
+
def swap_weakref(self, cdata):
|
| 464 |
+
self.ref.__del__()
|
| 465 |
+
self.ref.cdata = cdata
|
| 466 |
+
|
| 467 |
+
def data_ptr(self) -> int:
|
| 468 |
+
"NB: returns the data ptr even if the storage has expired"
|
| 469 |
+
return self._data_ptr
|
| 470 |
+
|
| 471 |
+
def remove_extra_reference(self):
|
| 472 |
+
self.extra_ref_check = None
|
| 473 |
+
|
| 474 |
+
def expired(self):
|
| 475 |
+
if self.extra_ref_check is not None and not self.extra_ref_check():
|
| 476 |
+
return False
|
| 477 |
+
|
| 478 |
+
# if extra_ref_check is not None we expect an additional reference
|
| 479 |
+
stor_count = torch._C._storage_Use_Count(self.ref.cdata)
|
| 480 |
+
return (stor_count - (self.extra_ref_check is not None)) == 0
|
| 481 |
+
|
| 482 |
+
def __repr__(self):
|
| 483 |
+
if self.ref is None or self.ref.expired():
|
| 484 |
+
return f"StorageWeakRefWrapper to {self.data_ptr()}; dead"
|
| 485 |
+
else:
|
| 486 |
+
return f"StorageWeakRefWrapper to {self.data_ptr()}; alive"
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool:
|
| 490 |
+
return maybe_deref(weak_ref) is not None
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def maybe_deref(
|
| 494 |
+
weak_ref: Optional[StorageWeakRefWrapper],
|
| 495 |
+
) -> Optional[Tuple[StorageWeakRefPointer, int]]:
|
| 496 |
+
if weak_ref is None:
|
| 497 |
+
return None
|
| 498 |
+
r = weak_ref()
|
| 499 |
+
if r is None:
|
| 500 |
+
return None
|
| 501 |
+
# NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr()
|
| 502 |
+
return r, weak_ref.data_ptr()
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
@contextlib.contextmanager
|
| 506 |
+
def _use_cuda_memory_pool_manager(device, mem_pool, stream):
|
| 507 |
+
"""
|
| 508 |
+
Context manager to use cuda graph pool for new allocations. If you use this manager
|
| 509 |
+
all cudagraph tensors in use should be reflected in the allocator or they will be overwritten.
|
| 510 |
+
existing_graph should already have been used in a capture, and the mem_pool must already exist,
|
| 511 |
+
because this manager will not preserve a reference to the pool which keeps it alive.
|
| 512 |
+
"""
|
| 513 |
+
torch.cuda.synchronize()
|
| 514 |
+
stream.wait_stream(torch.cuda.current_stream())
|
| 515 |
+
|
| 516 |
+
with torch.cuda.stream(stream), torch.device(device):
|
| 517 |
+
torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
|
| 518 |
+
try:
|
| 519 |
+
yield
|
| 520 |
+
finally:
|
| 521 |
+
torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
|
| 522 |
+
torch._C._cuda_releasePool(device, mem_pool)
|
| 523 |
+
|
| 524 |
+
torch.cuda.current_stream().wait_stream(stream)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
|
| 528 |
+
if not isinstance(t, torch.Tensor):
|
| 529 |
+
assert t is None
|
| 530 |
+
return None
|
| 531 |
+
return StorageWeakRefWrapper(t)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root
|
| 535 |
+
# at graph output offset
|
| 536 |
+
PathOutputIndex = Tuple[int, int]
|
| 537 |
+
|
| 538 |
+
# For each node in the path, for each output, is the output alive
|
| 539 |
+
PathLiveness = List[List[bool]]
|
| 540 |
+
|
| 541 |
+
StackTraces = List[Optional[str]]
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class CUDAWarmupNode:
|
| 545 |
+
"""
|
| 546 |
+
Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes
|
| 547 |
+
apis to get the live storages in the current chain of warmup.
|
| 548 |
+
|
| 549 |
+
A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have
|
| 550 |
+
CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable
|
| 551 |
+
memory addresses.
|
| 552 |
+
|
| 553 |
+
CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes.
|
| 554 |
+
- Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the
|
| 555 |
+
first instance of warmup, these are not finalized yet.
|
| 556 |
+
- All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup.
|
| 557 |
+
- CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler.
|
| 558 |
+
|
| 559 |
+
NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and
|
| 560 |
+
`self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility.
|
| 561 |
+
"""
|
| 562 |
+
|
| 563 |
+
def __init__(
|
| 564 |
+
self,
|
| 565 |
+
wrapped_function: WrappedFunction,
|
| 566 |
+
parent,
|
| 567 |
+
cuda_graphs_pool: Tuple[int, int],
|
| 568 |
+
existing_cuda_graph: Optional[torch.cuda.CUDAGraph],
|
| 569 |
+
device_index: int,
|
| 570 |
+
stack_traces: Optional[StackTraces],
|
| 571 |
+
stream: torch.cuda.Stream,
|
| 572 |
+
already_warm: bool,
|
| 573 |
+
):
|
| 574 |
+
self.wrapped_function = wrapped_function
|
| 575 |
+
self.parent = parent
|
| 576 |
+
self.cuda_graphs_pool = cuda_graphs_pool
|
| 577 |
+
self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = []
|
| 578 |
+
self.tensor_weakrefs: List[Optional[TensorWeakRef]] = []
|
| 579 |
+
self.existing_cuda_graph = existing_cuda_graph
|
| 580 |
+
self.has_run = False
|
| 581 |
+
self.device_index = device_index
|
| 582 |
+
self.stack_traces = stack_traces
|
| 583 |
+
self.stream = stream
|
| 584 |
+
self.already_warm = already_warm
|
| 585 |
+
|
| 586 |
+
def run(self, new_inputs):
|
| 587 |
+
assert not self.has_run, "Wrapped function should never be run twice"
|
| 588 |
+
|
| 589 |
+
# See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created
|
| 590 |
+
# storages in path_live_weakrefs.
|
| 591 |
+
existing_path_data_ptrs = {
|
| 592 |
+
t.data_ptr() for t in self.path_live_weakrefs() if t()
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
def get_non_cudagraph_inps():
|
| 596 |
+
non_cudagraph_inps = set()
|
| 597 |
+
for t in itertools.chain(new_inputs, self.wrapped_function.constants):
|
| 598 |
+
if (
|
| 599 |
+
isinstance(t, torch.Tensor)
|
| 600 |
+
and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
|
| 601 |
+
):
|
| 602 |
+
non_cudagraph_inps.add(t.untyped_storage().data_ptr())
|
| 603 |
+
return non_cudagraph_inps
|
| 604 |
+
|
| 605 |
+
non_cudagraph_inps = get_non_cudagraph_inps()
|
| 606 |
+
|
| 607 |
+
if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
|
| 608 |
+
refs = list(self.path_live_weakrefs())
|
| 609 |
+
check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
|
| 610 |
+
|
| 611 |
+
with torch.cuda.device(
|
| 612 |
+
self.device_index
|
| 613 |
+
), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager(
|
| 614 |
+
self.device_index, self.cuda_graphs_pool, self.stream
|
| 615 |
+
), get_history_recording():
|
| 616 |
+
out = self.wrapped_function.model(new_inputs)
|
| 617 |
+
|
| 618 |
+
assert len(new_inputs) == 0
|
| 619 |
+
|
| 620 |
+
# sdpa returns cpu tensors when not recording cuda graph
|
| 621 |
+
def add_ref(o):
|
| 622 |
+
return (
|
| 623 |
+
o is not None
|
| 624 |
+
and isinstance(o, torch.Tensor)
|
| 625 |
+
and o.is_cuda
|
| 626 |
+
and o.untyped_storage().data_ptr() not in non_cudagraph_inps
|
| 627 |
+
and o.untyped_storage().data_ptr() != 0
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
self.outputs_weakrefs.extend(
|
| 631 |
+
[map_to_ref(o) if add_ref(o) else None for o in out]
|
| 632 |
+
)
|
| 633 |
+
self.tensor_weakrefs.extend(
|
| 634 |
+
[TensorWeakRef(o) if add_ref(o) else None for o in out]
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
|
| 638 |
+
out_refs = self.path_live_weakrefs()
|
| 639 |
+
new_storages = [
|
| 640 |
+
t for t in out_refs if t.data_ptr() not in non_cudagraph_inps
|
| 641 |
+
]
|
| 642 |
+
check_memory_pool(self.device_index, self.cuda_graphs_pool, new_storages)
|
| 643 |
+
|
| 644 |
+
return out
|
| 645 |
+
|
| 646 |
+
@property
|
| 647 |
+
def _path_from_root(self):
|
| 648 |
+
nodes = []
|
| 649 |
+
node = self
|
| 650 |
+
while node:
|
| 651 |
+
nodes.append(node)
|
| 652 |
+
node = node.parent
|
| 653 |
+
|
| 654 |
+
yield from reversed(nodes)
|
| 655 |
+
|
| 656 |
+
def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]:
|
| 657 |
+
"Returns all live storages weakrefs that created by nodes in this path"
|
| 658 |
+
for node in self._path_from_root:
|
| 659 |
+
for output in node.outputs_weakrefs:
|
| 660 |
+
if is_live(output):
|
| 661 |
+
yield output
|
| 662 |
+
|
| 663 |
+
def all_outputs_are_dead(self):
|
| 664 |
+
return not list(self.path_live_weakrefs())
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
# Aliases for List that say what the indices denote
|
| 668 |
+
InputList = List # input indexes
|
| 669 |
+
OutputList = List # output indexes
|
| 670 |
+
LevelList = List # levels (distance from root of tree)
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
class OutputAliasInfo:
|
| 674 |
+
pass
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
class _UnaliasedStorage(OutputAliasInfo):
|
| 678 |
+
"Singleton to mark that the graph output constructs a new alias or is None"
|
| 679 |
+
pass
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
UnaliasedStorage = _UnaliasedStorage()
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
class AliasesPriorGraphOutput(OutputAliasInfo):
|
| 686 |
+
"Marks that the graph output aliases an output of a prior graph"
|
| 687 |
+
__slots__ = ["index"]
|
| 688 |
+
|
| 689 |
+
index: PathOutputIndex
|
| 690 |
+
|
| 691 |
+
def __init__(self, index: PathOutputIndex):
|
| 692 |
+
assert isinstance(index, tuple)
|
| 693 |
+
self.index = index
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class AliasesNewOutput(OutputAliasInfo):
|
| 697 |
+
"Marks that the graph output aliases an index in the new, returned outputs"
|
| 698 |
+
|
| 699 |
+
__slots__ = ["index"]
|
| 700 |
+
|
| 701 |
+
index: int
|
| 702 |
+
|
| 703 |
+
def __init__(self, index):
|
| 704 |
+
assert isinstance(index, int)
|
| 705 |
+
self.index = index
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
class CUDAGraphNode:
|
| 709 |
+
"""
|
| 710 |
+
A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool
|
| 711 |
+
and are structured into a tree, where there is a single recording that can precede it (parent) and multiple
|
| 712 |
+
subsequent recordings that may follow (children). A node will have no parent if it is the first recording
|
| 713 |
+
in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which
|
| 714 |
+
would force a dependency.
|
| 715 |
+
|
| 716 |
+
On first recording, all of the live tensors in the current CUDA Graph Node path will be
|
| 717 |
+
reflected in the corresponding private pool. On subsequent executions, the caching allocator
|
| 718 |
+
is unaffected when the graph is replayed.
|
| 719 |
+
|
| 720 |
+
In order to support recording a subsequent cuda graph recording after execution of this graph,
|
| 721 |
+
we checkpoint the state of the memory pool so that it may later be resumed.
|
| 722 |
+
|
| 723 |
+
WrappedFunction should have already been warmed up prior to invocation.
|
| 724 |
+
|
| 725 |
+
See [setCheckpointPoolState] for further explanation, as well as
|
| 726 |
+
https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png
|
| 727 |
+
"""
|
| 728 |
+
|
| 729 |
+
def __init__(
|
| 730 |
+
self,
|
| 731 |
+
wrapped_function: WrappedFunction,
|
| 732 |
+
id: GraphID,
|
| 733 |
+
parent: Optional[CUDAGraphNode],
|
| 734 |
+
inputs: List[Tensor],
|
| 735 |
+
cuda_graphs_pool: Tuple[int, int],
|
| 736 |
+
device_index: int,
|
| 737 |
+
stack_traces: Optional[StackTraces],
|
| 738 |
+
stream: torch.cuda.Stream,
|
| 739 |
+
):
|
| 740 |
+
assert isinstance(inputs, (list, tuple))
|
| 741 |
+
|
| 742 |
+
self.wrapped_function = wrapped_function
|
| 743 |
+
self.id = id
|
| 744 |
+
self.device = device_index
|
| 745 |
+
self.stack_traces = stack_traces
|
| 746 |
+
self.stream = stream
|
| 747 |
+
|
| 748 |
+
# if this is a root parent will be None. use weakref to prevent reference cycle
|
| 749 |
+
self._parent = weakref.ref(parent) if parent is not None else None
|
| 750 |
+
# reference to the shared memory pool for the entire cuda graphs tree
|
| 751 |
+
self.cuda_graphs_pool = cuda_graphs_pool
|
| 752 |
+
|
| 753 |
+
# A single wrapped function may be recorded multiple times if memory patterns or
|
| 754 |
+
# invariants change from one execution to the next
|
| 755 |
+
self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
|
| 756 |
+
|
| 757 |
+
# StorageWeakRef maintains whether the Storage C++ object remains allocated,
|
| 758 |
+
# not whether the corresponding memory has been deallocated. In order
|
| 759 |
+
# to use them to track memory deallocations we must maintain a single StorageWeakRef
|
| 760 |
+
# for all Storages that reference that memory (even if we are constructing Storages
|
| 761 |
+
# that do not have a deallocator function). We maintain one single storage_cache
|
| 762 |
+
# as we execute any tree path. When we retrieve a storage from the cache we
|
| 763 |
+
# check that it is still alive, and we hash based on observed recording data ptr
|
| 764 |
+
# and storage cdata.
|
| 765 |
+
|
| 766 |
+
# we preserve a single reference to executed outputs that is then referenced
|
| 767 |
+
# in children to avoid children having to chase parent pointers in the hot path
|
| 768 |
+
# DO NOT reassign output_weakrefs, only call `clear()`
|
| 769 |
+
# Path is a series of nodes from root to the current node
|
| 770 |
+
self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = []
|
| 771 |
+
self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [
|
| 772 |
+
node.outputs_weakrefs for node in self._path_from_root
|
| 773 |
+
]
|
| 774 |
+
self.path_stacktraces: LevelList[StackTraces] = [
|
| 775 |
+
node.stack_traces for node in self._path_from_root
|
| 776 |
+
]
|
| 777 |
+
self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
|
| 778 |
+
|
| 779 |
+
# tensors which are outputs of previous graphs in the tree
|
| 780 |
+
self.cudagraph_managed_idxs: List[int] = [
|
| 781 |
+
idx
|
| 782 |
+
for idx, t in enumerate(inputs)
|
| 783 |
+
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
|
| 784 |
+
]
|
| 785 |
+
|
| 786 |
+
self.static_input_idxs: List[int] = list(
|
| 787 |
+
set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
self.static_input_data_ptrs: InputList[Optional[int]] = [
|
| 791 |
+
(
|
| 792 |
+
inputs[i].data_ptr()
|
| 793 |
+
if isinstance(inputs[i], torch.Tensor) and i in self.static_input_idxs
|
| 794 |
+
else None
|
| 795 |
+
)
|
| 796 |
+
for i in range(len(inputs))
|
| 797 |
+
]
|
| 798 |
+
|
| 799 |
+
# When we checkpoint, and free generations, we will be manually freeing the outputs
|
| 800 |
+
# of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for
|
| 801 |
+
# their liveness (they are static), so we need to compute which outputs are aliases of
|
| 802 |
+
# parameters. Some static inputs are saved tensors from the forward that die in the backward.
|
| 803 |
+
# Their locations are static but lifetimes are not. We only include the persistent static
|
| 804 |
+
# data ptrs below because the non persistent data ptrs may be outputs of this record and
|
| 805 |
+
# fresh allocations.
|
| 806 |
+
|
| 807 |
+
# precompute expanded dims to avoid computing in the hot path
|
| 808 |
+
self.expanded_dims: List[List[int]] = [
|
| 809 |
+
get_expanded_dims(x)
|
| 810 |
+
if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs
|
| 811 |
+
else []
|
| 812 |
+
for idx, x in enumerate(inputs)
|
| 813 |
+
]
|
| 814 |
+
|
| 815 |
+
# For each node in path, which outputs were observed to be live
|
| 816 |
+
# before invoking graph recording, and after graph recording
|
| 817 |
+
self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = []
|
| 818 |
+
self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = []
|
| 819 |
+
|
| 820 |
+
# List of Tuples of (depth, output_index) that index into node at depth
|
| 821 |
+
# number of nodes from root and output_index of outputs. Will index into
|
| 822 |
+
# path_weakrefs.
|
| 823 |
+
self.expected_dead_indices_before_graph: List[PathOutputIndex] = []
|
| 824 |
+
self.expected_dead_indices_after_graph: List[PathOutputIndex] = []
|
| 825 |
+
|
| 826 |
+
# all live indices after graph recording
|
| 827 |
+
self.live_indices_after_graph: List[PathOutputIndex] = []
|
| 828 |
+
|
| 829 |
+
if self.parent is not None:
|
| 830 |
+
previous_liveness = self.parent.recorded_liveness_after_graph
|
| 831 |
+
curr_liveness = self._get_liveness(self.path_weakrefs)
|
| 832 |
+
|
| 833 |
+
different_indices = self._get_different_indices(
|
| 834 |
+
previous_liveness, curr_liveness
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
self.recorded_liveness_before_graph = curr_liveness
|
| 838 |
+
self.expected_dead_indices_before_graph = different_indices
|
| 839 |
+
|
| 840 |
+
recording_inputs = self._allocate_and_copy_recording_inputs(inputs)
|
| 841 |
+
# recording inputs will copy over memory, so we can free non recording inputs
|
| 842 |
+
inputs.clear()
|
| 843 |
+
del inputs
|
| 844 |
+
|
| 845 |
+
# graph used for recording model invocation
|
| 846 |
+
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
|
| 847 |
+
|
| 848 |
+
# we allocate non-static inputs within the same memory pool as the CUDAGraph
|
| 849 |
+
# which we will record the model with. For memory efficiency, it is important
|
| 850 |
+
# to reclaim the input memory when the inputs are no longer live. To accomplish this,
|
| 851 |
+
# we reconstruct tensors at the correct data pointers of our inputs which are
|
| 852 |
+
# non owning and do not prevent deallocation. On subsequent executions, input values
|
| 853 |
+
# will be copied over to these tensors.
|
| 854 |
+
self.reconstructed_inputs: InputList[Union[Tensor, int]] = [
|
| 855 |
+
self._reconstruct_from_tensor_metadata(self._tensor_metadata(x))
|
| 856 |
+
if isinstance(x, torch.Tensor)
|
| 857 |
+
else x
|
| 858 |
+
for x in recording_inputs
|
| 859 |
+
]
|
| 860 |
+
|
| 861 |
+
# DO THE RECORDING!!!
|
| 862 |
+
# We record the CUDA graph in the constructor of CUDAGraphNode, which
|
| 863 |
+
# gives you what the CPU side compute of the function would do. We
|
| 864 |
+
# don't throw the recording outputs away: their memory is
|
| 865 |
+
# correctly accounted for in the CUDAGraphs caching allocator. This
|
| 866 |
+
# means on the very FIRST run of the CUDA graph node, we can directly
|
| 867 |
+
# do more recording, because we have a valid caching allocator state.
|
| 868 |
+
# NB: This relies on run() being called immediately after the
|
| 869 |
+
# constructor, otherwise this optimization would not be valid.
|
| 870 |
+
|
| 871 |
+
# initialized below in _record
|
| 872 |
+
|
| 873 |
+
self.checkpointed_caching_state: Optional[AllocatorState] = None
|
| 874 |
+
|
| 875 |
+
# Output Storage Alias information, can be:
|
| 876 |
+
# - A new, unaliased storage, or the output is None
|
| 877 |
+
# - An alias of an output of a prior graph
|
| 878 |
+
# - An alias of an output already created in the reconstructed outputs
|
| 879 |
+
# This is None if the output in question is an int
|
| 880 |
+
self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = []
|
| 881 |
+
|
| 882 |
+
# is the output Storage unaliased in subsequent outputs, of all subsequent paths
|
| 883 |
+
# if it is, we cached the output tensor and adjust storage liveness tracking to also
|
| 884 |
+
# check if the output tensor does not have an additional python reference.
|
| 885 |
+
# If a descendent node discovers it has an alias of a prior output, then the output
|
| 886 |
+
# will no longer be cached in the ancestor.
|
| 887 |
+
# The large majority of tensors are unaliased, and preserving aliased output tensors would add
|
| 888 |
+
# significant additional complexity with marginal gains
|
| 889 |
+
# The cached tensor outputs are added on the first execution, and cleared whenever we need
|
| 890 |
+
# to do subsequent recording
|
| 891 |
+
self.unaliased_in_all_paths: OutputList[bool] = []
|
| 892 |
+
self.cached_tensor_outputs: OutputList[Optional[Tensor]] = []
|
| 893 |
+
|
| 894 |
+
# if an output aliases a static, persistent input then the corresponding Tensor will
|
| 895 |
+
# be set here. These are different than cached tensors, because they are tensors that
|
| 896 |
+
# are aliases of parameters that are always live.
|
| 897 |
+
self.static_output_tensors: OutputList[Optional[Tensor]] = []
|
| 898 |
+
|
| 899 |
+
# Cleared after recording
|
| 900 |
+
self.recording_outputs: Optional[
|
| 901 |
+
OutputList[Union[torch.Tensor, int]]
|
| 902 |
+
] = self._record(wrapped_function.model, recording_inputs)
|
| 903 |
+
self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = []
|
| 904 |
+
|
| 905 |
+
# As with inputs, we do not want to keep the outputs permanently alive because that would prevent
|
| 906 |
+
# their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata
|
| 907 |
+
# needed to reconstruct instead.
|
| 908 |
+
assert self.recording_outputs is not None
|
| 909 |
+
for out in self.recording_outputs:
|
| 910 |
+
if isinstance(out, torch.Tensor):
|
| 911 |
+
self.outputs_metadata.append(
|
| 912 |
+
self._tensor_metadata(out, ignore_storage_offset=False)
|
| 913 |
+
)
|
| 914 |
+
else:
|
| 915 |
+
assert isinstance(out, (int, type(None))), type(out)
|
| 916 |
+
self.outputs_metadata.append(out)
|
| 917 |
+
|
| 918 |
+
self.graph.replay()
|
| 919 |
+
|
| 920 |
+
def _copy_input(self, idx, dst, src):
|
| 921 |
+
expanded_dims = self.expanded_dims[idx]
|
| 922 |
+
dst = index_expanded_dims(dst, expanded_dims)
|
| 923 |
+
src = index_expanded_dims(src, expanded_dims)
|
| 924 |
+
# TODO - one jit kernel across multiple inputs
|
| 925 |
+
dst.copy_(src)
|
| 926 |
+
|
| 927 |
+
def run_first_inputs(self, new_inputs):
|
| 928 |
+
if config.triton.fast_path_cudagraph_asserts:
|
| 929 |
+
self.debug_check_invariants_before_invocation()
|
| 930 |
+
|
| 931 |
+
# graph is already invoked in the __init__
|
| 932 |
+
# inputs are copied over in _allocate_recording_inputs and subsequently cleared
|
| 933 |
+
assert len(new_inputs) == 0
|
| 934 |
+
outputs = self.recording_outputs
|
| 935 |
+
self.recording_outputs = None
|
| 936 |
+
return outputs
|
| 937 |
+
|
| 938 |
+
def run(self, new_inputs):
|
| 939 |
+
if config.triton.fast_path_cudagraph_asserts:
|
| 940 |
+
self.debug_check_invariants_before_invocation()
|
| 941 |
+
|
| 942 |
+
assert len(self.static_input_data_ptrs) == len(new_inputs)
|
| 943 |
+
# NB: this ranges over non-static inputs too
|
| 944 |
+
for idx, data_ptr in enumerate(self.static_input_data_ptrs):
|
| 945 |
+
if idx in self.cudagraph_managed_idxs:
|
| 946 |
+
continue
|
| 947 |
+
if not isinstance(new_inputs[idx], torch.Tensor):
|
| 948 |
+
pass
|
| 949 |
+
elif data_ptr is not None:
|
| 950 |
+
# static input, e.g., parameter
|
| 951 |
+
assert data_ptr == new_inputs[idx].data_ptr()
|
| 952 |
+
else:
|
| 953 |
+
# non-static input, need to copy it into CUDA graph
|
| 954 |
+
dst = self.reconstructed_inputs[idx]
|
| 955 |
+
src = new_inputs[idx]
|
| 956 |
+
self._copy_input(idx, dst, src)
|
| 957 |
+
|
| 958 |
+
new_inputs.clear()
|
| 959 |
+
self.run_graph()
|
| 960 |
+
|
| 961 |
+
outputs = self.reconstruct_outputs()
|
| 962 |
+
self.debug_check_invariants_after_invocation()
|
| 963 |
+
|
| 964 |
+
return outputs
|
| 965 |
+
|
| 966 |
+
def reconstruct_outputs(self):
|
| 967 |
+
"Reconstruct output tensors according to their saved metadata and alias information"
|
| 968 |
+
|
| 969 |
+
# Cached tensors will not yet be set on the first execution
|
| 970 |
+
# They are also cleared in checkpointing, so if we checkpoint this node
|
| 971 |
+
# and then execute it again we will need to repopulate cached tensors
|
| 972 |
+
if not self.cached_tensor_outputs:
|
| 973 |
+
self._initialize_cached_tensors()
|
| 974 |
+
|
| 975 |
+
outputs: List[Optional[Union[int, torch.Tensor]]] = []
|
| 976 |
+
|
| 977 |
+
for i, (storage_info, metadata) in enumerate(
|
| 978 |
+
zip(self.output_storage_alias, self.outputs_metadata)
|
| 979 |
+
):
|
| 980 |
+
if not isinstance(metadata, dict): # tensor metadata
|
| 981 |
+
assert isinstance(metadata, (int, type(None)))
|
| 982 |
+
outputs.append(metadata)
|
| 983 |
+
continue
|
| 984 |
+
|
| 985 |
+
cached_t = self.cached_tensor_outputs[i]
|
| 986 |
+
if cached_t is not None:
|
| 987 |
+
# No need to update weakrefs, already correctly initialized
|
| 988 |
+
outputs.append(cached_t)
|
| 989 |
+
continue
|
| 990 |
+
|
| 991 |
+
static_t = self.static_output_tensors[i]
|
| 992 |
+
if static_t is not None:
|
| 993 |
+
assert self.outputs_weakrefs[i] is None
|
| 994 |
+
outputs.append(static_t)
|
| 995 |
+
continue
|
| 996 |
+
|
| 997 |
+
storage = self.prepare_alias_info_for_tensor_construction(
|
| 998 |
+
storage_info, metadata
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
if isinstance(storage, UntypedStorage) or storage is None:
|
| 1002 |
+
out = self._reconstruct_from_tensor_metadata(metadata, storage)
|
| 1003 |
+
else:
|
| 1004 |
+
assert isinstance(storage, int)
|
| 1005 |
+
out = self._reconstruct_from_tensor_metadata(
|
| 1006 |
+
metadata, cast(torch.Tensor, outputs[storage]).untyped_storage()
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
outputs.append(out)
|
| 1010 |
+
w = self.outputs_weakrefs[i]
|
| 1011 |
+
assert w is not None
|
| 1012 |
+
w.swap_weakref(out.untyped_storage()._weak_ref())
|
| 1013 |
+
|
| 1014 |
+
return outputs
|
| 1015 |
+
|
| 1016 |
+
def prepare_alias_info_for_tensor_construction(
|
| 1017 |
+
self,
|
| 1018 |
+
out_alias_info: Optional[OutputAliasInfo],
|
| 1019 |
+
metadata: Union[Dict[str, Any], int, None],
|
| 1020 |
+
) -> Union[UntypedStorage, None, int]:
|
| 1021 |
+
if (
|
| 1022 |
+
isinstance(metadata, (int, type(None)))
|
| 1023 |
+
or out_alias_info is UnaliasedStorage
|
| 1024 |
+
):
|
| 1025 |
+
return None
|
| 1026 |
+
|
| 1027 |
+
if isinstance(out_alias_info, AliasesPriorGraphOutput):
|
| 1028 |
+
depth, existing_output_index = out_alias_info.index
|
| 1029 |
+
ref = self.path_weakrefs[depth][existing_output_index]
|
| 1030 |
+
assert ref is not None
|
| 1031 |
+
return torch.UntypedStorage._new_with_weak_ptr(ref())
|
| 1032 |
+
|
| 1033 |
+
assert isinstance(out_alias_info, AliasesNewOutput)
|
| 1034 |
+
return out_alias_info.index
|
| 1035 |
+
|
| 1036 |
+
def prepare_storages_for_construction(
|
| 1037 |
+
self,
|
| 1038 |
+
) -> List[Union[UntypedStorage, None, int]]:
|
| 1039 |
+
output_storages = []
|
| 1040 |
+
for output_storage_alias, metadata in zip(
|
| 1041 |
+
self.output_storage_alias, self.outputs_metadata
|
| 1042 |
+
):
|
| 1043 |
+
output_storages.append(
|
| 1044 |
+
self.prepare_alias_info_for_tensor_construction(
|
| 1045 |
+
output_storage_alias, metadata
|
| 1046 |
+
)
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
return output_storages
|
| 1050 |
+
|
| 1051 |
+
def run_graph(self):
|
| 1052 |
+
assert self.graph is not None
|
| 1053 |
+
self.graph.replay()
|
| 1054 |
+
|
| 1055 |
+
def all_outputs_are_dead(self):
|
| 1056 |
+
"All outputs of the path from this node to its root are dead"
|
| 1057 |
+
for depth, output_index in self.live_indices_after_graph:
|
| 1058 |
+
if is_live(self.path_weakrefs[depth][output_index]):
|
| 1059 |
+
return False
|
| 1060 |
+
return True
|
| 1061 |
+
|
| 1062 |
+
def _record(self, model, inputs):
|
| 1063 |
+
"Record the model"
|
| 1064 |
+
|
| 1065 |
+
def static_input_iter():
|
| 1066 |
+
for i in self.wrapped_function.static_input_idxs:
|
| 1067 |
+
if isinstance(
|
| 1068 |
+
inputs[i], torch.Tensor
|
| 1069 |
+
) and not self._is_cuda_graph_recorded_tensor(inputs[i]):
|
| 1070 |
+
yield inputs[i]
|
| 1071 |
+
|
| 1072 |
+
# see: output_is_alias_of_persistent_static_inputs above
|
| 1073 |
+
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = {
|
| 1074 |
+
inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp)
|
| 1075 |
+
for inp in itertools.chain(
|
| 1076 |
+
static_input_iter(), self.wrapped_function.constants
|
| 1077 |
+
)
|
| 1078 |
+
}
|
| 1079 |
+
|
| 1080 |
+
if config.triton.slow_path_cudagraph_asserts:
|
| 1081 |
+
# need to use parent live weakrefs because live_indices isnt set yet
|
| 1082 |
+
memory = (
|
| 1083 |
+
[] if self.parent is None else list(self.parent.path_live_weakrefs())
|
| 1084 |
+
)
|
| 1085 |
+
memory += [
|
| 1086 |
+
StorageWeakRefWrapper(elem)
|
| 1087 |
+
for i, elem in enumerate(inputs)
|
| 1088 |
+
if isinstance(elem, torch.Tensor)
|
| 1089 |
+
and i not in self.wrapped_function.static_input_idxs
|
| 1090 |
+
and elem.untyped_storage().data_ptr() != 0
|
| 1091 |
+
]
|
| 1092 |
+
check_memory_pool(self.device, self.cuda_graphs_pool, memory)
|
| 1093 |
+
|
| 1094 |
+
with preserve_rng_state(), torch.cuda.device(
|
| 1095 |
+
self.device
|
| 1096 |
+
), clear_cublas_manager(), torch.cuda.graph(
|
| 1097 |
+
self.graph,
|
| 1098 |
+
stream=self.stream,
|
| 1099 |
+
pool=self.cuda_graphs_pool,
|
| 1100 |
+
capture_error_mode="thread_local",
|
| 1101 |
+
), get_history_recording():
|
| 1102 |
+
static_outputs = model(inputs)
|
| 1103 |
+
|
| 1104 |
+
# running model should reclaim memory
|
| 1105 |
+
assert len(inputs) == 0
|
| 1106 |
+
|
| 1107 |
+
if not isinstance(static_outputs, (list, tuple)):
|
| 1108 |
+
static_outputs = (static_outputs,)
|
| 1109 |
+
|
| 1110 |
+
self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs)
|
| 1111 |
+
|
| 1112 |
+
return static_outputs
|
| 1113 |
+
|
| 1114 |
+
def _add_first_outputs(
|
| 1115 |
+
self,
|
| 1116 |
+
outputs,
|
| 1117 |
+
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper],
|
| 1118 |
+
):
|
| 1119 |
+
"Add the outputs from the first invocation of the node and set up metadata"
|
| 1120 |
+
|
| 1121 |
+
# getting liveness before we have added the outputs to path, so the length
|
| 1122 |
+
# of the two lists is equal
|
| 1123 |
+
prev_liveness = self.recorded_liveness_before_graph
|
| 1124 |
+
curr_liveness = self._get_liveness(self.path_weakrefs)
|
| 1125 |
+
|
| 1126 |
+
delta = self._get_different_indices(prev_liveness, curr_liveness)
|
| 1127 |
+
self.expected_dead_indices_after_graph = delta
|
| 1128 |
+
|
| 1129 |
+
assert len(self.outputs_weakrefs) == 0
|
| 1130 |
+
# index from data pointer to index in outputs
|
| 1131 |
+
output_new_storages_index: Dict[StorageDataPtr, int] = {}
|
| 1132 |
+
|
| 1133 |
+
self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
|
| 1134 |
+
self.static_output_tensors = [None for _ in range(len(outputs))]
|
| 1135 |
+
|
| 1136 |
+
for i, o in enumerate(outputs):
|
| 1137 |
+
if o is None or not isinstance(o, torch.Tensor):
|
| 1138 |
+
self.output_storage_alias.append(UnaliasedStorage)
|
| 1139 |
+
continue
|
| 1140 |
+
|
| 1141 |
+
torch._check(
|
| 1142 |
+
o.is_cuda or o.untyped_storage().data_ptr() == 0,
|
| 1143 |
+
lambda: (
|
| 1144 |
+
"Expected all cuda outputs in cuda graph recording. Non cuda output "
|
| 1145 |
+
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
|
| 1146 |
+
),
|
| 1147 |
+
),
|
| 1148 |
+
|
| 1149 |
+
ref = static_input_persistent_storage_ptrs.get(
|
| 1150 |
+
o.untyped_storage().data_ptr(), None
|
| 1151 |
+
)
|
| 1152 |
+
# also treat empty storages as static outputs because we do not need to manage their lifetime
|
| 1153 |
+
# and they should not participate in checkpointing
|
| 1154 |
+
is_empty_storage = o.untyped_storage().data_ptr() == 0
|
| 1155 |
+
if (ref and ref() is not None) or is_empty_storage:
|
| 1156 |
+
self.output_storage_alias.append(None)
|
| 1157 |
+
self.static_output_tensors[i] = o
|
| 1158 |
+
continue
|
| 1159 |
+
|
| 1160 |
+
path_ref = self._is_alias_of_live_recorded_tensor(o)
|
| 1161 |
+
if path_ref is not None:
|
| 1162 |
+
self._mark_prior_graph_output_as_aliased(path_ref)
|
| 1163 |
+
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
|
| 1164 |
+
continue
|
| 1165 |
+
|
| 1166 |
+
if o.untyped_storage().data_ptr() in output_new_storages_index:
|
| 1167 |
+
index = output_new_storages_index[o.untyped_storage().data_ptr()]
|
| 1168 |
+
self.unaliased_in_all_paths[index] = False
|
| 1169 |
+
self.output_storage_alias.append(AliasesNewOutput(index))
|
| 1170 |
+
continue
|
| 1171 |
+
|
| 1172 |
+
output_new_storages_index[o.untyped_storage().data_ptr()] = i
|
| 1173 |
+
self.output_storage_alias.append(UnaliasedStorage)
|
| 1174 |
+
self.unaliased_in_all_paths[i] = True
|
| 1175 |
+
|
| 1176 |
+
if self.stack_traces is None:
|
| 1177 |
+
self.stack_traces = [None for _ in range(len(outputs))]
|
| 1178 |
+
else:
|
| 1179 |
+
assert len(self.stack_traces) == len(
|
| 1180 |
+
outputs
|
| 1181 |
+
), "Wrong number of stack traces passed in"
|
| 1182 |
+
|
| 1183 |
+
assert not self.outputs_weakrefs
|
| 1184 |
+
for out, static_output_tensor in zip(outputs, self.static_output_tensors):
|
| 1185 |
+
if not isinstance(out, torch.Tensor) or static_output_tensor is not None:
|
| 1186 |
+
self.outputs_weakrefs.append(None)
|
| 1187 |
+
self.tensor_weakrefs.append(None)
|
| 1188 |
+
else:
|
| 1189 |
+
self.outputs_weakrefs.append(StorageWeakRefWrapper(out))
|
| 1190 |
+
self.tensor_weakrefs.append(TensorWeakRef(out))
|
| 1191 |
+
|
| 1192 |
+
self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs)
|
| 1193 |
+
self.checkpointed_caching_state = torch._C._cuda_getCheckpointState(
|
| 1194 |
+
self.device, self.cuda_graphs_pool
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
# now, get liveness with outputs added
|
| 1198 |
+
for depth in range(len(self.path_weakrefs)):
|
| 1199 |
+
for output_index in range(len(self.path_weakrefs[depth])):
|
| 1200 |
+
if is_live(self.path_weakrefs[depth][output_index]):
|
| 1201 |
+
self.live_indices_after_graph.append((depth, output_index))
|
| 1202 |
+
|
| 1203 |
+
self.debug_check_invariants_after_invocation()
|
| 1204 |
+
if config.triton.slow_path_cudagraph_asserts:
|
| 1205 |
+
check_memory_pool(
|
| 1206 |
+
self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs())
|
| 1207 |
+
)
|
| 1208 |
+
|
| 1209 |
+
def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex):
|
| 1210 |
+
"Remove a graph output from the unaliased, cached tensors in an ancestor node"
|
| 1211 |
+
depth, output_index = index
|
| 1212 |
+
node = list(self._path_from_root)[depth]
|
| 1213 |
+
node.unaliased_in_all_paths[output_index] = False
|
| 1214 |
+
x = self.path_weakrefs[depth][output_index]
|
| 1215 |
+
assert x is not None
|
| 1216 |
+
x.remove_extra_reference()
|
| 1217 |
+
|
| 1218 |
+
def _initialize_cached_tensors(self):
|
| 1219 |
+
# we should not be clearing output_weakrefs, and they should be set in the first
|
| 1220 |
+
# record run
|
| 1221 |
+
assert len(self.outputs_weakrefs) == len(self.outputs_metadata)
|
| 1222 |
+
|
| 1223 |
+
for i, (storage_info, metadata, make_cached) in enumerate(
|
| 1224 |
+
zip(
|
| 1225 |
+
self.output_storage_alias,
|
| 1226 |
+
self.outputs_metadata,
|
| 1227 |
+
self.unaliased_in_all_paths,
|
| 1228 |
+
)
|
| 1229 |
+
):
|
| 1230 |
+
if not make_cached:
|
| 1231 |
+
self.cached_tensor_outputs.append(None)
|
| 1232 |
+
continue
|
| 1233 |
+
|
| 1234 |
+
assert storage_info is UnaliasedStorage
|
| 1235 |
+
assert isinstance(metadata, dict)
|
| 1236 |
+
s = self.create_storage(metadata)
|
| 1237 |
+
out = self._reconstruct_from_tensor_metadata(metadata, storage=s)
|
| 1238 |
+
|
| 1239 |
+
# XXX: let autograd know that there will be an additional reference to the tensor
|
| 1240 |
+
# that can be ignored when deciding whether to do gradient buffer inplacing.
|
| 1241 |
+
# Otherwise, inplacing could differ between tracing and subsequent execution.
|
| 1242 |
+
# For some models we tested this led to inputs no longer being in cudagraph pools,
|
| 1243 |
+
# leading to spurious re-recordings.
|
| 1244 |
+
# It also tells AMP cache that even though the tensor impls cannot be cached
|
| 1245 |
+
# in dtype conversions.
|
| 1246 |
+
|
| 1247 |
+
torch._C._add_cached_tensor(out)
|
| 1248 |
+
|
| 1249 |
+
self_ref = weakref.ref(self)
|
| 1250 |
+
|
| 1251 |
+
# one reference in our array, and calling sys.getrefcount bumps the refcount by one
|
| 1252 |
+
def check_refcount(i):
|
| 1253 |
+
self_loc = self_ref()
|
| 1254 |
+
if self_loc is None:
|
| 1255 |
+
return False
|
| 1256 |
+
return self_loc.get_output_refcount(i) == 2
|
| 1257 |
+
|
| 1258 |
+
check = functools.partial(check_refcount, i=i)
|
| 1259 |
+
|
| 1260 |
+
self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check)
|
| 1261 |
+
self.cached_tensor_outputs.append(out)
|
| 1262 |
+
|
| 1263 |
+
def get_output_refcount(self, index):
|
| 1264 |
+
return sys.getrefcount(self.cached_tensor_outputs[index])
|
| 1265 |
+
|
| 1266 |
+
@property
|
| 1267 |
+
def parent(self):
|
| 1268 |
+
"unwraps the weakref to _parent"
|
| 1269 |
+
return self._parent() if self._parent is not None else None
|
| 1270 |
+
|
| 1271 |
+
@property
|
| 1272 |
+
def _path_to_root(self):
|
| 1273 |
+
"Returns all nodes in the path starting at self and ending at root"
|
| 1274 |
+
node = self
|
| 1275 |
+
while node:
|
| 1276 |
+
yield node
|
| 1277 |
+
node = node.parent
|
| 1278 |
+
|
| 1279 |
+
@property
|
| 1280 |
+
def _path_from_root(self):
|
| 1281 |
+
"Returns all nodes in the path starting at the root and ending at self"
|
| 1282 |
+
nodes = reversed(list(self._path_to_root))
|
| 1283 |
+
yield from nodes
|
| 1284 |
+
|
| 1285 |
+
def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor):
|
| 1286 |
+
"Is this tensor an output of a node in this path"
|
| 1287 |
+
for output_refs in self.path_weakrefs:
|
| 1288 |
+
for storage_weak_ref in output_refs:
|
| 1289 |
+
if storage_weak_ref is None:
|
| 1290 |
+
continue
|
| 1291 |
+
# don't need to check liveness of storage since the cuda graph managed
|
| 1292 |
+
# memory is never released.
|
| 1293 |
+
data_ptr = storage_weak_ref.data_ptr()
|
| 1294 |
+
if t.untyped_storage().data_ptr() == data_ptr:
|
| 1295 |
+
return True
|
| 1296 |
+
|
| 1297 |
+
return False
|
| 1298 |
+
|
| 1299 |
+
def _is_alias_of_live_recorded_tensor(
|
| 1300 |
+
self, t: torch.Tensor
|
| 1301 |
+
) -> Optional[PathOutputIndex]:
|
| 1302 |
+
for depth, output_refs in enumerate(self.path_weakrefs):
|
| 1303 |
+
for output_index, storage_ref in enumerate(output_refs):
|
| 1304 |
+
if (storage_and_ptr := maybe_deref(storage_ref)) is not None:
|
| 1305 |
+
storage, ptr = storage_and_ptr
|
| 1306 |
+
if ptr == t.untyped_storage().data_ptr():
|
| 1307 |
+
return (depth, output_index)
|
| 1308 |
+
|
| 1309 |
+
return None
|
| 1310 |
+
|
| 1311 |
+
@staticmethod
|
| 1312 |
+
def _check_liveness(
|
| 1313 |
+
indices: List[PathOutputIndex],
|
| 1314 |
+
output_refs: List[List[Optional[StorageWeakRefWrapper]]],
|
| 1315 |
+
):
|
| 1316 |
+
"Check that all of the indices specified are dead references"
|
| 1317 |
+
for depth, output_index in indices:
|
| 1318 |
+
w = output_refs[depth][output_index]
|
| 1319 |
+
assert w is not None
|
| 1320 |
+
if w() is not None:
|
| 1321 |
+
return False
|
| 1322 |
+
return True
|
| 1323 |
+
|
| 1324 |
+
def add_child(self, function_id: FunctionID, node: CUDAGraphNode):
|
| 1325 |
+
"Adds node as a a child of self"
|
| 1326 |
+
self.children[function_id].append(node)
|
| 1327 |
+
|
| 1328 |
+
@staticmethod
|
| 1329 |
+
def _get_different_indices(
|
| 1330 |
+
prev: List[List[bool]], curr: List[List[bool]]
|
| 1331 |
+
) -> List[PathOutputIndex]:
|
| 1332 |
+
"Find indices where the two lists differ."
|
| 1333 |
+
dead_indices = []
|
| 1334 |
+
assert len(prev) <= len(curr)
|
| 1335 |
+
for i, (outputs1, outputs2) in enumerate(zip(prev, curr)):
|
| 1336 |
+
assert len(outputs1) == len(outputs2)
|
| 1337 |
+
for j, (output1, output2) in enumerate(zip(outputs1, outputs2)):
|
| 1338 |
+
if output1 != output2:
|
| 1339 |
+
dead_indices.append((i, j))
|
| 1340 |
+
|
| 1341 |
+
return dead_indices
|
| 1342 |
+
|
| 1343 |
+
@staticmethod
|
| 1344 |
+
def _get_liveness(
|
| 1345 |
+
weakrefs: List[List[Optional[StorageWeakRefWrapper]]],
|
| 1346 |
+
) -> List[List[bool]]:
|
| 1347 |
+
"Maps weakrefs to true if the reference is alive and false otherwise"
|
| 1348 |
+
if len(weakrefs) == 0:
|
| 1349 |
+
return []
|
| 1350 |
+
|
| 1351 |
+
return [pytree.tree_map(is_live, outputs) for outputs in weakrefs]
|
| 1352 |
+
|
| 1353 |
+
def debug_assert_invariants(
|
| 1354 |
+
self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex]
|
| 1355 |
+
):
|
| 1356 |
+
if not config.triton.fast_path_cudagraph_asserts:
|
| 1357 |
+
return
|
| 1358 |
+
|
| 1359 |
+
for i, node in enumerate(self._path_from_root):
|
| 1360 |
+
assert self.path_weakrefs[i] is node.outputs_weakrefs
|
| 1361 |
+
|
| 1362 |
+
nodes = list(self._path_from_root)
|
| 1363 |
+
|
| 1364 |
+
live_blocks = get_block_addrs(self.cuda_graphs_pool)
|
| 1365 |
+
|
| 1366 |
+
live_storage_data_ptrs = set()
|
| 1367 |
+
live_storage_weak_ptrs = set()
|
| 1368 |
+
|
| 1369 |
+
for depth, outputs_liveness in enumerate(expected_liveness):
|
| 1370 |
+
for output_idx, output_liveness in enumerate(outputs_liveness):
|
| 1371 |
+
# tensor can die early, but it can't be alive when it should be dead
|
| 1372 |
+
w = self.path_weakrefs[depth][output_idx]
|
| 1373 |
+
if (stor_weak_ptr_and_data_ptr := maybe_deref(w)) is not None:
|
| 1374 |
+
assert output_liveness
|
| 1375 |
+
stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr
|
| 1376 |
+
assert (stor_data_ptr in live_storage_data_ptrs) == (
|
| 1377 |
+
stor_weak_ptr in live_storage_weak_ptrs
|
| 1378 |
+
)
|
| 1379 |
+
live_storage_data_ptrs.add(stor_data_ptr)
|
| 1380 |
+
live_storage_weak_ptrs.add(stor_weak_ptr)
|
| 1381 |
+
|
| 1382 |
+
is_persistent_alias = (
|
| 1383 |
+
nodes[depth].static_output_tensors[output_idx] is not None
|
| 1384 |
+
)
|
| 1385 |
+
|
| 1386 |
+
if is_persistent_alias:
|
| 1387 |
+
assert stor_data_ptr not in live_blocks
|
| 1388 |
+
|
| 1389 |
+
for depth, output_index in newly_dead:
|
| 1390 |
+
assert not is_live(self.path_weakrefs[depth][output_index])
|
| 1391 |
+
|
| 1392 |
+
def debug_check_invariants_before_invocation(self):
|
| 1393 |
+
self.debug_assert_invariants(
|
| 1394 |
+
self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph
|
| 1395 |
+
)
|
| 1396 |
+
|
| 1397 |
+
def debug_check_invariants_after_invocation(self):
|
| 1398 |
+
self.debug_assert_invariants(
|
| 1399 |
+
self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph
|
| 1400 |
+
)
|
| 1401 |
+
|
| 1402 |
+
def data_ptrs_dead_since_invocation(self) -> List[int]:
|
| 1403 |
+
"""
|
| 1404 |
+
Since this node was invoked, return data ptrs of all tensor outputs that have died
|
| 1405 |
+
in the current executing tree path.
|
| 1406 |
+
"""
|
| 1407 |
+
curr_liveness = self._get_liveness(self.path_weakrefs)
|
| 1408 |
+
_get_different_indices = self._get_different_indices(
|
| 1409 |
+
self.recorded_liveness_after_graph, curr_liveness
|
| 1410 |
+
)
|
| 1411 |
+
|
| 1412 |
+
path = list(self._path_from_root)
|
| 1413 |
+
ptrs_to_deallocate = []
|
| 1414 |
+
for depth, output_index in _get_different_indices:
|
| 1415 |
+
ptrs_to_deallocate.append(
|
| 1416 |
+
path[depth].outputs_metadata[output_index]["data_ptr"]
|
| 1417 |
+
)
|
| 1418 |
+
|
| 1419 |
+
return ptrs_to_deallocate
|
| 1420 |
+
|
| 1421 |
+
def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]:
|
| 1422 |
+
for i, j in self.live_indices_after_graph:
|
| 1423 |
+
out = self.path_weakrefs[i][j]
|
| 1424 |
+
if out is not None and is_live(out):
|
| 1425 |
+
yield out
|
| 1426 |
+
|
| 1427 |
+
def remove_node_cached_tensors(self):
|
| 1428 |
+
for t in self.cached_tensor_outputs:
|
| 1429 |
+
if t is not None:
|
| 1430 |
+
torch._C._remove_cached_tensor(t)
|
| 1431 |
+
self.cached_tensor_outputs.clear()
|
| 1432 |
+
|
| 1433 |
+
for i, unaliased in enumerate(self.unaliased_in_all_paths):
|
| 1434 |
+
if unaliased:
|
| 1435 |
+
n = self.outputs_weakrefs[i]
|
| 1436 |
+
assert n is not None
|
| 1437 |
+
n.remove_extra_reference()
|
| 1438 |
+
|
| 1439 |
+
def remove_path_cached_tensors(self):
|
| 1440 |
+
for node in self._path_from_root:
|
| 1441 |
+
node.remove_node_cached_tensors()
|
| 1442 |
+
|
| 1443 |
+
def clear_path_state(self):
|
| 1444 |
+
"Clear the path state in this current executing node"
|
| 1445 |
+
# this doesnt actually do anything right now, leaving it as placeholder
|
| 1446 |
+
pass
|
| 1447 |
+
|
| 1448 |
+
@staticmethod
|
| 1449 |
+
def _tensor_metadata(x, ignore_storage_offset=True):
|
| 1450 |
+
assert isinstance(x, torch.Tensor)
|
| 1451 |
+
# We ignore the storage offset for inputs, but not for outputs
|
| 1452 |
+
# TODO: - should we make the storage resizable ?
|
| 1453 |
+
return {
|
| 1454 |
+
"nbytes": x.untyped_storage().nbytes(),
|
| 1455 |
+
"data_ptr": x.untyped_storage().data_ptr(),
|
| 1456 |
+
"size": x.shape,
|
| 1457 |
+
"stride": x.stride(),
|
| 1458 |
+
"dtype": x.dtype,
|
| 1459 |
+
"device": x.device,
|
| 1460 |
+
"storage_offset": x.storage_offset() if not ignore_storage_offset else 0,
|
| 1461 |
+
}
|
| 1462 |
+
|
| 1463 |
+
def _reconstruct_from_tensor_metadata(
|
| 1464 |
+
self, metadata: Dict[str, Any], storage=None
|
| 1465 |
+
) -> Tensor:
|
| 1466 |
+
s = self.create_storage(metadata) if storage is None else storage
|
| 1467 |
+
return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s)
|
| 1468 |
+
|
| 1469 |
+
def create_storage(self, metadata):
|
| 1470 |
+
return torch._C._construct_storage_from_data_pointer(
|
| 1471 |
+
metadata["data_ptr"], metadata["device"], metadata["nbytes"]
|
| 1472 |
+
)
|
| 1473 |
+
|
| 1474 |
+
def _allocate_and_copy_recording_inputs(
|
| 1475 |
+
self, inputs
|
| 1476 |
+
) -> List[Union[torch.Tensor, int]]:
|
| 1477 |
+
"""
|
| 1478 |
+
Allocate inputs for non static, non cudagraph managraphed managed tensors in the memory pool
|
| 1479 |
+
and copy over the tensor values.
|
| 1480 |
+
"""
|
| 1481 |
+
|
| 1482 |
+
torch.cuda.synchronize()
|
| 1483 |
+
self.stream.wait_stream(torch.cuda.current_stream())
|
| 1484 |
+
recording_inputs: List[Union[Tensor, int]] = []
|
| 1485 |
+
|
| 1486 |
+
with warnings.catch_warnings(record=True), torch.cuda.device(
|
| 1487 |
+
self.device
|
| 1488 |
+
), _use_cuda_memory_pool_manager(
|
| 1489 |
+
self.device,
|
| 1490 |
+
mem_pool=self.cuda_graphs_pool,
|
| 1491 |
+
stream=self.stream,
|
| 1492 |
+
):
|
| 1493 |
+
for i, inp in enumerate(inputs):
|
| 1494 |
+
if not isinstance(inp, torch.Tensor):
|
| 1495 |
+
assert isinstance(inp, int)
|
| 1496 |
+
recording_inputs.append(inp)
|
| 1497 |
+
elif i not in self.static_input_idxs:
|
| 1498 |
+
# static_input does an allocation!
|
| 1499 |
+
recording_inputs.append(static_input(inp))
|
| 1500 |
+
# copy over and clear non recording input
|
| 1501 |
+
self._copy_input(i, recording_inputs[-1], inp)
|
| 1502 |
+
inputs[i] = None
|
| 1503 |
+
del inp
|
| 1504 |
+
else:
|
| 1505 |
+
recording_inputs.append(inp)
|
| 1506 |
+
|
| 1507 |
+
return recording_inputs
|
| 1508 |
+
|
| 1509 |
+
def check_invariants(self, inputs: List[Tensor]) -> bool:
|
| 1510 |
+
"""
|
| 1511 |
+
Checks if this node can be run. The same pattern of tensor liveness and tensors
|
| 1512 |
+
managed in the cudagraph private pool must remain stable.
|
| 1513 |
+
"""
|
| 1514 |
+
|
| 1515 |
+
# previously managed data pointers remain stable
|
| 1516 |
+
for idx in self.cudagraph_managed_idxs:
|
| 1517 |
+
if inputs[idx].data_ptr() != self.static_input_data_ptrs[idx]:
|
| 1518 |
+
return False
|
| 1519 |
+
|
| 1520 |
+
if not self._check_liveness(
|
| 1521 |
+
self.expected_dead_indices_before_graph, self.path_weakrefs
|
| 1522 |
+
):
|
| 1523 |
+
return False
|
| 1524 |
+
|
| 1525 |
+
# the cudagraph managed tensors which died upon recording must also die upon
|
| 1526 |
+
# this invocation. it is too late to check after we've replayed the graph,
|
| 1527 |
+
# because we would have already written over their memory.
|
| 1528 |
+
for idx in self.cudagraph_managed_idxs:
|
| 1529 |
+
inputs[idx] = None # type: ignore[call-overload]
|
| 1530 |
+
|
| 1531 |
+
torch._check(
|
| 1532 |
+
self._check_liveness(
|
| 1533 |
+
self.expected_dead_indices_after_graph, self.path_weakrefs
|
| 1534 |
+
),
|
| 1535 |
+
lambda: "TODO: graph recording observed an input tensor deallocate during graph "
|
| 1536 |
+
" recording that did not occur during replay. Please file an issue.",
|
| 1537 |
+
)
|
| 1538 |
+
return True
|
| 1539 |
+
|
| 1540 |
+
def num_descendants(self) -> int:
|
| 1541 |
+
"Total number of descendents of this node"
|
| 1542 |
+
num_desc = 0
|
| 1543 |
+
for children in self.children.values():
|
| 1544 |
+
for child in children:
|
| 1545 |
+
num_desc += 1
|
| 1546 |
+
num_desc += child.num_descendants()
|
| 1547 |
+
return num_desc
|
| 1548 |
+
|
| 1549 |
+
|
| 1550 |
+
def get_cudagraph_segments(pool_id):
|
| 1551 |
+
segments = torch.cuda.memory_snapshot()
|
| 1552 |
+
return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
|
| 1553 |
+
|
| 1554 |
+
|
| 1555 |
+
def get_block_addrs(pool_id, live_only=True):
|
| 1556 |
+
blocks = []
|
| 1557 |
+
|
| 1558 |
+
for segment in get_cudagraph_segments(pool_id):
|
| 1559 |
+
addr = segment["address"]
|
| 1560 |
+
for block in segment["blocks"]:
|
| 1561 |
+
if block["state"] == "active_allocated" or not live_only:
|
| 1562 |
+
blocks.append(addr)
|
| 1563 |
+
|
| 1564 |
+
addr += block["size"]
|
| 1565 |
+
|
| 1566 |
+
return blocks
|
| 1567 |
+
|
| 1568 |
+
|
| 1569 |
+
def format_tb(frames):
|
| 1570 |
+
formatted_traceback = []
|
| 1571 |
+
|
| 1572 |
+
for entry in frames:
|
| 1573 |
+
formatted_traceback.append(
|
| 1574 |
+
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
|
| 1575 |
+
)
|
| 1576 |
+
|
| 1577 |
+
return "".join(traceback.format_list(formatted_traceback))
|
| 1578 |
+
|
| 1579 |
+
|
| 1580 |
+
def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWrapper]):
|
| 1581 |
+
assert all(
|
| 1582 |
+
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
|
| 1583 |
+
) # noqa: C419
|
| 1584 |
+
unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()}
|
| 1585 |
+
|
| 1586 |
+
# check if there is a divergence first, then do the expensive snapshot call after
|
| 1587 |
+
# we know it will error
|
| 1588 |
+
if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages):
|
| 1589 |
+
return
|
| 1590 |
+
|
| 1591 |
+
# at this point we are past the fast-path. we have seen rare cases where a dead tensor is dead,
|
| 1592 |
+
# but hasn't been gc'd yet, and gives false positive for allocated_not_in_live_storages
|
| 1593 |
+
gc.collect()
|
| 1594 |
+
|
| 1595 |
+
segments = get_cudagraph_segments(pool_id)
|
| 1596 |
+
|
| 1597 |
+
allocated_not_in_live_storages = {}
|
| 1598 |
+
|
| 1599 |
+
for segment in segments:
|
| 1600 |
+
addr = segment["address"]
|
| 1601 |
+
for block in segment["blocks"]:
|
| 1602 |
+
if block["state"] == "active_allocated":
|
| 1603 |
+
if addr not in unique_storages:
|
| 1604 |
+
allocated_not_in_live_storages[addr] = block
|
| 1605 |
+
else:
|
| 1606 |
+
unique_storages.remove(addr)
|
| 1607 |
+
|
| 1608 |
+
addr += block["size"]
|
| 1609 |
+
|
| 1610 |
+
torch._check(
|
| 1611 |
+
len(unique_storages) == 0,
|
| 1612 |
+
lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
|
| 1613 |
+
)
|
| 1614 |
+
|
| 1615 |
+
if allocated_not_in_live_storages != 0:
|
| 1616 |
+
formatted = []
|
| 1617 |
+
for dp, block in allocated_not_in_live_storages.items():
|
| 1618 |
+
trace = format_tb(block.get("frames", []))
|
| 1619 |
+
formatted.append(f"Data Pointer: {dp}, history: \n{trace}")
|
| 1620 |
+
formatted_s = "\n".join(formatted)
|
| 1621 |
+
msg = (
|
| 1622 |
+
f"These live storage data ptrs are in the cudagraph pool but not "
|
| 1623 |
+
f"accounted for as an output of cudagraph trees: \n\n{formatted_s}"
|
| 1624 |
+
)
|
| 1625 |
+
raise RuntimeError(msg)
|
| 1626 |
+
|
| 1627 |
+
|
| 1628 |
+
class ExecutionState(Enum):
|
| 1629 |
+
"""
|
| 1630 |
+
Represents the state of the CUDAGraph Tree. Will be None if there is no live current memory allocated
|
| 1631 |
+
in the cuda graph pool. Otherwise will reflect the state of the most recently executed node.
|
| 1632 |
+
"""
|
| 1633 |
+
|
| 1634 |
+
NONE = auto()
|
| 1635 |
+
WARMUP = auto()
|
| 1636 |
+
RECORDING = auto()
|
| 1637 |
+
EXECUTION = auto()
|
| 1638 |
+
|
| 1639 |
+
|
| 1640 |
+
class CompilationMode(Enum):
|
| 1641 |
+
FORWARD = auto()
|
| 1642 |
+
BACKWARD = auto()
|
| 1643 |
+
INFERENCE = auto()
|
| 1644 |
+
|
| 1645 |
+
|
| 1646 |
+
class CUDAGraphTreeManager:
|
| 1647 |
+
"""
|
| 1648 |
+
Groups individual recordings or executions of cuda graphs into a tree of recordings,
|
| 1649 |
+
and checks required invariants, and manages warmups of graphs.
|
| 1650 |
+
|
| 1651 |
+
When graphs are recorded in the same tree, it enforces subsequent execution
|
| 1652 |
+
to follow the same order and have the same output tensor livespans. To remove
|
| 1653 |
+
unnecessary coupling of cuda graphs (and additional imposed invariants),
|
| 1654 |
+
the tree manager will end a currently recording tree whenever it is valid - when
|
| 1655 |
+
the memory pool no longer has any live allocations.
|
| 1656 |
+
|
| 1657 |
+
We ignore outputs from a previous generation that correspond to prior model outputs.
|
| 1658 |
+
Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo.
|
| 1659 |
+
# TODO: make generation increment configurable, warn on overwrite.
|
| 1660 |
+
|
| 1661 |
+
We run graph warmups in the cudagraph memory pool and return the result on the first invocation
|
| 1662 |
+
of a function. For many models it is important to reclaim activations as you run the backward.
|
| 1663 |
+
If we were to warm up the model and keep an extra copy of the inputs around to subsequently
|
| 1664 |
+
use for recording, we would incur a memory penalty. Additionally, if we are part way through training
|
| 1665 |
+
your model and need to recompile, memory will be allocated to the cuda graph pool, so we run this
|
| 1666 |
+
warmup run in the cuda graph memory pool. As for recording, warm up needs the state of live tensors
|
| 1667 |
+
to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph
|
| 1668 |
+
replay.
|
| 1669 |
+
"""
|
| 1670 |
+
|
| 1671 |
+
def __init__(self, device_index: int):
|
| 1672 |
+
# roots are functions which have no dependencies on an other node. I.e.,
|
| 1673 |
+
# when they are first invoked, none of their inputs are outputs are outputs
|
| 1674 |
+
# of another node, nor are there any live outputs of another node whose
|
| 1675 |
+
# liveness would create a dependency.
|
| 1676 |
+
self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
|
| 1677 |
+
|
| 1678 |
+
# mapping from function id to wrapped function
|
| 1679 |
+
self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {}
|
| 1680 |
+
|
| 1681 |
+
self.ids_to_stack_traces: Dict[FunctionID, StackTraces] = {}
|
| 1682 |
+
|
| 1683 |
+
self.warmed_up_functions: Set[FunctionID] = set()
|
| 1684 |
+
# if we fail to increment generation, and are stuck warming up,
|
| 1685 |
+
# only warn on each function once
|
| 1686 |
+
self.warned_functions: Set[FunctionID] = set()
|
| 1687 |
+
torch._C._set_cached_tensors_enabled(True)
|
| 1688 |
+
|
| 1689 |
+
# NB: cuda caching allocator will remember the stream a segment is allocated to
|
| 1690 |
+
# and only allocate that segment to the same stream. we need to use a single stream
|
| 1691 |
+
# for all allocations to the memory pool, otherwise the allocations to separate streams
|
| 1692 |
+
# will not be reused; separate recordings would have use the same memory pool, but not
|
| 1693 |
+
# the same memory.
|
| 1694 |
+
|
| 1695 |
+
with torch.cuda.device(device_index):
|
| 1696 |
+
torch.cuda.synchronize()
|
| 1697 |
+
self.stream = torch.cuda.Stream()
|
| 1698 |
+
self.stream.wait_stream(torch.cuda.current_stream())
|
| 1699 |
+
|
| 1700 |
+
# Keeps Memory Pool Alive
|
| 1701 |
+
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
|
| 1702 |
+
self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle()
|
| 1703 |
+
|
| 1704 |
+
with warnings.catch_warnings(record=True), torch.cuda.graph(
|
| 1705 |
+
self.graph,
|
| 1706 |
+
pool=self.cuda_graphs_thread_pool,
|
| 1707 |
+
stream=self.stream,
|
| 1708 |
+
capture_error_mode="thread_local",
|
| 1709 |
+
):
|
| 1710 |
+
pass
|
| 1711 |
+
|
| 1712 |
+
self.graph_counter = itertools.count(0)
|
| 1713 |
+
self.func_counter = itertools.count(0)
|
| 1714 |
+
|
| 1715 |
+
# whether we the current node is in a state of warmup, recording, execution. If
|
| 1716 |
+
# there is no current node the state will be ExecutionState.None.
|
| 1717 |
+
self.path_state = ExecutionState.NONE
|
| 1718 |
+
self.device_index = device_index
|
| 1719 |
+
|
| 1720 |
+
# the most recently invoked cudagraph wrapping of a function. Will be None
|
| 1721 |
+
# when there is no output from a previous recording or execution whose memory
|
| 1722 |
+
# we need to respect in the cuda caching allocation. If you incremented generation,
|
| 1723 |
+
# this will also be none, as ignore those allocations.
|
| 1724 |
+
self.current_node: Optional[CUDAGraphNode] = None
|
| 1725 |
+
|
| 1726 |
+
# current generation of cudagraph invocations. when torch.compile is run
|
| 1727 |
+
# we increment the current generation. are willing to ignore live outputs
|
| 1728 |
+
# of a previous generation in checking liveness.
|
| 1729 |
+
self.current_gen: int = -1
|
| 1730 |
+
|
| 1731 |
+
# number of instances we are in execution and failed to match to an
|
| 1732 |
+
# existing child
|
| 1733 |
+
self.debug_fail_counter = 0
|
| 1734 |
+
# number of instances we had to checkpoint the function
|
| 1735 |
+
self.debug_checkpointing_counter = 0
|
| 1736 |
+
|
| 1737 |
+
self.id_to_mode: Dict[FunctionID, CompilationMode] = {}
|
| 1738 |
+
|
| 1739 |
+
# Note: [Backward Generation Handling]
|
| 1740 |
+
# We generally perform a sequence of forward executions followed by backward executions.
|
| 1741 |
+
# If multiple torch.compile wrapped forwards are executed with their backwards pending,
|
| 1742 |
+
# we should not disregard the outputs from a prior torch.compile since the entire training
|
| 1743 |
+
# loop hasn't completed. Occasionally, a backward pass corresponding to a forward pass may
|
| 1744 |
+
# not be executed, so we cannot wait for all pending forward pass backward completions, so
|
| 1745 |
+
# we cannot wait for all backwards to have been invoked. Instead we wait for a single backward
|
| 1746 |
+
# invocation. Triggering a backward pass typically doesn't lead to another torch.compile
|
| 1747 |
+
# invocation, making it less likely for the generation to increase between multiple
|
| 1748 |
+
# backward calls. The following use case is covered by this approach:
|
| 1749 |
+
# mod1 = torch.compile(...)
|
| 1750 |
+
# mod2 = torch.compile(...)
|
| 1751 |
+
# mod2(mod1(x)).sum().backward()
|
| 1752 |
+
|
| 1753 |
+
self.running_forwards_with_pending_backwards = False
|
| 1754 |
+
|
| 1755 |
+
def run(self, new_inputs: List[Tensor], function_id: FunctionID):
|
| 1756 |
+
assert self.graph is not None, "Running CUDAGraph after shutdown"
|
| 1757 |
+
out = self._run(new_inputs, function_id)
|
| 1758 |
+
|
| 1759 |
+
# The forwards are only pending following invocation, not before
|
| 1760 |
+
mode = self.id_to_mode[function_id]
|
| 1761 |
+
if mode == CompilationMode.FORWARD:
|
| 1762 |
+
self.running_forwards_with_pending_backwards = True
|
| 1763 |
+
elif mode == CompilationMode.BACKWARD:
|
| 1764 |
+
self.running_forwards_with_pending_backwards = False
|
| 1765 |
+
|
| 1766 |
+
return out
|
| 1767 |
+
|
| 1768 |
+
def set_to_running_backward(self):
|
| 1769 |
+
self.running_forwards_with_pending_backwards = False
|
| 1770 |
+
|
| 1771 |
+
def _run(self, new_inputs: List[Tensor], function_id: FunctionID):
|
| 1772 |
+
# we will try to end the current execution lazily, since
|
| 1773 |
+
# we dont want to do unnecessary checking of the existing outputs
|
| 1774 |
+
# on the hot path, but both recording and warmup only happen once
|
| 1775 |
+
# so we check up front
|
| 1776 |
+
if self.in_recording:
|
| 1777 |
+
self.try_end_curr_recording(function_id)
|
| 1778 |
+
|
| 1779 |
+
if self.in_warmup:
|
| 1780 |
+
self.try_end_curr_warmup(function_id)
|
| 1781 |
+
|
| 1782 |
+
# warming up a function and subsequentally recording may use different memory addresses
|
| 1783 |
+
# because both depend on the state of the caching allocator. if we warm up graph A,
|
| 1784 |
+
# then warm up graph B and make more allocations, the subsequent recording of A will not
|
| 1785 |
+
# necessarily use the same addresses as in the warm up. Thus any warm up of a node can only
|
| 1786 |
+
# be followed by warm up runs.
|
| 1787 |
+
if (
|
| 1788 |
+
not (
|
| 1789 |
+
function_id in self.warmed_up_functions
|
| 1790 |
+
or config.triton.skip_cudagraph_warmup
|
| 1791 |
+
)
|
| 1792 |
+
) or self.in_warmup:
|
| 1793 |
+
# If we are in the middle of executing cuda graphs, then we need to checkpoint memory state.
|
| 1794 |
+
# Both Recording and Warmup will be reflected in the allocator and dont need changes
|
| 1795 |
+
if self.path_state == ExecutionState.EXECUTION:
|
| 1796 |
+
self.apply_checkpoint_execution_state_in_allocator()
|
| 1797 |
+
|
| 1798 |
+
return self.run_eager(new_inputs, function_id)
|
| 1799 |
+
|
| 1800 |
+
child_nodes = (
|
| 1801 |
+
self.roots if self.current_node is None else self.current_node.children
|
| 1802 |
+
)
|
| 1803 |
+
|
| 1804 |
+
if not self.in_recording:
|
| 1805 |
+
for child in child_nodes[function_id]:
|
| 1806 |
+
# here we are checking memory consistency between recording and execution,
|
| 1807 |
+
# as well as things like stability of tensor locations, etc
|
| 1808 |
+
# and other
|
| 1809 |
+
if child.check_invariants(new_inputs):
|
| 1810 |
+
return self.execute_node(child, new_inputs)
|
| 1811 |
+
|
| 1812 |
+
# now that we know the new function can't be run as a child of the
|
| 1813 |
+
# current node, if it is a root, try to end the current execution.
|
| 1814 |
+
# as noted above, we want to do this lazily to avoid having to
|
| 1815 |
+
# check all existing outputs
|
| 1816 |
+
if self.current_node is not None and function_id in self.roots:
|
| 1817 |
+
self.try_end_curr_execution()
|
| 1818 |
+
|
| 1819 |
+
# run again to hit the root matching case which must succeed
|
| 1820 |
+
if self.current_node is None:
|
| 1821 |
+
return self.run(new_inputs, function_id)
|
| 1822 |
+
|
| 1823 |
+
# at this point, we necessarily will do a new recording
|
| 1824 |
+
self.debug_fail_counter += 1
|
| 1825 |
+
|
| 1826 |
+
self.try_end_curr_execution()
|
| 1827 |
+
if self.current_node is not None:
|
| 1828 |
+
self.apply_checkpoint_execution_state_in_allocator()
|
| 1829 |
+
|
| 1830 |
+
# now, we are in a recording state !
|
| 1831 |
+
return self.record_function(new_inputs, function_id)
|
| 1832 |
+
|
| 1833 |
+
def shutdown(self):
|
| 1834 |
+
"""
|
| 1835 |
+
Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn
|
| 1836 |
+
might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown
|
| 1837 |
+
to avoid a reference cycle.
|
| 1838 |
+
"""
|
| 1839 |
+
nodes = []
|
| 1840 |
+
for roots in self.roots.values():
|
| 1841 |
+
nodes.extend(roots)
|
| 1842 |
+
|
| 1843 |
+
while nodes:
|
| 1844 |
+
node = nodes.pop()
|
| 1845 |
+
for children in node.children.values():
|
| 1846 |
+
nodes.extend(children)
|
| 1847 |
+
node.remove_node_cached_tensors()
|
| 1848 |
+
node.graph = None
|
| 1849 |
+
|
| 1850 |
+
self.graph = None
|
| 1851 |
+
self.roots = None # type: ignore[assignment]
|
| 1852 |
+
self.current_node = None
|
| 1853 |
+
|
| 1854 |
+
def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]:
|
| 1855 |
+
graph_id = self.new_graph_id()
|
| 1856 |
+
log.debug(
|
| 1857 |
+
"Recording function %d of graph recording id %d",
|
| 1858 |
+
function_id.id,
|
| 1859 |
+
graph_id.id,
|
| 1860 |
+
)
|
| 1861 |
+
torch.cuda.synchronize()
|
| 1862 |
+
node = CUDAGraphNode(
|
| 1863 |
+
self.ids_to_funcs[function_id],
|
| 1864 |
+
graph_id,
|
| 1865 |
+
self.current_node,
|
| 1866 |
+
new_inputs,
|
| 1867 |
+
self.cuda_graphs_thread_pool,
|
| 1868 |
+
self.device_index,
|
| 1869 |
+
self.ids_to_stack_traces[function_id],
|
| 1870 |
+
self.stream,
|
| 1871 |
+
)
|
| 1872 |
+
if self.current_node is None:
|
| 1873 |
+
self.roots[function_id].append(node)
|
| 1874 |
+
else:
|
| 1875 |
+
self.current_node.add_child(function_id, node)
|
| 1876 |
+
self.current_node = node
|
| 1877 |
+
self.path_state = ExecutionState.RECORDING
|
| 1878 |
+
self.update_generation()
|
| 1879 |
+
torch.cuda.synchronize()
|
| 1880 |
+
return node.run_first_inputs(new_inputs)
|
| 1881 |
+
|
| 1882 |
+
def execute_node(self, node: CUDAGraphNode, new_inputs) -> List[Optional[Tensor]]:
|
| 1883 |
+
self.current_node = node
|
| 1884 |
+
self.path_state = ExecutionState.EXECUTION
|
| 1885 |
+
self.update_generation()
|
| 1886 |
+
return node.run(new_inputs)
|
| 1887 |
+
|
| 1888 |
+
def run_eager(self, new_inputs, function_id: FunctionID):
|
| 1889 |
+
# this is only stored on current node, because when we start a new path,
|
| 1890 |
+
# we will deallocate it
|
| 1891 |
+
already_warm = function_id in self.warmed_up_functions
|
| 1892 |
+
if not already_warm:
|
| 1893 |
+
log.debug("Running warmup of function %d", function_id.id)
|
| 1894 |
+
else:
|
| 1895 |
+
log.debug(
|
| 1896 |
+
"Running eager of function %d because ancestor needed to warm up",
|
| 1897 |
+
function_id.id,
|
| 1898 |
+
)
|
| 1899 |
+
self.warmed_up_functions.add(function_id)
|
| 1900 |
+
node = CUDAWarmupNode(
|
| 1901 |
+
self.ids_to_funcs[function_id],
|
| 1902 |
+
self.current_node,
|
| 1903 |
+
self.cuda_graphs_thread_pool,
|
| 1904 |
+
self.graph,
|
| 1905 |
+
self.device_index,
|
| 1906 |
+
self.ids_to_stack_traces[function_id],
|
| 1907 |
+
self.stream,
|
| 1908 |
+
already_warm,
|
| 1909 |
+
)
|
| 1910 |
+
self.current_node = node
|
| 1911 |
+
self.path_state = ExecutionState.WARMUP
|
| 1912 |
+
self.update_generation()
|
| 1913 |
+
return node.run(new_inputs)
|
| 1914 |
+
|
| 1915 |
+
def new_graph_id(self) -> GraphID:
|
| 1916 |
+
return GraphID(next(self.graph_counter))
|
| 1917 |
+
|
| 1918 |
+
def new_func_id(self) -> FunctionID:
|
| 1919 |
+
return FunctionID(next(self.func_counter))
|
| 1920 |
+
|
| 1921 |
+
def add_function(
|
| 1922 |
+
self,
|
| 1923 |
+
model,
|
| 1924 |
+
inputs,
|
| 1925 |
+
static_input_idxs,
|
| 1926 |
+
stack_traces,
|
| 1927 |
+
mode,
|
| 1928 |
+
constants,
|
| 1929 |
+
) -> Tuple[Callable[..., Any], List[Optional[Tensor]]]:
|
| 1930 |
+
id = self.new_func_id()
|
| 1931 |
+
self.ids_to_stack_traces[id] = stack_traces
|
| 1932 |
+
self.ids_to_funcs[id] = WrappedFunction(
|
| 1933 |
+
model,
|
| 1934 |
+
static_input_idxs,
|
| 1935 |
+
id,
|
| 1936 |
+
tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda),
|
| 1937 |
+
)
|
| 1938 |
+
self.id_to_mode[id] = mode
|
| 1939 |
+
fn = functools.partial(self.run, function_id=id)
|
| 1940 |
+
|
| 1941 |
+
# container needs to set clean up when fn dies
|
| 1942 |
+
get_container(self.device_index).add_strong_reference(fn)
|
| 1943 |
+
return fn, fn(inputs)
|
| 1944 |
+
|
| 1945 |
+
@property
|
| 1946 |
+
def in_recording(self):
|
| 1947 |
+
return self.path_state == ExecutionState.RECORDING
|
| 1948 |
+
|
| 1949 |
+
@property
|
| 1950 |
+
def in_warmup(self):
|
| 1951 |
+
return self.path_state == ExecutionState.WARMUP
|
| 1952 |
+
|
| 1953 |
+
def get_roots(self) -> Iterator[CUDAGraphNode]:
|
| 1954 |
+
for nodes in self.roots.values():
|
| 1955 |
+
yield from nodes
|
| 1956 |
+
|
| 1957 |
+
@property
|
| 1958 |
+
def current_node(self):
|
| 1959 |
+
return self._current_node
|
| 1960 |
+
|
| 1961 |
+
@current_node.setter
|
| 1962 |
+
def current_node(self, value):
|
| 1963 |
+
self._current_node = value
|
| 1964 |
+
if value is None:
|
| 1965 |
+
self.path_state = ExecutionState.NONE
|
| 1966 |
+
|
| 1967 |
+
def update_generation(self):
|
| 1968 |
+
self.current_gen = self.get_curr_generation()
|
| 1969 |
+
|
| 1970 |
+
@staticmethod
|
| 1971 |
+
def get_curr_generation() -> int:
|
| 1972 |
+
if MarkStepBox.mark_step_counter != 0:
|
| 1973 |
+
return MarkStepBox.mark_step_counter
|
| 1974 |
+
|
| 1975 |
+
return GenerationTracker.generation
|
| 1976 |
+
|
| 1977 |
+
@staticmethod
|
| 1978 |
+
def user_invoked_mark_step():
|
| 1979 |
+
return MarkStepBox.mark_step_counter != 0
|
| 1980 |
+
|
| 1981 |
+
def can_start_new_generation(self) -> bool:
|
| 1982 |
+
if not self.in_new_torch_compile_invocation():
|
| 1983 |
+
return False
|
| 1984 |
+
|
| 1985 |
+
if self.user_invoked_mark_step():
|
| 1986 |
+
return True
|
| 1987 |
+
|
| 1988 |
+
return not self.running_forwards_with_pending_backwards
|
| 1989 |
+
|
| 1990 |
+
def in_new_torch_compile_invocation(self):
|
| 1991 |
+
return self.current_gen != self.get_curr_generation()
|
| 1992 |
+
|
| 1993 |
+
def try_end_curr_recording(self, function_id: FunctionID) -> None:
|
| 1994 |
+
"""
|
| 1995 |
+
Check if the current recording can be terminated, either because all outputs of the
|
| 1996 |
+
previously recorded node are dead or because it was executed in a different
|
| 1997 |
+
generation. Will set current_node to None and in_recording to False if successful.
|
| 1998 |
+
"""
|
| 1999 |
+
assert self.in_recording
|
| 2000 |
+
assert self.current_node is not None
|
| 2001 |
+
|
| 2002 |
+
# multiple invocations, allow overwriting the previous generation
|
| 2003 |
+
if self.can_start_new_generation():
|
| 2004 |
+
self.dealloc_current_path_weakrefs()
|
| 2005 |
+
self.clear_current_path_state_and_set_to_none()
|
| 2006 |
+
return
|
| 2007 |
+
|
| 2008 |
+
if self.current_node.all_outputs_are_dead():
|
| 2009 |
+
self.clear_current_path_state_and_set_to_none()
|
| 2010 |
+
return
|
| 2011 |
+
|
| 2012 |
+
self.check_warn_on_unable_to_start_executing(function_id)
|
| 2013 |
+
|
| 2014 |
+
def try_end_curr_execution(self) -> None:
|
| 2015 |
+
"""
|
| 2016 |
+
Check if the current executing node can be terminated, either because all outputs of the
|
| 2017 |
+
previously executed node are dead or because it was executed in a different generation.
|
| 2018 |
+
Will set current_node to None if successful.
|
| 2019 |
+
"""
|
| 2020 |
+
|
| 2021 |
+
assert not self.in_recording
|
| 2022 |
+
if self.current_node is None:
|
| 2023 |
+
return
|
| 2024 |
+
|
| 2025 |
+
if self.can_start_new_generation():
|
| 2026 |
+
self.clear_current_path_state_and_set_to_none()
|
| 2027 |
+
return
|
| 2028 |
+
|
| 2029 |
+
if self.current_node.all_outputs_are_dead():
|
| 2030 |
+
self.clear_current_path_state_and_set_to_none()
|
| 2031 |
+
|
| 2032 |
+
def try_end_curr_warmup(self, function_id: FunctionID):
|
| 2033 |
+
if self.can_start_new_generation():
|
| 2034 |
+
self.dealloc_current_path_weakrefs()
|
| 2035 |
+
self.current_node = None
|
| 2036 |
+
return
|
| 2037 |
+
|
| 2038 |
+
if self.current_node.all_outputs_are_dead():
|
| 2039 |
+
self.current_node = None
|
| 2040 |
+
return
|
| 2041 |
+
|
| 2042 |
+
self.check_warn_on_unable_to_start_executing(function_id)
|
| 2043 |
+
|
| 2044 |
+
def check_warn_on_unable_to_start_executing(self, function_id: FunctionID):
|
| 2045 |
+
"Warn if we in a potential loop where we are unable to hit fast path"
|
| 2046 |
+
if (
|
| 2047 |
+
function_id in self.warned_functions
|
| 2048 |
+
or not self.in_new_torch_compile_invocation()
|
| 2049 |
+
):
|
| 2050 |
+
return
|
| 2051 |
+
|
| 2052 |
+
existing_nodes = [
|
| 2053 |
+
node
|
| 2054 |
+
for node in self.current_node._path_from_root
|
| 2055 |
+
if node.wrapped_function.id == function_id
|
| 2056 |
+
]
|
| 2057 |
+
|
| 2058 |
+
if len(existing_nodes) <= 1:
|
| 2059 |
+
return
|
| 2060 |
+
|
| 2061 |
+
# repeated same pattern
|
| 2062 |
+
parents = {
|
| 2063 |
+
n.parent.wrapped_function.id
|
| 2064 |
+
for n in itertools.chain(existing_nodes, (self.current_node,))
|
| 2065 |
+
if n.parent is not None
|
| 2066 |
+
}
|
| 2067 |
+
if len(parents) == len(existing_nodes):
|
| 2068 |
+
return
|
| 2069 |
+
|
| 2070 |
+
self.warned_functions.add(function_id)
|
| 2071 |
+
warnings.warn(
|
| 2072 |
+
"Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. "
|
| 2073 |
+
"Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() "
|
| 2074 |
+
"before each model invocation"
|
| 2075 |
+
)
|
| 2076 |
+
|
| 2077 |
+
def dealloc_current_path_weakrefs(self):
|
| 2078 |
+
# TODO: we could also allow the these weak refs to continue to be allocated,
|
| 2079 |
+
# but that adds some complications.
|
| 2080 |
+
for node in self.current_node._path_from_root:
|
| 2081 |
+
assert len(node.tensor_weakrefs) == len(node.stack_traces)
|
| 2082 |
+
for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces):
|
| 2083 |
+
ten = None if t is None else t()
|
| 2084 |
+
if ten is None:
|
| 2085 |
+
continue
|
| 2086 |
+
|
| 2087 |
+
stack_trace = (
|
| 2088 |
+
stack_trace.strip()
|
| 2089 |
+
if stack_trace
|
| 2090 |
+
else "[Could not find stack trace]"
|
| 2091 |
+
)
|
| 2092 |
+
msg = (
|
| 2093 |
+
"Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. "
|
| 2094 |
+
f"Stack trace: {stack_trace}. "
|
| 2095 |
+
"To prevent overwriting, clone the tensor outside of torch.compile() "
|
| 2096 |
+
"or call torch.compiler.cudagraph_mark_step_begin() before each model invocation."
|
| 2097 |
+
)
|
| 2098 |
+
torch._C._set_storage_access_error_msg(ten, msg)
|
| 2099 |
+
|
| 2100 |
+
deleted = set()
|
| 2101 |
+
for storage_ref in self.current_node.path_live_weakrefs():
|
| 2102 |
+
if storage_ref() and storage_ref.data_ptr() not in deleted:
|
| 2103 |
+
deleted.add(storage_ref.data_ptr())
|
| 2104 |
+
torch._C._free_And_Remove_DeleterFn(storage_ref())
|
| 2105 |
+
|
| 2106 |
+
def clear_current_path_state_and_set_to_none(self):
|
| 2107 |
+
self.current_node.clear_path_state()
|
| 2108 |
+
self.current_node = None
|
| 2109 |
+
|
| 2110 |
+
def apply_checkpoint_execution_state_in_allocator(self):
|
| 2111 |
+
"""
|
| 2112 |
+
Checkpoint the current execution state in the caching allocator so that
|
| 2113 |
+
additional cudagraph recordings can be made respecting existent live storages.
|
| 2114 |
+
"""
|
| 2115 |
+
self.debug_checkpointing_counter += 1
|
| 2116 |
+
log.debug(
|
| 2117 |
+
"Checkpointing cuda caching allocator state. Number of checkpoints %d",
|
| 2118 |
+
self.debug_checkpointing_counter,
|
| 2119 |
+
)
|
| 2120 |
+
|
| 2121 |
+
state = self.current_node.checkpointed_caching_state
|
| 2122 |
+
device = self.current_node.device
|
| 2123 |
+
assert state is not None and device is not None
|
| 2124 |
+
|
| 2125 |
+
# currently we deallocate on instead of allowing stale recordings
|
| 2126 |
+
stale_storages: List[int] = []
|
| 2127 |
+
|
| 2128 |
+
# remove cached tensors, otherwise they would prevent memory from being
|
| 2129 |
+
# reclaimed in subsequent recordings
|
| 2130 |
+
self.current_node.remove_path_cached_tensors()
|
| 2131 |
+
live_storages_wrappers = list(self.current_node.path_live_weakrefs())
|
| 2132 |
+
|
| 2133 |
+
live_storages_weak_refs = [t() for t in live_storages_wrappers]
|
| 2134 |
+
ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation()
|
| 2135 |
+
torch._C._cuda_setCheckpointPoolState(
|
| 2136 |
+
device, state, stale_storages, live_storages_weak_refs
|
| 2137 |
+
)
|
| 2138 |
+
|
| 2139 |
+
# NB: deduplicate aliased outputs
|
| 2140 |
+
for ptr in set(ptrs_to_deallocate):
|
| 2141 |
+
torch._C._cuda_cudaCachingAllocator_raw_delete(ptr)
|
| 2142 |
+
|
| 2143 |
+
# Now the live blocks should be exactly equal to the live storages in private pool
|
| 2144 |
+
if config.triton.slow_path_cudagraph_asserts:
|
| 2145 |
+
check_memory_pool(
|
| 2146 |
+
self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers
|
| 2147 |
+
)
|
| 2148 |
+
for wrapper in live_storages_wrappers:
|
| 2149 |
+
assert wrapper()
|
| 2150 |
+
assert torch._C._has_Standard_Deleter(wrapper())
|
| 2151 |
+
assert wrapper.data_ptr() not in ptrs_to_deallocate
|
| 2152 |
+
|
| 2153 |
+
def live_cudagraph_pool_storages_in_curr_execution(
|
| 2154 |
+
self,
|
| 2155 |
+
) -> List[StorageWeakRefPointer]:
|
| 2156 |
+
if self.current_node is None:
|
| 2157 |
+
return []
|
| 2158 |
+
# explicitly ignoring previous recorded outputs from past path
|
| 2159 |
+
return [t() for t in self.current_node.path_live_weakrefs()]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/hooks.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
from typing import Callable, List, TYPE_CHECKING
|
| 3 |
+
|
| 4 |
+
if TYPE_CHECKING:
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
# Executed in the order they're registered
|
| 8 |
+
INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@contextlib.contextmanager
|
| 12 |
+
def intermediate_hook(fn):
|
| 13 |
+
INTERMEDIATE_HOOKS.append(fn)
|
| 14 |
+
try:
|
| 15 |
+
yield
|
| 16 |
+
finally:
|
| 17 |
+
INTERMEDIATE_HOOKS.pop()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run_intermediate_hooks(name, val):
|
| 21 |
+
global INTERMEDIATE_HOOKS
|
| 22 |
+
hooks = INTERMEDIATE_HOOKS
|
| 23 |
+
INTERMEDIATE_HOOKS = []
|
| 24 |
+
try:
|
| 25 |
+
for hook in hooks:
|
| 26 |
+
hook(name, val)
|
| 27 |
+
finally:
|
| 28 |
+
INTERMEDIATE_HOOKS = hooks
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/ops_handler.py
ADDED
|
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
from typing import Any, Callable, Generic, Literal, Optional, Tuple, TypeVar, Union
|
| 3 |
+
from unittest.mock import patch
|
| 4 |
+
|
| 5 |
+
import sympy
|
| 6 |
+
from typing_extensions import Protocol
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils._pytree as pytree
|
| 10 |
+
from torch.fx.graph import inplace_methods, magic_methods
|
| 11 |
+
from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str
|
| 12 |
+
|
| 13 |
+
T = TypeVar("T")
|
| 14 |
+
StoreMode = Optional[Literal["atomic_add"]]
|
| 15 |
+
ReductionType = Literal[
|
| 16 |
+
"argmax",
|
| 17 |
+
"argmin",
|
| 18 |
+
"welford_reduce",
|
| 19 |
+
"welford_combine",
|
| 20 |
+
"any",
|
| 21 |
+
"max",
|
| 22 |
+
"min",
|
| 23 |
+
"prod",
|
| 24 |
+
"sum",
|
| 25 |
+
"xor_sum",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _arg_str(a) -> str:
|
| 30 |
+
if isinstance(a, sympy.Expr):
|
| 31 |
+
return sympy_str(a)
|
| 32 |
+
return str(a)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# NB: This is not done as a parent class, because our ops handlers
|
| 36 |
+
# implementations make heavy use of __getattr__ magic, and pre-existing
|
| 37 |
+
# stubs for methods would interfere with this mechanism.
|
| 38 |
+
#
|
| 39 |
+
# TODO: A superclass that does desugaring for operations like
|
| 40 |
+
# reciprocal/square might be useful.
|
| 41 |
+
class OpsHandler(Protocol[T]):
|
| 42 |
+
"""
|
| 43 |
+
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
|
| 44 |
+
as well as the contract for op handlers. The type T signifies the domain
|
| 45 |
+
of the abstract analysis AKA what all of the functions return / take as arguments
|
| 46 |
+
anywhere compute occurs.
|
| 47 |
+
|
| 48 |
+
While these operators are typically dtype polymorphic (e.g., you can use mul
|
| 49 |
+
on both integers and floats), they do NOT do promotion and usually return the
|
| 50 |
+
same dtype as the input. You are expected to have handled type promotion
|
| 51 |
+
during ATen decompositions. Most operators correspond exactly to pointwise
|
| 52 |
+
operations as defined by torch, so when in doubt about semantics, check the
|
| 53 |
+
corresponding torch documentation. These are all scalar operations (so they
|
| 54 |
+
are defined to operate on a single element at a time.)
|
| 55 |
+
|
| 56 |
+
For convenience, many operators take a src_dtype which indicates what the dtype
|
| 57 |
+
of the input argument is. Although in principle this can be derived by an
|
| 58 |
+
analysis, providing this for ops where it is useful helps avoid having to repeatedly
|
| 59 |
+
recompute dtype in code generation.
|
| 60 |
+
|
| 61 |
+
Note that this often describes a class of static methods, for stateless
|
| 62 |
+
ops handlers.
|
| 63 |
+
|
| 64 |
+
Handlers are often defined using ``__getattr__`` metaprogramming, which means
|
| 65 |
+
that you cannot declare that a type implements a protocol by inheriting from
|
| 66 |
+
it (as the type stubs count as attribute declarations and impede the getattr
|
| 67 |
+
magic method from being called). Instead, define a function that casts an
|
| 68 |
+
argument of your type to the protocol, which is sufficient to induce mypy to
|
| 69 |
+
test that the protocol is implemented correctly. Search for ``_typecheck_``
|
| 70 |
+
in this file to see some examples. If you see an obscure error where a
|
| 71 |
+
class doesn't implement a Protocol, but mypy doesn't say why, check to see
|
| 72 |
+
that ``__getattr__`` is typed correctly (typically, it is not possible to
|
| 73 |
+
type ``__getattr__`` without typing it as ``Callable[..., Any]``)
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T:
|
| 77 |
+
"""Produces a scalar constant of type dtype."""
|
| 78 |
+
...
|
| 79 |
+
|
| 80 |
+
def load_seed(self, name: str, offset: T):
|
| 81 |
+
"""Computes inductor_prims.lookup_seed."""
|
| 82 |
+
...
|
| 83 |
+
|
| 84 |
+
def rand(self, seed: T, offset: T) -> T:
|
| 85 |
+
"""Computes inductor_prims.random with mode="rand". offset has dtype int32."""
|
| 86 |
+
...
|
| 87 |
+
|
| 88 |
+
def randn(self, seed: T, offset: T) -> T:
|
| 89 |
+
"""Computes inductor_prims.random with mode="randn". offset has dtype int32."""
|
| 90 |
+
...
|
| 91 |
+
|
| 92 |
+
def randint64(self, seed: T, offset: T, low: T, high: T) -> T:
|
| 93 |
+
"""Computes inductor_prims.randint. offset has dtype int32."""
|
| 94 |
+
...
|
| 95 |
+
|
| 96 |
+
def masked(self, mask: T, body: Callable[[], T], other: T) -> T:
|
| 97 |
+
"""
|
| 98 |
+
Computes body, but only perform loads/stores if the boolean mask
|
| 99 |
+
evaluates to true. For example, you would use this if you needed to
|
| 100 |
+
perform an indirect load that may not be valid on some elements;
|
| 101 |
+
without masking, invalid accesses can cause IMAs. When mask is true,
|
| 102 |
+
the result is the result of body; otherwise it is other.
|
| 103 |
+
|
| 104 |
+
Contrast this with ops.where, which can multiplex between two values
|
| 105 |
+
that have been unconditionally computed.
|
| 106 |
+
"""
|
| 107 |
+
...
|
| 108 |
+
|
| 109 |
+
def where(self, condition: T, input: T, other: T) -> T:
|
| 110 |
+
"""
|
| 111 |
+
Computes torch.where: when condition is true, return input; otherwise return other.
|
| 112 |
+
"""
|
| 113 |
+
...
|
| 114 |
+
|
| 115 |
+
def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T:
|
| 116 |
+
"""
|
| 117 |
+
Converts a sympy expression into a scalar of type dtype. expr is typically
|
| 118 |
+
an indexing expression, thus the name; however, it can also be used in
|
| 119 |
+
non-indexing situations.
|
| 120 |
+
"""
|
| 121 |
+
...
|
| 122 |
+
|
| 123 |
+
def to_dtype(
|
| 124 |
+
self, x: T, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None
|
| 125 |
+
) -> T:
|
| 126 |
+
"""
|
| 127 |
+
Convert x to dtype. src_dtype can be optionally set to specify what the original
|
| 128 |
+
dtype of x was, which can improve code generation (used by torch to(dtype=dtype)).
|
| 129 |
+
"""
|
| 130 |
+
...
|
| 131 |
+
|
| 132 |
+
def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T:
|
| 133 |
+
"""
|
| 134 |
+
Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.)
|
| 135 |
+
src_dtype must be the original type of x.
|
| 136 |
+
"""
|
| 137 |
+
...
|
| 138 |
+
|
| 139 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 140 |
+
# These operations are only available in a "kernel" context. Check
|
| 141 |
+
# torch._inductor.codegen.common.CSEProxy for their typical implementation
|
| 142 |
+
# in op handler (routing to their respective implementations in the kernel
|
| 143 |
+
# handler)
|
| 144 |
+
#
|
| 145 |
+
# Importantly, inside a kernel, indexing and mask variables are available
|
| 146 |
+
# in scope, which are typically used by sympy.Expr indexing.
|
| 147 |
+
|
| 148 |
+
def indirect_indexing(
|
| 149 |
+
self, x: T, size: sympy.Expr, check: bool = True
|
| 150 |
+
) -> sympy.Expr:
|
| 151 |
+
"""
|
| 152 |
+
Convert an integral x into a sympy.Expr that can be subsequently used in
|
| 153 |
+
indexing computation. 'size' represents an upper bound on the what valid
|
| 154 |
+
indexes can be; when 'check' is True, we check that the x is in bounds.
|
| 155 |
+
|
| 156 |
+
NB: This is typically mandatory to implement for any analysis, because you
|
| 157 |
+
MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol).
|
| 158 |
+
"""
|
| 159 |
+
...
|
| 160 |
+
|
| 161 |
+
def load(self, name: str, index: sympy.Expr) -> T:
|
| 162 |
+
"""
|
| 163 |
+
Load from the memory location 'name', offset by some indexing expression 'index'.
|
| 164 |
+
"""
|
| 165 |
+
...
|
| 166 |
+
|
| 167 |
+
def store(
|
| 168 |
+
self,
|
| 169 |
+
name: str,
|
| 170 |
+
index: sympy.Expr,
|
| 171 |
+
value: T,
|
| 172 |
+
mode: StoreMode = None,
|
| 173 |
+
) -> None:
|
| 174 |
+
"""
|
| 175 |
+
Store 'value' to the memory location 'name' offset by 'expr'. If
|
| 176 |
+
specified, 'mode' can require the store to be an atomic addition.
|
| 177 |
+
"""
|
| 178 |
+
...
|
| 179 |
+
|
| 180 |
+
# TODO: Better explain how the "collective" semantics of these ops;
|
| 181 |
+
# remember that the input value is a scalar, you can't reduce on it in the
|
| 182 |
+
# traditional sense!
|
| 183 |
+
def reduction(
|
| 184 |
+
self,
|
| 185 |
+
dtype: torch.dtype,
|
| 186 |
+
src_dtype: torch.dtype,
|
| 187 |
+
reduction_type: ReductionType,
|
| 188 |
+
value: T,
|
| 189 |
+
) -> Union[T, Tuple[T, ...]]:
|
| 190 |
+
"""
|
| 191 |
+
Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype',
|
| 192 |
+
using 'dtype' as the accumulation dtype for the reduction. The result
|
| 193 |
+
is an intermediate computation which should be stored to the final
|
| 194 |
+
location using 'ops.store_reduction'.
|
| 195 |
+
|
| 196 |
+
Valid reduction types are . For Welford reduction types, this
|
| 197 |
+
function returns multiple outputs; consult reduction_num_outputs to
|
| 198 |
+
determine the amount in metaprogramming applications.
|
| 199 |
+
"""
|
| 200 |
+
...
|
| 201 |
+
|
| 202 |
+
# TODO: in practice, this seems to actually return None, but not returning
|
| 203 |
+
# a T makes common __getattr__ idioms not type correctly. Figure out if
|
| 204 |
+
# this should be returning something.
|
| 205 |
+
def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T:
|
| 206 |
+
"""
|
| 207 |
+
Store the fully accumulated result of 'reduction' to the memory
|
| 208 |
+
location 'name' offset by 'expr'.
|
| 209 |
+
"""
|
| 210 |
+
...
|
| 211 |
+
|
| 212 |
+
def scan(
|
| 213 |
+
self, dtype: torch.dtype, combine_fn: Callable[[T, T], T], value: T, init: int
|
| 214 |
+
) -> T:
|
| 215 |
+
"""
|
| 216 |
+
Perform an associative scan on 'value'.
|
| 217 |
+
"""
|
| 218 |
+
# TODO: Improve the description with some pseudocode
|
| 219 |
+
...
|
| 220 |
+
|
| 221 |
+
def bucketize(
|
| 222 |
+
self,
|
| 223 |
+
values: T,
|
| 224 |
+
offsets_name: str,
|
| 225 |
+
offsets_size: sympy.Expr,
|
| 226 |
+
indexing_dtype: torch.dtype,
|
| 227 |
+
right: bool,
|
| 228 |
+
) -> T:
|
| 229 |
+
# See [Note: Inductor bucketize op]
|
| 230 |
+
...
|
| 231 |
+
|
| 232 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 233 |
+
# The following ops have semantics that correspond exactly to the torch
|
| 234 |
+
# operation with the same corresponding name.
|
| 235 |
+
|
| 236 |
+
def abs(self, x0: T) -> T:
|
| 237 |
+
...
|
| 238 |
+
|
| 239 |
+
def exp(self, x0: T) -> T:
|
| 240 |
+
...
|
| 241 |
+
|
| 242 |
+
def exp2(self, x0: T) -> T:
|
| 243 |
+
...
|
| 244 |
+
|
| 245 |
+
def expm1(self, x0: T) -> T:
|
| 246 |
+
...
|
| 247 |
+
|
| 248 |
+
def sqrt(self, x0: T) -> T:
|
| 249 |
+
...
|
| 250 |
+
|
| 251 |
+
def relu(self, x0: T) -> T:
|
| 252 |
+
...
|
| 253 |
+
|
| 254 |
+
def minimum(self, x0: T, x1: T) -> T:
|
| 255 |
+
...
|
| 256 |
+
|
| 257 |
+
def maximum(self, x0: T, x1: T) -> T:
|
| 258 |
+
...
|
| 259 |
+
|
| 260 |
+
def cos(self, x0: T) -> T:
|
| 261 |
+
...
|
| 262 |
+
|
| 263 |
+
def sin(self, x0: T) -> T:
|
| 264 |
+
...
|
| 265 |
+
|
| 266 |
+
def lgamma(self, x0: T) -> T:
|
| 267 |
+
...
|
| 268 |
+
|
| 269 |
+
def erf(self, x0: T) -> T:
|
| 270 |
+
...
|
| 271 |
+
|
| 272 |
+
def cosh(self, x0: T) -> T:
|
| 273 |
+
...
|
| 274 |
+
|
| 275 |
+
def sinh(self, x0: T) -> T:
|
| 276 |
+
...
|
| 277 |
+
|
| 278 |
+
def acos(self, x0: T) -> T:
|
| 279 |
+
...
|
| 280 |
+
|
| 281 |
+
def acosh(self, x0: T) -> T:
|
| 282 |
+
...
|
| 283 |
+
|
| 284 |
+
def asin(self, x0: T) -> T:
|
| 285 |
+
...
|
| 286 |
+
|
| 287 |
+
def asinh(self, x0: T) -> T:
|
| 288 |
+
...
|
| 289 |
+
|
| 290 |
+
def atan2(self, x0: T, x1: T) -> T:
|
| 291 |
+
...
|
| 292 |
+
|
| 293 |
+
def atan(self, x0: T) -> T:
|
| 294 |
+
...
|
| 295 |
+
|
| 296 |
+
def atanh(self, x0: T) -> T:
|
| 297 |
+
...
|
| 298 |
+
|
| 299 |
+
def copysign(self, x0: T, x1: T) -> T:
|
| 300 |
+
...
|
| 301 |
+
|
| 302 |
+
def erfc(self, x0: T) -> T:
|
| 303 |
+
...
|
| 304 |
+
|
| 305 |
+
def erfinv(self, x0: T) -> T:
|
| 306 |
+
...
|
| 307 |
+
|
| 308 |
+
def frexp(self, x0: T):
|
| 309 |
+
...
|
| 310 |
+
|
| 311 |
+
def hypot(self, x0: T, x1: T) -> T:
|
| 312 |
+
...
|
| 313 |
+
|
| 314 |
+
def log10(self, x0: T) -> T:
|
| 315 |
+
...
|
| 316 |
+
|
| 317 |
+
def nextafter(self, x0: T, x1: T) -> T:
|
| 318 |
+
...
|
| 319 |
+
|
| 320 |
+
def logical_and(self, x0: T, x1: T) -> T:
|
| 321 |
+
...
|
| 322 |
+
|
| 323 |
+
def logical_not(self, x0: T) -> T:
|
| 324 |
+
...
|
| 325 |
+
|
| 326 |
+
def logical_or(self, x0: T, x1: T) -> T:
|
| 327 |
+
...
|
| 328 |
+
|
| 329 |
+
def logical_xor(self, x0: T, x1: T) -> T:
|
| 330 |
+
...
|
| 331 |
+
|
| 332 |
+
def bitwise_and(self, x0: T, x1: T) -> T:
|
| 333 |
+
...
|
| 334 |
+
|
| 335 |
+
def bitwise_not(self, x0: T) -> T:
|
| 336 |
+
...
|
| 337 |
+
|
| 338 |
+
def bitwise_or(self, x0: T, x1: T) -> T:
|
| 339 |
+
...
|
| 340 |
+
|
| 341 |
+
def bitwise_xor(self, x0: T, x1: T) -> T:
|
| 342 |
+
...
|
| 343 |
+
|
| 344 |
+
def bitwise_left_shift(self, x0: T, x1: T) -> T:
|
| 345 |
+
...
|
| 346 |
+
|
| 347 |
+
def bitwise_right_shift(self, x0: T, x1: T) -> T:
|
| 348 |
+
...
|
| 349 |
+
|
| 350 |
+
def rsqrt(self, x0: T) -> T:
|
| 351 |
+
...
|
| 352 |
+
|
| 353 |
+
def log1p(self, x0: T) -> T:
|
| 354 |
+
...
|
| 355 |
+
|
| 356 |
+
def tan(self, x0: T) -> T:
|
| 357 |
+
...
|
| 358 |
+
|
| 359 |
+
def tanh(self, x0: T) -> T:
|
| 360 |
+
...
|
| 361 |
+
|
| 362 |
+
def sigmoid(self, x0: T) -> T:
|
| 363 |
+
...
|
| 364 |
+
|
| 365 |
+
def signbit(self, x0: T) -> T:
|
| 366 |
+
...
|
| 367 |
+
|
| 368 |
+
def fmod(self, x0: T, x1: T) -> T:
|
| 369 |
+
...
|
| 370 |
+
|
| 371 |
+
def log(self, x0: T) -> T:
|
| 372 |
+
...
|
| 373 |
+
|
| 374 |
+
def isinf(self, x0: T) -> T:
|
| 375 |
+
...
|
| 376 |
+
|
| 377 |
+
def isnan(self, x0: T) -> T:
|
| 378 |
+
...
|
| 379 |
+
|
| 380 |
+
def round(self, x0: T) -> T:
|
| 381 |
+
...
|
| 382 |
+
|
| 383 |
+
def floor(self, x0: T) -> T:
|
| 384 |
+
...
|
| 385 |
+
|
| 386 |
+
def sign(self, x0: T) -> T:
|
| 387 |
+
...
|
| 388 |
+
|
| 389 |
+
def to_int(self, x0: T) -> T:
|
| 390 |
+
...
|
| 391 |
+
|
| 392 |
+
def trunc(self, x0: T) -> T:
|
| 393 |
+
...
|
| 394 |
+
|
| 395 |
+
def truncdiv(self, x0: T, x1: T) -> T:
|
| 396 |
+
...
|
| 397 |
+
|
| 398 |
+
def ceil(self, x0: T) -> T:
|
| 399 |
+
...
|
| 400 |
+
|
| 401 |
+
def neg(self, x0: T) -> T:
|
| 402 |
+
...
|
| 403 |
+
|
| 404 |
+
def reciprocal(self, x0: T) -> T:
|
| 405 |
+
...
|
| 406 |
+
|
| 407 |
+
def eq(self, x0: T, x1: T) -> T:
|
| 408 |
+
...
|
| 409 |
+
|
| 410 |
+
def ne(self, x0: T, x1: T) -> T:
|
| 411 |
+
...
|
| 412 |
+
|
| 413 |
+
def lt(self, x0: T, x1: T) -> T:
|
| 414 |
+
...
|
| 415 |
+
|
| 416 |
+
def gt(self, x0: T, x1: T) -> T:
|
| 417 |
+
...
|
| 418 |
+
|
| 419 |
+
def le(self, x0: T, x1: T) -> T:
|
| 420 |
+
...
|
| 421 |
+
|
| 422 |
+
def ge(self, x0: T, x1: T) -> T:
|
| 423 |
+
...
|
| 424 |
+
|
| 425 |
+
def add(self, x0: T, x1: T) -> T:
|
| 426 |
+
...
|
| 427 |
+
|
| 428 |
+
def sub(self, x0: T, x1: T) -> T:
|
| 429 |
+
...
|
| 430 |
+
|
| 431 |
+
def mul(self, x0: T, x1: T) -> T:
|
| 432 |
+
...
|
| 433 |
+
|
| 434 |
+
def floordiv(self, x0: T, x1: T) -> T:
|
| 435 |
+
...
|
| 436 |
+
|
| 437 |
+
def truediv(self, x0: T, x1: T) -> T:
|
| 438 |
+
...
|
| 439 |
+
|
| 440 |
+
def div(self, x0: T, x1: T) -> T:
|
| 441 |
+
...
|
| 442 |
+
|
| 443 |
+
def mod(self, x0: T, x1: T) -> T:
|
| 444 |
+
...
|
| 445 |
+
|
| 446 |
+
def pow(self, x0: T, x1: T) -> T:
|
| 447 |
+
...
|
| 448 |
+
|
| 449 |
+
def and_(self, x0: T, x1: T) -> T:
|
| 450 |
+
...
|
| 451 |
+
|
| 452 |
+
def or_(self, x0: T, x1: T) -> T:
|
| 453 |
+
...
|
| 454 |
+
|
| 455 |
+
def xor(self, x0: T, x1: T) -> T:
|
| 456 |
+
...
|
| 457 |
+
|
| 458 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 459 |
+
# In CUDA, optimized implementations of other mathematical operations are
|
| 460 |
+
# offered separately via libdevice for double precision computation (in
|
| 461 |
+
# Triton, these go to tl.math rather than tl). We lower to these
|
| 462 |
+
# operators when doing FP64 on CUDA. Note that some operators
|
| 463 |
+
# unconditional go to tl.math.
|
| 464 |
+
#
|
| 465 |
+
# TODO(ezyang): Is this really the best way to do this? What if we have
|
| 466 |
+
# abs internally route to tl.math automatically when given a double
|
| 467 |
+
# precision input? One reason is that when doing codegen, we often don't
|
| 468 |
+
# know what the dtype of the inputs are! (In principle we do know, but
|
| 469 |
+
# for many analyses it's not conveniently available.)
|
| 470 |
+
|
| 471 |
+
def libdevice_abs(self, x0: T) -> T:
|
| 472 |
+
...
|
| 473 |
+
|
| 474 |
+
def libdevice_exp(self, x0: T) -> T:
|
| 475 |
+
...
|
| 476 |
+
|
| 477 |
+
def libdevice_sqrt(self, x0: T) -> T:
|
| 478 |
+
...
|
| 479 |
+
|
| 480 |
+
def libdevice_cos(self, x0: T) -> T:
|
| 481 |
+
...
|
| 482 |
+
|
| 483 |
+
def libdevice_sin(self, x0: T) -> T:
|
| 484 |
+
...
|
| 485 |
+
|
| 486 |
+
def libdevice_sigmoid(self, x0: T) -> T:
|
| 487 |
+
...
|
| 488 |
+
|
| 489 |
+
def libdevice_log(self, x0: T) -> T:
|
| 490 |
+
...
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class MockHandler:
|
| 494 |
+
def __getattr__(self, name):
|
| 495 |
+
if name == "name":
|
| 496 |
+
return "MockHandler"
|
| 497 |
+
|
| 498 |
+
def inner(*args, **kwargs):
|
| 499 |
+
fargs = [_arg_str(a) for a in args]
|
| 500 |
+
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
|
| 501 |
+
return f"ops.{name}({', '.join(fargs)})"
|
| 502 |
+
|
| 503 |
+
return inner
|
| 504 |
+
|
| 505 |
+
@staticmethod
|
| 506 |
+
def masked(mask, body, other) -> str:
|
| 507 |
+
return f"ops.masked({mask}, {body()}, {other})"
|
| 508 |
+
|
| 509 |
+
@staticmethod
|
| 510 |
+
def frexp(x):
|
| 511 |
+
return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]")
|
| 512 |
+
|
| 513 |
+
@staticmethod
|
| 514 |
+
def indirect_indexing(index_var, size, check=True) -> sympy.Symbol:
|
| 515 |
+
return sympy_index_symbol(f"({str(index_var)})")
|
| 516 |
+
|
| 517 |
+
@classmethod
|
| 518 |
+
def _init_cls(cls):
|
| 519 |
+
def make_handler(format_string):
|
| 520 |
+
@staticmethod # type: ignore[misc]
|
| 521 |
+
def inner(*args):
|
| 522 |
+
return format_string.format(*args)
|
| 523 |
+
|
| 524 |
+
return inner
|
| 525 |
+
|
| 526 |
+
for name, format_string in itertools.chain(
|
| 527 |
+
magic_methods.items(), inplace_methods.items()
|
| 528 |
+
):
|
| 529 |
+
setattr(cls, name, make_handler(format_string))
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
MockHandler._init_cls()
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
# Use mypy to check protocol implemented correctly
|
| 536 |
+
def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]:
|
| 537 |
+
return h
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class KernelFormatterHandler:
|
| 541 |
+
def __init__(self, parent_handler):
|
| 542 |
+
self.parent_handler = parent_handler
|
| 543 |
+
self.output = IndentedBuffer(1)
|
| 544 |
+
self.var_counter = itertools.count()
|
| 545 |
+
|
| 546 |
+
@staticmethod
|
| 547 |
+
def ir_to_string(ir_fn, index, rindex=None) -> str:
|
| 548 |
+
from .ir import FlexibleLayout
|
| 549 |
+
from .virtualized import V
|
| 550 |
+
|
| 551 |
+
args = [index, rindex] if rindex is not None else [index]
|
| 552 |
+
names = ["index", "rindex"] if rindex is not None else ["index"]
|
| 553 |
+
formatter = KernelFormatterHandler(MockHandler())
|
| 554 |
+
|
| 555 |
+
with formatter.output.indent(-1):
|
| 556 |
+
formatter.output.writeline(f"def inner_fn({', '.join(names)}):")
|
| 557 |
+
for name, arg in zip(names, args):
|
| 558 |
+
if arg:
|
| 559 |
+
lhs = ", ".join(
|
| 560 |
+
[
|
| 561 |
+
str("_" if isinstance(v, (int, sympy.Integer)) else v)
|
| 562 |
+
for v in arg
|
| 563 |
+
]
|
| 564 |
+
)
|
| 565 |
+
formatter.output.writeline(f"{lhs} = {name}")
|
| 566 |
+
|
| 567 |
+
with V.set_ops_handler(formatter), patch.object(
|
| 568 |
+
FlexibleLayout, "allow_indexing", True
|
| 569 |
+
):
|
| 570 |
+
result = ir_fn(*args)
|
| 571 |
+
return formatter.getvalue(result)
|
| 572 |
+
|
| 573 |
+
def __getattr__(self, name) -> Callable[..., Any]:
|
| 574 |
+
def inner(*args, **kwargs):
|
| 575 |
+
line = getattr(self.parent_handler, name)(*args, **kwargs)
|
| 576 |
+
if name == "indirect_indexing":
|
| 577 |
+
return line
|
| 578 |
+
|
| 579 |
+
def write(line):
|
| 580 |
+
# replace line with a new variable name
|
| 581 |
+
varname = f"tmp{next(self.var_counter)}"
|
| 582 |
+
self.output.writeline(f"{varname} = {line}")
|
| 583 |
+
return varname
|
| 584 |
+
|
| 585 |
+
return pytree.tree_map(write, line)
|
| 586 |
+
|
| 587 |
+
return inner
|
| 588 |
+
|
| 589 |
+
def reduction(
|
| 590 |
+
self,
|
| 591 |
+
dtype: torch.dtype,
|
| 592 |
+
src_dtype: torch.dtype,
|
| 593 |
+
reduction_type: ReductionType,
|
| 594 |
+
value: Union[str, Tuple[str, ...]],
|
| 595 |
+
) -> Union[str, Tuple[str, ...]]:
|
| 596 |
+
line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value)
|
| 597 |
+
num_values = reduction_num_outputs(reduction_type)
|
| 598 |
+
varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)]
|
| 599 |
+
self.output.writeline(f"{','.join(varnames)} = {line}")
|
| 600 |
+
return tuple(varnames) if num_values > 1 else varnames[0]
|
| 601 |
+
|
| 602 |
+
def getvalue(self, result):
|
| 603 |
+
self.output.writeline(f"return {result}")
|
| 604 |
+
return self.output.getvalue()
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
# Use mypy to check protocol implemented correctly
|
| 608 |
+
def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
|
| 609 |
+
return h
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class WrapperHandler(Generic[T]):
|
| 613 |
+
def __init__(self, inner: OpsHandler[T]):
|
| 614 |
+
self._inner = inner
|
| 615 |
+
|
| 616 |
+
def __getattr__(self, item):
|
| 617 |
+
return getattr(self._inner, item)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
# Use mypy to check protocol implemented correctly
|
| 621 |
+
def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]:
|
| 622 |
+
return h
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class OpCounterCSE:
|
| 626 |
+
"""Shim to count how many ops are used"""
|
| 627 |
+
|
| 628 |
+
def __init__(self, inner):
|
| 629 |
+
super().__init__()
|
| 630 |
+
self.parent_handler = inner
|
| 631 |
+
self.op_count = 0
|
| 632 |
+
self.var_names = {}
|
| 633 |
+
|
| 634 |
+
def __getattr__(self, name):
|
| 635 |
+
def inner(*args, **kwargs):
|
| 636 |
+
val = getattr(self.parent_handler, name)(*args, **kwargs)
|
| 637 |
+
if name == "indirect_indexing":
|
| 638 |
+
return val
|
| 639 |
+
|
| 640 |
+
def count(val):
|
| 641 |
+
if val not in self.var_names:
|
| 642 |
+
varname = f"tmp{self.op_count}"
|
| 643 |
+
self.op_count += 1
|
| 644 |
+
self.var_names[val] = varname
|
| 645 |
+
return varname
|
| 646 |
+
else:
|
| 647 |
+
return self.var_names[val]
|
| 648 |
+
|
| 649 |
+
return pytree.tree_map(count, val)
|
| 650 |
+
|
| 651 |
+
return inner
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]:
|
| 655 |
+
return h
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/optimize_indexing.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import sympy
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils._sympy.value_ranges import ValueRanges
|
| 7 |
+
from .ir import LoopBody
|
| 8 |
+
from .utils import dominated_nodes
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def val_expressable_in_32_bits(val):
|
| 12 |
+
if getattr(val, "is_Boolean", False):
|
| 13 |
+
return True
|
| 14 |
+
|
| 15 |
+
if isinstance(val, sympy.Expr):
|
| 16 |
+
assert val.is_number
|
| 17 |
+
if val.is_Integer or val.is_Boolean:
|
| 18 |
+
val = int(val)
|
| 19 |
+
else:
|
| 20 |
+
val = float(val)
|
| 21 |
+
|
| 22 |
+
# bound within mantissa
|
| 23 |
+
if isinstance(val, float):
|
| 24 |
+
return val <= (2**24) and val >= -(2**24)
|
| 25 |
+
|
| 26 |
+
if isinstance(val, int):
|
| 27 |
+
iinfo = torch.iinfo(torch.int32)
|
| 28 |
+
return val <= iinfo.max and val >= iinfo.min
|
| 29 |
+
|
| 30 |
+
raise Exception(f"Unexpected value {val}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def range_expressable_in_32_bits(range):
|
| 34 |
+
return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
|
| 35 |
+
range.upper
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals):
|
| 40 |
+
# if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
|
| 41 |
+
# then it's precision is set for that chain of uses, and we don't need to consider those
|
| 42 |
+
# dominated values
|
| 43 |
+
def skip_filter(node):
|
| 44 |
+
return node.target == "to_dtype" and node.args[2] in (
|
| 45 |
+
torch.int32,
|
| 46 |
+
torch.float32,
|
| 47 |
+
torch.float64,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# TODO - there are dominated uses whose dtype does not depend on whether
|
| 51 |
+
# we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
|
| 52 |
+
# int32 without changing the output precision of the node. this case hasn't shown up
|
| 53 |
+
for dominated in dominated_nodes([node], skip_filter):
|
| 54 |
+
if dominated.target in ["store", "output"]:
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
if isinstance(dominated.target, str) and "set_indirect" in dominated.target:
|
| 58 |
+
idx = int(dominated.target[len("set_indirect") :])
|
| 59 |
+
indirect_var = indirect_vars[idx]
|
| 60 |
+
|
| 61 |
+
# We check that we can compute all the indices it's involved in with int32
|
| 62 |
+
for index, expr in indices.items():
|
| 63 |
+
if indirect_var in expr.free_symbols:
|
| 64 |
+
index_val = replacement_vals[index]
|
| 65 |
+
|
| 66 |
+
if math.isinf(index_val.lower) or math.isinf(index_val.upper):
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
# all indices are integers, so make sure that we
|
| 70 |
+
# use the bounds of integers instead of floats.
|
| 71 |
+
# TODO - not sure if we should be doing int/float casts while tracing,
|
| 72 |
+
# might interfere with sympy.
|
| 73 |
+
|
| 74 |
+
index_val_int = ValueRanges[sympy.Expr](
|
| 75 |
+
int(index_val.lower), int(index_val.upper)
|
| 76 |
+
)
|
| 77 |
+
if not range_expressable_in_32_bits(index_val_int):
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
if not range_expressable_in_32_bits(bounds[dominated]):
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
args = list(node.args)
|
| 84 |
+
args[2] = torch.int32
|
| 85 |
+
node.args = tuple(args)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def indexing_dtype_strength_reduction(loop_body: LoopBody):
|
| 89 |
+
"""
|
| 90 |
+
Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
|
| 91 |
+
intermediaries from int64 to int32
|
| 92 |
+
"""
|
| 93 |
+
bv = loop_body.bounds()
|
| 94 |
+
|
| 95 |
+
int64_dtype_nodes = [
|
| 96 |
+
node
|
| 97 |
+
for node in loop_body.get_nodes()
|
| 98 |
+
if (
|
| 99 |
+
node.target == "to_dtype"
|
| 100 |
+
and node.args[2] == torch.int64
|
| 101 |
+
and node not in bv.unbounded_vars
|
| 102 |
+
)
|
| 103 |
+
]
|
| 104 |
+
if not int64_dtype_nodes:
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
bounds = bv.get_bounds()
|
| 108 |
+
|
| 109 |
+
# TODO - if dominated node of one to_dtype is not expressible in int32,
|
| 110 |
+
# we should short circuit another to_dtype node if that node also dominates
|
| 111 |
+
for node in int64_dtype_nodes:
|
| 112 |
+
try_to_reduce_precision(
|
| 113 |
+
node,
|
| 114 |
+
bounds,
|
| 115 |
+
loop_body.indirect_vars,
|
| 116 |
+
loop_body.indexing_exprs,
|
| 117 |
+
bv.replacement_vals,
|
| 118 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_heuristics.py
ADDED
|
@@ -0,0 +1,1527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import builtins
|
| 2 |
+
import copy
|
| 3 |
+
import functools
|
| 4 |
+
import hashlib
|
| 5 |
+
import inspect
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
import operator
|
| 10 |
+
import os
|
| 11 |
+
import os.path
|
| 12 |
+
import re
|
| 13 |
+
import threading
|
| 14 |
+
from enum import auto, Enum
|
| 15 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
import torch.autograd.profiler as autograd_profiler
|
| 20 |
+
from torch._dynamo.device_interface import get_interface_for_device
|
| 21 |
+
from torch._dynamo.utils import dynamo_timed, get_first_attr
|
| 22 |
+
from torch.utils._triton import has_triton_package
|
| 23 |
+
|
| 24 |
+
from . import config
|
| 25 |
+
from .codecache import cache_dir, CudaKernelParamCache
|
| 26 |
+
from .coordinate_descent_tuner import CoordescTuner
|
| 27 |
+
|
| 28 |
+
from .ir import ReductionHint, TileHint
|
| 29 |
+
from .utils import (
|
| 30 |
+
ceildiv,
|
| 31 |
+
conditional_product,
|
| 32 |
+
create_bandwidth_info_str,
|
| 33 |
+
do_bench,
|
| 34 |
+
get_max_y_grid,
|
| 35 |
+
get_num_bytes,
|
| 36 |
+
next_power_of_2,
|
| 37 |
+
triton_config_to_hashable,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
log = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
if has_triton_package():
|
| 44 |
+
import triton
|
| 45 |
+
from triton import Config
|
| 46 |
+
from triton.runtime.autotuner import OutOfResources
|
| 47 |
+
from triton.runtime.jit import KernelInterface
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
from triton.compiler.compiler import ASTSource
|
| 51 |
+
except ImportError:
|
| 52 |
+
ASTSource = None
|
| 53 |
+
else:
|
| 54 |
+
Config = object
|
| 55 |
+
triton = None
|
| 56 |
+
KernelInterface = object
|
| 57 |
+
OutOfResources = object
|
| 58 |
+
ASTSource = None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
_NUM_THREADS_PER_WARP = 32
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class HeuristicType(Enum):
|
| 65 |
+
PERSISTENT_REDUCTION = auto()
|
| 66 |
+
POINTWISE = auto()
|
| 67 |
+
REDUCTION = auto()
|
| 68 |
+
SPLIT_SCAN = auto()
|
| 69 |
+
TEMPLATE = auto()
|
| 70 |
+
USER_AUTOTUNE = auto()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class AutotuneHint(Enum):
|
| 74 |
+
ELEMENTS_PER_WARP_32 = 0
|
| 75 |
+
|
| 76 |
+
# Triton codegen tries to codegen set of AutotuneHints.
|
| 77 |
+
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
|
| 78 |
+
# which isn't valid python.
|
| 79 |
+
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
|
| 80 |
+
__repr__ = Enum.__str__
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def autotune_hints_to_configs(
|
| 84 |
+
hints: Set[AutotuneHint], size_hints, block_size: int
|
| 85 |
+
) -> List[Config]:
|
| 86 |
+
"""
|
| 87 |
+
AutotuneHints can be attached to the metadata of triton kernels for providing
|
| 88 |
+
suggestions about what to try for autotuning. One reason to do this is if there are
|
| 89 |
+
some configs that are only useful in specific scenarios, in which case we can avoid
|
| 90 |
+
wasting compile time on autotuning unless we know we are in one of those scenarios.
|
| 91 |
+
|
| 92 |
+
Based on those hints, this function will generate a list of additional autotuning
|
| 93 |
+
configs to try.
|
| 94 |
+
"""
|
| 95 |
+
xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...]
|
| 96 |
+
configs = []
|
| 97 |
+
|
| 98 |
+
for hint in hints:
|
| 99 |
+
if hint == AutotuneHint.ELEMENTS_PER_WARP_32:
|
| 100 |
+
if len(size_hints) == 1:
|
| 101 |
+
xyz_options = ((block_size // 4, None, None),)
|
| 102 |
+
elif len(size_hints) == 2:
|
| 103 |
+
xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None))
|
| 104 |
+
elif len(size_hints) == 3:
|
| 105 |
+
xyz_options = (
|
| 106 |
+
(block_size // 4, 1, 1),
|
| 107 |
+
(1, block_size // 4, 1),
|
| 108 |
+
(1, 1, block_size // 4),
|
| 109 |
+
)
|
| 110 |
+
for xyz in xyz_options:
|
| 111 |
+
configs.append(
|
| 112 |
+
triton_config(
|
| 113 |
+
size_hints,
|
| 114 |
+
*xyz,
|
| 115 |
+
num_elements_per_warp=32,
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return configs
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def disable_pointwise_autotuning():
|
| 123 |
+
# Autotuning can give different benchmarking results from run to run, and
|
| 124 |
+
# therefore we disable autotuning when use_deterministic flag is on.
|
| 125 |
+
if torch.are_deterministic_algorithms_enabled():
|
| 126 |
+
return True
|
| 127 |
+
return not config.triton.autotune_pointwise
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class CachingAutotuner(KernelInterface):
|
| 131 |
+
"""
|
| 132 |
+
Simplified version of Triton autotuner that has no invalidation
|
| 133 |
+
key and caches the best config to disk to improve cold start times.
|
| 134 |
+
Unlike the main triton Autotuner, this version can precompile all
|
| 135 |
+
configs, and does not rely on the Triton JIT.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
fn,
|
| 141 |
+
triton_meta, # passed directly to triton
|
| 142 |
+
configs,
|
| 143 |
+
save_cache_hook,
|
| 144 |
+
mutated_arg_names,
|
| 145 |
+
heuristic_type,
|
| 146 |
+
size_hints=None,
|
| 147 |
+
inductor_meta=None, # metadata not relevant to triton
|
| 148 |
+
custom_kernel=False, # whether the kernel is inductor-generated or custom
|
| 149 |
+
):
|
| 150 |
+
super().__init__()
|
| 151 |
+
|
| 152 |
+
assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
|
| 153 |
+
self.fn = fn
|
| 154 |
+
self.triton_meta = triton_meta
|
| 155 |
+
self.inductor_meta = {} if inductor_meta is None else inductor_meta
|
| 156 |
+
self.save_cache_hook = save_cache_hook
|
| 157 |
+
self.mutated_arg_names = mutated_arg_names
|
| 158 |
+
self.configs = configs
|
| 159 |
+
self.heuristic_type = heuristic_type
|
| 160 |
+
self.custom_kernel = custom_kernel
|
| 161 |
+
self.cuda_kernel_saved = False
|
| 162 |
+
|
| 163 |
+
# Align the default design that default as cuda
|
| 164 |
+
self.device_type = (
|
| 165 |
+
triton_meta["device_type"] if "device_type" in triton_meta else "cuda"
|
| 166 |
+
)
|
| 167 |
+
self.gpu_device = get_interface_for_device(self.device_type)
|
| 168 |
+
|
| 169 |
+
if log.isEnabledFor(logging.DEBUG):
|
| 170 |
+
log.debug(
|
| 171 |
+
"CachingAutotuner gets %d configs for %s",
|
| 172 |
+
len(self.configs),
|
| 173 |
+
self.fn.__name__,
|
| 174 |
+
)
|
| 175 |
+
for c in self.configs:
|
| 176 |
+
log.debug(c)
|
| 177 |
+
|
| 178 |
+
self.launchers = []
|
| 179 |
+
self.lock = threading.Lock()
|
| 180 |
+
if os.getenv("TRITON_CACHE_DIR") is None:
|
| 181 |
+
os.environ["TRITON_CACHE_DIR"] = os.path.join(
|
| 182 |
+
cache_dir(),
|
| 183 |
+
"triton",
|
| 184 |
+
str(self.triton_meta.get("device", 0)),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
self.size_hints = size_hints
|
| 188 |
+
self.coordesc_tuner = CoordescTuner(
|
| 189 |
+
is_mm=False, name=self.fn.__name__, size_hints=size_hints
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# pre-create the profiler context manager to reduce latency
|
| 193 |
+
self.record_function_ctx = torch._C._profiler._RecordFunctionFast(
|
| 194 |
+
self.inductor_meta.get("kernel_name", "triton kernel")
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def precompile(self, warm_cache_only_with_cc=None):
|
| 198 |
+
with self.lock:
|
| 199 |
+
if self.launchers:
|
| 200 |
+
return
|
| 201 |
+
self.launchers = []
|
| 202 |
+
compiled_binaries = []
|
| 203 |
+
if not self.configs:
|
| 204 |
+
raise RuntimeError("No triton configs are available")
|
| 205 |
+
|
| 206 |
+
for c in self.configs:
|
| 207 |
+
try:
|
| 208 |
+
compiled_binary, launcher = self._precompile_config(
|
| 209 |
+
c, warm_cache_only_with_cc
|
| 210 |
+
)
|
| 211 |
+
except OutOfResources:
|
| 212 |
+
# Skip the config if we run out of resource
|
| 213 |
+
continue
|
| 214 |
+
self.launchers.append(launcher)
|
| 215 |
+
compiled_binaries.append(compiled_binary)
|
| 216 |
+
|
| 217 |
+
if len(self.launchers) == 0:
|
| 218 |
+
raise RuntimeError(
|
| 219 |
+
"No valid triton configs. Report a fatal compilation error"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
seen_configs = set(self.configs)
|
| 223 |
+
|
| 224 |
+
device_prop = self.gpu_device.Worker.get_device_properties(
|
| 225 |
+
self.triton_meta["device"]
|
| 226 |
+
)
|
| 227 |
+
if (
|
| 228 |
+
config.dynamic_scale_rblock
|
| 229 |
+
and self.heuristic_type == HeuristicType.REDUCTION
|
| 230 |
+
and self.size_hints is not None
|
| 231 |
+
# Disable for AMDGPU as Triton is not ready to return n_regs for a compiled_binary.
|
| 232 |
+
and torch.version.hip is None
|
| 233 |
+
and device_prop.major >= 8
|
| 234 |
+
):
|
| 235 |
+
for triton_config, compiled_binary in zip(
|
| 236 |
+
self.configs, compiled_binaries
|
| 237 |
+
):
|
| 238 |
+
assert len(self.size_hints) == 2
|
| 239 |
+
xblock = triton_config.kwargs.get("XBLOCK", 1)
|
| 240 |
+
rblock = triton_config.kwargs["RBLOCK"]
|
| 241 |
+
total_block = (self.size_hints[0] + xblock - 1) // xblock
|
| 242 |
+
nreg = getattr(compiled_binary, "n_regs", None)
|
| 243 |
+
if nreg is None:
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
# make sure rblock is not too small
|
| 247 |
+
if rblock <= 64:
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
# each SM of A100 has 65536 32-bit registers. To maximize
|
| 251 |
+
# the theoretical occupancy, we need run 2048 threads on each
|
| 252 |
+
# SM. So each thread should use no more than 65536 / 2048
|
| 253 |
+
# = 32 registers. In cases where occupancy matters, and each
|
| 254 |
+
# thread uses too many registers, reduce RBLOCK to reduce
|
| 255 |
+
# the register usage.
|
| 256 |
+
# For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
|
| 257 |
+
# from PLBartForCausalLM, latency improve from
|
| 258 |
+
# 7.795ms to 4.883ms.
|
| 259 |
+
#
|
| 260 |
+
if (
|
| 261 |
+
nreg
|
| 262 |
+
<= device_prop.regs_per_multiprocessor
|
| 263 |
+
// device_prop.max_threads_per_multi_processor
|
| 264 |
+
):
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
nreg_per_warp = nreg * 32
|
| 268 |
+
nreg_per_block = nreg_per_warp * triton_config.num_warps
|
| 269 |
+
|
| 270 |
+
# Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)'
|
| 271 |
+
# The formula below is a tighter upper bound since we have the assumption that
|
| 272 |
+
# nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor
|
| 273 |
+
# due to the if condition above and:
|
| 274 |
+
# regs_per_multiprocessor / nreg_per_block
|
| 275 |
+
# = regs_per_multiprocessor / (nreg * 32 * num_warps)
|
| 276 |
+
# < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
|
| 277 |
+
# = max_threads_per_multi_processor / (32 * num_warps)
|
| 278 |
+
# Using a tigher upper bound can reveal more optimization opportunities.
|
| 279 |
+
max_blocks_per_sm = max(
|
| 280 |
+
device_prop.regs_per_multiprocessor // nreg_per_block, 1
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if (
|
| 284 |
+
total_block
|
| 285 |
+
<= max_blocks_per_sm * device_prop.multi_processor_count
|
| 286 |
+
):
|
| 287 |
+
# no need to improve occupancy
|
| 288 |
+
continue
|
| 289 |
+
new_config = copy.deepcopy(triton_config)
|
| 290 |
+
new_config.kwargs["RBLOCK"] = rblock // 2
|
| 291 |
+
if new_config in seen_configs:
|
| 292 |
+
continue
|
| 293 |
+
seen_configs.add(new_config)
|
| 294 |
+
self.launchers.append(
|
| 295 |
+
self._precompile_config(new_config, warm_cache_only_with_cc)[1]
|
| 296 |
+
)
|
| 297 |
+
self.configs = None
|
| 298 |
+
|
| 299 |
+
def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]):
|
| 300 |
+
"""Ahead of time compile a given autotuner config."""
|
| 301 |
+
compile_meta = copy.deepcopy(self.triton_meta)
|
| 302 |
+
for k, v in cfg.kwargs.items():
|
| 303 |
+
compile_meta["constants"][self.fn.arg_names.index(k)] = v
|
| 304 |
+
compile_meta["num_warps"] = cfg.num_warps
|
| 305 |
+
compile_meta["num_stages"] = cfg.num_stages
|
| 306 |
+
compile_meta["debug"] = (
|
| 307 |
+
config.assert_indirect_indexing and torch.version.hip is None
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Setting device_type="hip" required on ROCm to pass down to triton
|
| 311 |
+
compile_meta["device_type"] = (
|
| 312 |
+
self.device_type if torch.version.hip is None else "hip"
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if warm_cache_only_with_cc:
|
| 316 |
+
cc = warm_cache_only_with_cc
|
| 317 |
+
else:
|
| 318 |
+
# Use device_type 'cuda' for both cuda and hip devices to retrieve
|
| 319 |
+
# the compute capability.
|
| 320 |
+
device_type = self.device_type if torch.version.hip is None else "cuda"
|
| 321 |
+
device_id = compile_meta["device"]
|
| 322 |
+
device = torch.device(device_type, device_id)
|
| 323 |
+
cc = self.gpu_device.get_compute_capability(device)
|
| 324 |
+
|
| 325 |
+
compile_meta["cc"] = cc
|
| 326 |
+
|
| 327 |
+
if ASTSource:
|
| 328 |
+
compile_args = (
|
| 329 |
+
ASTSource(
|
| 330 |
+
self.fn,
|
| 331 |
+
compile_meta["signature"],
|
| 332 |
+
compile_meta["constants"],
|
| 333 |
+
compile_meta["configs"][0],
|
| 334 |
+
),
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
target = (compile_meta["device_type"], cc)
|
| 338 |
+
options = {
|
| 339 |
+
"num_warps": compile_meta["num_warps"],
|
| 340 |
+
"num_stages": compile_meta["num_stages"],
|
| 341 |
+
"debug": compile_meta["debug"],
|
| 342 |
+
}
|
| 343 |
+
compile_kwargs = {
|
| 344 |
+
"target": target,
|
| 345 |
+
"options": options,
|
| 346 |
+
}
|
| 347 |
+
else:
|
| 348 |
+
compile_args = (self.fn,)
|
| 349 |
+
compile_kwargs = compile_meta
|
| 350 |
+
|
| 351 |
+
if warm_cache_only_with_cc:
|
| 352 |
+
return (
|
| 353 |
+
triton.compile(*compile_args, **compile_kwargs),
|
| 354 |
+
None,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# load binary to the correct device
|
| 358 |
+
with self.gpu_device.device(compile_meta["device"]): # type: ignore[attr-defined]
|
| 359 |
+
# need to initialize context
|
| 360 |
+
self.gpu_device.synchronize(self.gpu_device.current_device())
|
| 361 |
+
|
| 362 |
+
try:
|
| 363 |
+
binary = triton.compile(*compile_args, **compile_kwargs)
|
| 364 |
+
except Exception:
|
| 365 |
+
log.exception(
|
| 366 |
+
"Triton compilation failed: %s\n%s\nmetadata: %s",
|
| 367 |
+
self.inductor_meta.get("kernel_name", "triton_"),
|
| 368 |
+
self.fn.src,
|
| 369 |
+
compile_meta,
|
| 370 |
+
)
|
| 371 |
+
raise
|
| 372 |
+
binary._init_handles()
|
| 373 |
+
|
| 374 |
+
call_args = [
|
| 375 |
+
arg
|
| 376 |
+
for i, arg in enumerate(self.fn.arg_names)
|
| 377 |
+
if i not in self.fn.constexprs
|
| 378 |
+
]
|
| 379 |
+
def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs]
|
| 380 |
+
|
| 381 |
+
scope = {
|
| 382 |
+
"grid_meta": cfg.kwargs,
|
| 383 |
+
"bin": binary,
|
| 384 |
+
"launch_enter_hook": binary.launch_enter_hook,
|
| 385 |
+
"launch_exit_hook": binary.launch_exit_hook,
|
| 386 |
+
"metadata": binary.metadata,
|
| 387 |
+
"torch": torch,
|
| 388 |
+
"set_device": self.gpu_device.set_device,
|
| 389 |
+
"current_device": self.gpu_device.current_device,
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
scope["runner"] = get_first_attr(binary, "run", "c_wrapper")
|
| 393 |
+
scope["function"] = get_first_attr(binary, "function", "cu_function")
|
| 394 |
+
scope["cta_args"] = (
|
| 395 |
+
(binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims"))
|
| 396 |
+
if hasattr(binary, "num_ctas")
|
| 397 |
+
else (
|
| 398 |
+
(binary.metadata.num_ctas, *binary.metadata.cluster_dims)
|
| 399 |
+
if hasattr(binary, "metadata")
|
| 400 |
+
else ()
|
| 401 |
+
)
|
| 402 |
+
)
|
| 403 |
+
scope["num_warps"] = (
|
| 404 |
+
binary.num_warps
|
| 405 |
+
if hasattr(binary, "num_warps")
|
| 406 |
+
else binary.metadata.num_warps
|
| 407 |
+
)
|
| 408 |
+
binary_shared = (
|
| 409 |
+
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
|
| 410 |
+
)
|
| 411 |
+
scope["shared"] = binary_shared
|
| 412 |
+
|
| 413 |
+
exec(
|
| 414 |
+
f"""
|
| 415 |
+
def launcher({', '.join(def_args)}, grid, stream):
|
| 416 |
+
if callable(grid):
|
| 417 |
+
grid_0, grid_1, grid_2 = grid(grid_meta)
|
| 418 |
+
else:
|
| 419 |
+
grid_0, grid_1, grid_2 = grid
|
| 420 |
+
|
| 421 |
+
runner(grid_0, grid_1, grid_2, num_warps,
|
| 422 |
+
*cta_args, shared,
|
| 423 |
+
stream, function,
|
| 424 |
+
launch_enter_hook,
|
| 425 |
+
launch_exit_hook,
|
| 426 |
+
metadata,
|
| 427 |
+
{', '.join(call_args)})
|
| 428 |
+
return bin
|
| 429 |
+
""".lstrip(),
|
| 430 |
+
scope,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
launcher = scope["launcher"]
|
| 434 |
+
launcher.config = cfg
|
| 435 |
+
launcher.n_regs = getattr(binary, "n_regs", None)
|
| 436 |
+
launcher.n_spills = getattr(binary, "n_spills", None)
|
| 437 |
+
launcher.shared = binary_shared
|
| 438 |
+
launcher.store_cubin = config.triton.store_cubin
|
| 439 |
+
# store this global variable to avoid the high overhead of reading it when calling run
|
| 440 |
+
if launcher.store_cubin:
|
| 441 |
+
launcher.fn = self.fn
|
| 442 |
+
launcher.bin = binary
|
| 443 |
+
|
| 444 |
+
return binary, launcher
|
| 445 |
+
|
| 446 |
+
def bench(self, launcher, *args, grid, **kwargs):
|
| 447 |
+
"""Measure the performance of a given launcher"""
|
| 448 |
+
# we don't skip configs wiht spilled registers when auto-tuning custom
|
| 449 |
+
# (user-written) Triton kernels, as (i) we don't have any knowledge or
|
| 450 |
+
# control over the kernel code; (ii) there is empirical evidence that
|
| 451 |
+
# for some (complicated) custom Triton kernels, a register-spilling
|
| 452 |
+
# config may yield the best latency.
|
| 453 |
+
if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold:
|
| 454 |
+
log.debug(
|
| 455 |
+
"Skip config %s because of register spilling: %d",
|
| 456 |
+
launcher.config,
|
| 457 |
+
launcher.n_spills,
|
| 458 |
+
)
|
| 459 |
+
return float("inf")
|
| 460 |
+
|
| 461 |
+
stream = self.gpu_device.get_raw_stream( # type: ignore[call-arg]
|
| 462 |
+
self.gpu_device.current_device()
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def kernel_call():
|
| 466 |
+
if launcher.config.pre_hook is not None:
|
| 467 |
+
launcher.config.pre_hook(
|
| 468 |
+
{**dict(zip(self.arg_names, args)), **launcher.config.kwargs}
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
|
| 472 |
+
launcher(
|
| 473 |
+
*cloned_args,
|
| 474 |
+
**cloned_kwargs,
|
| 475 |
+
grid=grid,
|
| 476 |
+
stream=stream,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
return do_bench(kernel_call, rep=40, fast_flush=True)
|
| 480 |
+
|
| 481 |
+
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
|
| 482 |
+
from .compile_fx import clone_preserve_strides
|
| 483 |
+
|
| 484 |
+
# clone inplace buffers to avoid autotune contaminating them if
|
| 485 |
+
# the kernel does in-place stores. avoid cloning other buffers because
|
| 486 |
+
# it leads to increase memory use
|
| 487 |
+
cloned_args = []
|
| 488 |
+
for i, arg in enumerate(args):
|
| 489 |
+
if self.fn.arg_names[i] in self.mutated_arg_names:
|
| 490 |
+
assert isinstance(arg, torch.Tensor)
|
| 491 |
+
cloned_args.append(clone_preserve_strides(arg))
|
| 492 |
+
else:
|
| 493 |
+
cloned_args.append(arg)
|
| 494 |
+
|
| 495 |
+
cloned_kwargs: Dict[str, Any] = {}
|
| 496 |
+
for name, arg in kwargs.items():
|
| 497 |
+
if name in self.mutated_arg_names:
|
| 498 |
+
assert isinstance(arg, torch.Tensor)
|
| 499 |
+
cloned_kwargs[name] = clone_preserve_strides(arg)
|
| 500 |
+
else:
|
| 501 |
+
cloned_kwargs[name] = arg
|
| 502 |
+
|
| 503 |
+
return cloned_args, cloned_kwargs
|
| 504 |
+
|
| 505 |
+
@dynamo_timed
|
| 506 |
+
def benchmark_all_configs(self, *args, **kwargs):
|
| 507 |
+
timings = {
|
| 508 |
+
launcher: self.bench(launcher, *args, **kwargs)
|
| 509 |
+
for launcher in self.launchers
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
for k, v in timings.items():
|
| 513 |
+
self.coordesc_tuner.cache_benchmark_result(k.config, v)
|
| 514 |
+
|
| 515 |
+
if log.isEnabledFor(logging.DEBUG):
|
| 516 |
+
log.debug("Benchmark all input configs for %s, get:", self.fn.__name__)
|
| 517 |
+
for k, v in timings.items():
|
| 518 |
+
log.debug(
|
| 519 |
+
"%s: %f, nreg %d, nspill %d, #shared-mem %s",
|
| 520 |
+
k.config,
|
| 521 |
+
v,
|
| 522 |
+
k.n_regs,
|
| 523 |
+
k.n_spills,
|
| 524 |
+
k.shared,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
return timings
|
| 528 |
+
|
| 529 |
+
def autotune_to_one_config(self, *args, **kwargs):
|
| 530 |
+
"""Do the actual autotuning"""
|
| 531 |
+
timings = self.benchmark_all_configs(*args, **kwargs)
|
| 532 |
+
self.launchers = [builtins.min(timings, key=timings.get)]
|
| 533 |
+
if self.save_cache_hook:
|
| 534 |
+
self.save_cache_hook(self.launchers[0].config)
|
| 535 |
+
|
| 536 |
+
def save_cuda_kernel(self, grid, stream, launcher):
|
| 537 |
+
if callable(grid):
|
| 538 |
+
grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
|
| 539 |
+
else:
|
| 540 |
+
grid_x, grid_y, grid_z = grid
|
| 541 |
+
|
| 542 |
+
key = self.inductor_meta.get("kernel_name", None) # unique kernel name
|
| 543 |
+
assert key is not None, "kernel_name can not be None"
|
| 544 |
+
params = {
|
| 545 |
+
"mangled_name": launcher.bin.metadata.name
|
| 546 |
+
if hasattr(launcher.bin.metadata, "name")
|
| 547 |
+
else launcher.bin.metadata["name"],
|
| 548 |
+
"grid_x": grid_x,
|
| 549 |
+
"grid_y": grid_y,
|
| 550 |
+
"grid_z": grid_z,
|
| 551 |
+
"x_block": launcher.config.kwargs.get("XBLOCK", 1),
|
| 552 |
+
"y_block": launcher.config.kwargs.get("YBLOCK", None),
|
| 553 |
+
"z_block": launcher.config.kwargs.get("ZBLOCK", None),
|
| 554 |
+
"num_warps": launcher.bin.num_warps
|
| 555 |
+
if hasattr(launcher.bin, "num_warps")
|
| 556 |
+
else launcher.bin.metadata.num_warps,
|
| 557 |
+
"shared_mem": launcher.bin.shared
|
| 558 |
+
if hasattr(launcher.bin, "shared")
|
| 559 |
+
else launcher.bin.metadata.shared,
|
| 560 |
+
"stream": stream,
|
| 561 |
+
# User defined triton kernels will have arbitrary kwarg names
|
| 562 |
+
"meta": launcher.config.kwargs,
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
if torch.version.hip is None:
|
| 566 |
+
CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"])
|
| 567 |
+
else:
|
| 568 |
+
# There is some divergence between CUDA and ROCm here.
|
| 569 |
+
# On ROCm's triton we only have the the path to the binary, not the binary itself.
|
| 570 |
+
# For ROCm we will copy the binary to the new location instead of writing to file
|
| 571 |
+
import pathlib
|
| 572 |
+
|
| 573 |
+
launcher.bin.asm["hsaco"] = pathlib.Path(
|
| 574 |
+
launcher.bin.asm["hsaco_path"]
|
| 575 |
+
).read_bytes()
|
| 576 |
+
CudaKernelParamCache.set(key, params, launcher.bin.asm["hsaco"])
|
| 577 |
+
|
| 578 |
+
self.cuda_kernel_saved = True
|
| 579 |
+
|
| 580 |
+
def coordinate_descent_tuning(self, launcher, *args, **kwargs):
|
| 581 |
+
"""
|
| 582 |
+
Coordinate descent tuning can be run with or without max-autotune.
|
| 583 |
+
|
| 584 |
+
The only difference between these two is the starting config for coordinate_descent tuning.
|
| 585 |
+
E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4
|
| 586 |
+
and max-autotune figure out C3 is the best.
|
| 587 |
+
|
| 588 |
+
Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1;
|
| 589 |
+
while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
|
| 590 |
+
"""
|
| 591 |
+
if (
|
| 592 |
+
self.heuristic_type == HeuristicType.TEMPLATE
|
| 593 |
+
or self.heuristic_type == HeuristicType.USER_AUTOTUNE
|
| 594 |
+
):
|
| 595 |
+
# skip triton template
|
| 596 |
+
return launcher
|
| 597 |
+
|
| 598 |
+
cloned_args, _ = self.clone_args(*args)
|
| 599 |
+
config2launcher = {launcher.config: launcher}
|
| 600 |
+
|
| 601 |
+
def benchmark_one_config(config):
|
| 602 |
+
with self.lock:
|
| 603 |
+
_, launcher = self._precompile_config(config, None)
|
| 604 |
+
config2launcher[config] = launcher
|
| 605 |
+
|
| 606 |
+
out = self.bench(launcher, *cloned_args, **kwargs)
|
| 607 |
+
log.debug(
|
| 608 |
+
"COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
|
| 609 |
+
launcher.config,
|
| 610 |
+
out,
|
| 611 |
+
launcher.n_regs,
|
| 612 |
+
launcher.n_spills,
|
| 613 |
+
launcher.shared,
|
| 614 |
+
)
|
| 615 |
+
return out
|
| 616 |
+
|
| 617 |
+
assert not (
|
| 618 |
+
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
|
| 619 |
+
and "RBLOCK" in launcher.config.kwargs
|
| 620 |
+
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK"
|
| 621 |
+
best_config = self.coordesc_tuner.autotune(
|
| 622 |
+
benchmark_one_config, launcher.config, None
|
| 623 |
+
)
|
| 624 |
+
best_config.found_by_coordesc = True
|
| 625 |
+
|
| 626 |
+
if self.save_cache_hook:
|
| 627 |
+
self.save_cache_hook(best_config, found_by_coordesc=True)
|
| 628 |
+
return config2launcher.get(best_config)
|
| 629 |
+
|
| 630 |
+
def run(self, *args, grid, stream, **kwargs):
|
| 631 |
+
if len(self.launchers) != 1:
|
| 632 |
+
if len(self.launchers) == 0:
|
| 633 |
+
self.precompile()
|
| 634 |
+
if len(self.launchers) > 1:
|
| 635 |
+
self.autotune_to_one_config(*args, grid=grid, **kwargs)
|
| 636 |
+
|
| 637 |
+
if (
|
| 638 |
+
not getattr(self.launchers[0].config, "found_by_coordesc", False)
|
| 639 |
+
and config.coordinate_descent_tuning
|
| 640 |
+
):
|
| 641 |
+
self.launchers = [
|
| 642 |
+
self.coordinate_descent_tuning(
|
| 643 |
+
self.launchers[0], *args, grid=grid, **kwargs
|
| 644 |
+
)
|
| 645 |
+
]
|
| 646 |
+
|
| 647 |
+
(launcher,) = self.launchers
|
| 648 |
+
if launcher.store_cubin:
|
| 649 |
+
self.save_cuda_kernel(grid, stream, launcher)
|
| 650 |
+
|
| 651 |
+
if launcher.config.pre_hook is not None:
|
| 652 |
+
launcher.config.pre_hook(
|
| 653 |
+
{**dict(zip(self.arg_names, args)), **launcher.config.kwargs, **kwargs}
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
# guard the record_function_ctx and only call it if profiling is currently
|
| 657 |
+
# in progress, to reduce latency when profiler is not turned on. Note that
|
| 658 |
+
# the "if" statement (instead of, say, a contextlib.nullcontext) is intentional;
|
| 659 |
+
# it is faster than entering and exiting a context manager, even if the context
|
| 660 |
+
# manager is a nullcontext.
|
| 661 |
+
if autograd_profiler._is_profiler_enabled:
|
| 662 |
+
with self.record_function_ctx:
|
| 663 |
+
return launcher(
|
| 664 |
+
*args,
|
| 665 |
+
**kwargs,
|
| 666 |
+
grid=grid,
|
| 667 |
+
stream=stream,
|
| 668 |
+
)
|
| 669 |
+
else:
|
| 670 |
+
return launcher(
|
| 671 |
+
*args,
|
| 672 |
+
**kwargs,
|
| 673 |
+
grid=grid,
|
| 674 |
+
stream=stream,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def _find_names(obj):
|
| 679 |
+
import gc
|
| 680 |
+
import inspect
|
| 681 |
+
|
| 682 |
+
frame = inspect.currentframe()
|
| 683 |
+
while frame is not None:
|
| 684 |
+
frame.f_locals
|
| 685 |
+
frame = frame.f_back
|
| 686 |
+
obj_names = []
|
| 687 |
+
for referrer in gc.get_referrers(obj):
|
| 688 |
+
if isinstance(referrer, dict):
|
| 689 |
+
for k, v in referrer.items():
|
| 690 |
+
if v is obj:
|
| 691 |
+
obj_names.append(k)
|
| 692 |
+
return obj_names
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
collected_calls: List[Any] = []
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def start_graph():
|
| 699 |
+
collected_calls.clear()
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def end_graph():
|
| 703 |
+
if len(collected_calls) == 0:
|
| 704 |
+
return
|
| 705 |
+
overall_time = sum(call[0] for call in collected_calls)
|
| 706 |
+
overall_gb = sum(call[1] for call in collected_calls)
|
| 707 |
+
cur_file = inspect.stack()[1].filename
|
| 708 |
+
summary_str = (
|
| 709 |
+
f"SUMMARY ({cur_file})\n"
|
| 710 |
+
f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s"
|
| 711 |
+
)
|
| 712 |
+
print(summary_str)
|
| 713 |
+
print()
|
| 714 |
+
output_file = config.profile_bandwidth_output
|
| 715 |
+
if output_file is not None:
|
| 716 |
+
# sort perf numbers in descending order, i.e. placing the
|
| 717 |
+
# most runtime-heavy kernels at the top of the list
|
| 718 |
+
sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True)
|
| 719 |
+
try:
|
| 720 |
+
with open(output_file, "a") as file:
|
| 721 |
+
log.debug("Save profile bandwidth results to %s", output_file)
|
| 722 |
+
file.write("====================\n")
|
| 723 |
+
file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n")
|
| 724 |
+
for ms, num_gb, gb_per_s, kernel_name in sorted_calls:
|
| 725 |
+
# also display the runtime percentage for each kernel
|
| 726 |
+
percentage = f"{ms/overall_time*100:.2f}%"
|
| 727 |
+
suffix = f" \t {percentage} \t {kernel_name}"
|
| 728 |
+
bw_info_str = create_bandwidth_info_str(
|
| 729 |
+
ms,
|
| 730 |
+
num_gb,
|
| 731 |
+
gb_per_s,
|
| 732 |
+
suffix=suffix,
|
| 733 |
+
color=False,
|
| 734 |
+
)
|
| 735 |
+
file.write(bw_info_str + "\n")
|
| 736 |
+
file.write(f"{summary_str}\n\n")
|
| 737 |
+
except Exception as e:
|
| 738 |
+
log.warning(
|
| 739 |
+
"failed to write profile bandwidth result into %s: %s",
|
| 740 |
+
output_file,
|
| 741 |
+
e,
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
class DebugAutotuner(CachingAutotuner):
|
| 746 |
+
def __init__(self, *args, regex_filter="", **kwargs):
|
| 747 |
+
self.regex_filter = regex_filter
|
| 748 |
+
super().__init__(*args, **kwargs)
|
| 749 |
+
self.cached = None
|
| 750 |
+
|
| 751 |
+
def run(self, *args, grid, stream):
|
| 752 |
+
possible_names = _find_names(self)
|
| 753 |
+
kernel_name = f"{max(possible_names, key=len)}"
|
| 754 |
+
if not re.match(self.regex_filter, kernel_name):
|
| 755 |
+
return
|
| 756 |
+
super().run(*args, grid=grid, stream=stream)
|
| 757 |
+
(launcher,) = self.launchers
|
| 758 |
+
|
| 759 |
+
if self.cached is None:
|
| 760 |
+
ms = self.bench(launcher, *args, grid=grid)
|
| 761 |
+
num_in_out_ptrs = len(
|
| 762 |
+
[
|
| 763 |
+
arg_name
|
| 764 |
+
for arg_name in self.fn.arg_names
|
| 765 |
+
if arg_name.startswith("in_out_ptr")
|
| 766 |
+
]
|
| 767 |
+
)
|
| 768 |
+
num_gb = self.inductor_meta.get("kernel_num_gb", None)
|
| 769 |
+
if num_gb is None:
|
| 770 |
+
num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
|
| 771 |
+
gb_per_s = num_gb / (ms / 1e3)
|
| 772 |
+
self.cached = (ms, num_gb, gb_per_s, kernel_name)
|
| 773 |
+
else:
|
| 774 |
+
ms, num_gb, gb_per_s, kernel_name = self.cached
|
| 775 |
+
collected_calls.append((ms, num_gb, gb_per_s, kernel_name))
|
| 776 |
+
print(
|
| 777 |
+
create_bandwidth_info_str(ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}")
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
def hash_configs(configs: List[Config]):
|
| 782 |
+
"""
|
| 783 |
+
Hash used to check for changes in configurations
|
| 784 |
+
"""
|
| 785 |
+
hasher = hashlib.sha256()
|
| 786 |
+
for cfg in configs:
|
| 787 |
+
hasher.update(
|
| 788 |
+
f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
|
| 789 |
+
)
|
| 790 |
+
return hasher.hexdigest()
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def load_cached_autotuning(
|
| 794 |
+
best_config,
|
| 795 |
+
configs_hash: str,
|
| 796 |
+
configs: List[Config],
|
| 797 |
+
):
|
| 798 |
+
if best_config is None:
|
| 799 |
+
return None
|
| 800 |
+
if best_config.pop("configs_hash", None) != configs_hash:
|
| 801 |
+
return None
|
| 802 |
+
|
| 803 |
+
if config.coordinate_descent_tuning and best_config.pop("found_by_coordesc", False):
|
| 804 |
+
num_warps = best_config.pop("num_warps")
|
| 805 |
+
num_stages = best_config.pop("num_stages")
|
| 806 |
+
triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
|
| 807 |
+
triton_config.found_by_coordesc = True
|
| 808 |
+
return triton_config
|
| 809 |
+
|
| 810 |
+
matching_configs = [
|
| 811 |
+
cfg
|
| 812 |
+
for cfg in configs
|
| 813 |
+
if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
|
| 814 |
+
and cfg.num_warps == best_config.get("num_warps")
|
| 815 |
+
and cfg.num_stages == best_config.get("num_stages")
|
| 816 |
+
]
|
| 817 |
+
if len(matching_configs) != 1:
|
| 818 |
+
return None
|
| 819 |
+
|
| 820 |
+
return matching_configs[0]
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
def cached_autotune(
|
| 824 |
+
size_hints: Optional[List[int]],
|
| 825 |
+
configs: List[Config],
|
| 826 |
+
triton_meta,
|
| 827 |
+
heuristic_type,
|
| 828 |
+
filename=None,
|
| 829 |
+
inductor_meta=None,
|
| 830 |
+
custom_kernel=False,
|
| 831 |
+
):
|
| 832 |
+
"""
|
| 833 |
+
A copy of triton.autotune that calls our subclass. Our subclass
|
| 834 |
+
has additional debugging, error handling, and on-disk caching.
|
| 835 |
+
"""
|
| 836 |
+
configs = unique_configs(configs)
|
| 837 |
+
assert len(configs) == 1 or filename
|
| 838 |
+
save_cache_hook: Optional[Callable[[Any, Any], Any]]
|
| 839 |
+
inductor_meta = {} if inductor_meta is None else inductor_meta
|
| 840 |
+
|
| 841 |
+
# on disk caching logic and/or remote caching
|
| 842 |
+
if filename is not None and (len(configs) > 1 or config.coordinate_descent_tuning):
|
| 843 |
+
configs_hash = hash_configs(configs)
|
| 844 |
+
|
| 845 |
+
cache_filename = None
|
| 846 |
+
remote_cache = None
|
| 847 |
+
remote_cache_key = None
|
| 848 |
+
if config.use_autotune_local_cache:
|
| 849 |
+
cache_filename = os.path.splitext(filename)[0] + ".best_config"
|
| 850 |
+
if config.use_autotune_remote_cache or (
|
| 851 |
+
config.is_fbcode()
|
| 852 |
+
and torch._utils_internal.justknobs_check(
|
| 853 |
+
"pytorch/autotune_remote_cache:enable"
|
| 854 |
+
)
|
| 855 |
+
):
|
| 856 |
+
backend_hash = inductor_meta.get("backend_hash", None)
|
| 857 |
+
if backend_hash is not None:
|
| 858 |
+
key = backend_hash + configs_hash + "autotune-best-config"
|
| 859 |
+
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
| 860 |
+
|
| 861 |
+
try:
|
| 862 |
+
if config.is_fbcode():
|
| 863 |
+
remote_cache = (
|
| 864 |
+
triton.runtime.fb_memcache.FbMemcacheRemoteCacheBackend(
|
| 865 |
+
key, is_autotune=True
|
| 866 |
+
)
|
| 867 |
+
)
|
| 868 |
+
else:
|
| 869 |
+
remote_cache = triton.runtime.cache.RedisRemoteCacheBackend(key)
|
| 870 |
+
except Exception:
|
| 871 |
+
remote_cache = None
|
| 872 |
+
log.warning("Unable to create a remote cache", exc_info=True)
|
| 873 |
+
# we already sha256 hash the source contents
|
| 874 |
+
remote_cache_key = os.path.basename(filename)
|
| 875 |
+
else:
|
| 876 |
+
log.debug(
|
| 877 |
+
"backend_hash is not passed on the inductor_meta, unable to use autotune remote cache"
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
best_config = None
|
| 881 |
+
if cache_filename is not None and os.path.exists(cache_filename):
|
| 882 |
+
with open(cache_filename) as fd:
|
| 883 |
+
best_config = json.loads(fd.read())
|
| 884 |
+
elif remote_cache is not None and remote_cache_key is not None:
|
| 885 |
+
cache_outs = remote_cache.get([remote_cache_key])
|
| 886 |
+
cache_out = cache_outs.get(remote_cache_key, None)
|
| 887 |
+
best_config = json.loads(cache_out) if cache_out else None
|
| 888 |
+
|
| 889 |
+
best_config = load_cached_autotuning(best_config, configs_hash, configs)
|
| 890 |
+
if best_config:
|
| 891 |
+
configs = [best_config]
|
| 892 |
+
|
| 893 |
+
def save_cache_hook(cfg, found_by_coordesc=False):
|
| 894 |
+
data = json.dumps(
|
| 895 |
+
{
|
| 896 |
+
**cfg.kwargs,
|
| 897 |
+
"num_warps": cfg.num_warps,
|
| 898 |
+
"num_stages": cfg.num_stages,
|
| 899 |
+
"configs_hash": configs_hash,
|
| 900 |
+
"found_by_coordesc": found_by_coordesc,
|
| 901 |
+
}
|
| 902 |
+
)
|
| 903 |
+
if cache_filename is not None:
|
| 904 |
+
with open(cache_filename, "w") as fd:
|
| 905 |
+
fd.write(data)
|
| 906 |
+
if remote_cache is not None and remote_cache_key is not None:
|
| 907 |
+
remote_cache.put(remote_cache_key, data)
|
| 908 |
+
|
| 909 |
+
if log.isEnabledFor(logging.DEBUG):
|
| 910 |
+
type_str = "coordesc" if found_by_coordesc else "heuristic"
|
| 911 |
+
log.debug("Save %s tuning result to %s", type_str, cache_filename)
|
| 912 |
+
|
| 913 |
+
else:
|
| 914 |
+
save_cache_hook = None
|
| 915 |
+
|
| 916 |
+
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
|
| 917 |
+
|
| 918 |
+
def decorator(fn):
|
| 919 |
+
# Remove XBLOCK from config if it's not a function argument.
|
| 920 |
+
# This way, coordinate descent tuning will not try to tune it.
|
| 921 |
+
#
|
| 922 |
+
# Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1.
|
| 923 |
+
import inspect
|
| 924 |
+
|
| 925 |
+
if "XBLOCK" not in inspect.signature(fn.fn).parameters:
|
| 926 |
+
for tconfig in configs:
|
| 927 |
+
if "XBLOCK" in tconfig.kwargs:
|
| 928 |
+
assert tconfig.kwargs["XBLOCK"] == 1
|
| 929 |
+
tconfig.kwargs.pop("XBLOCK")
|
| 930 |
+
|
| 931 |
+
if config.profile_bandwidth:
|
| 932 |
+
return DebugAutotuner(
|
| 933 |
+
fn,
|
| 934 |
+
triton_meta=triton_meta,
|
| 935 |
+
inductor_meta=inductor_meta,
|
| 936 |
+
regex_filter=config.profile_bandwidth_regex,
|
| 937 |
+
configs=configs,
|
| 938 |
+
save_cache_hook=save_cache_hook,
|
| 939 |
+
mutated_arg_names=mutated_arg_names,
|
| 940 |
+
heuristic_type=heuristic_type,
|
| 941 |
+
size_hints=size_hints,
|
| 942 |
+
custom_kernel=custom_kernel,
|
| 943 |
+
)
|
| 944 |
+
return CachingAutotuner(
|
| 945 |
+
fn,
|
| 946 |
+
triton_meta=triton_meta,
|
| 947 |
+
inductor_meta=inductor_meta,
|
| 948 |
+
configs=configs,
|
| 949 |
+
save_cache_hook=save_cache_hook,
|
| 950 |
+
mutated_arg_names=mutated_arg_names,
|
| 951 |
+
heuristic_type=heuristic_type,
|
| 952 |
+
size_hints=size_hints,
|
| 953 |
+
custom_kernel=custom_kernel,
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
return decorator
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
def unique_configs(configs: List[Config]):
|
| 960 |
+
"""Remove duplicate configurations"""
|
| 961 |
+
seen = set()
|
| 962 |
+
pruned_configs = []
|
| 963 |
+
|
| 964 |
+
for cfg in configs:
|
| 965 |
+
key = triton_config_to_hashable(cfg)
|
| 966 |
+
if key not in seen:
|
| 967 |
+
seen.add(key)
|
| 968 |
+
pruned_configs.append(cfg)
|
| 969 |
+
return pruned_configs
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
|
| 973 |
+
for numel, label in zip((xnumel, ynumel, znumel), "XYZ"):
|
| 974 |
+
if numel is None:
|
| 975 |
+
continue
|
| 976 |
+
block = cfg[f"{label}BLOCK"]
|
| 977 |
+
if numel == 1:
|
| 978 |
+
assert block == 1, (
|
| 979 |
+
f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
|
| 980 |
+
f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
|
| 981 |
+
)
|
| 982 |
+
max_block = config.triton.max_block[label]
|
| 983 |
+
max_block_str = f'config.triton.max_block["{label}"]'
|
| 984 |
+
assert max_block % block == 0, (
|
| 985 |
+
f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
|
| 986 |
+
f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
def triton_config(
|
| 991 |
+
size_hints,
|
| 992 |
+
x,
|
| 993 |
+
y=None,
|
| 994 |
+
z=None,
|
| 995 |
+
num_stages=1,
|
| 996 |
+
num_elements_per_warp=256,
|
| 997 |
+
min_elem_per_thread=0,
|
| 998 |
+
) -> Config:
|
| 999 |
+
"""
|
| 1000 |
+
Construct a pointwise triton config with some adjustment heuristics
|
| 1001 |
+
based on size_hints. Size_hints is a tuple of numels in each tile
|
| 1002 |
+
dimension and will be rounded up to the nearest power of 2.
|
| 1003 |
+
|
| 1004 |
+
num_elements_per_warp is a suggestion for controlling how many warps
|
| 1005 |
+
the triton config should contain. e.g.: if x=16, y=8, z=4 then
|
| 1006 |
+
num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128,
|
| 1007 |
+
we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's
|
| 1008 |
+
just a suggestion, and sometimes other adjustment heuristics will
|
| 1009 |
+
override the num_elements_per_warp.
|
| 1010 |
+
|
| 1011 |
+
min_elem_per_thread controls the minimum number of elements
|
| 1012 |
+
processed by each thread. It's always enforced.
|
| 1013 |
+
"""
|
| 1014 |
+
# Ideally we want to read this from some device config
|
| 1015 |
+
|
| 1016 |
+
# for a 2d size_hints [a, b], a should be mapped to YBLOCK rather than XBLOCK
|
| 1017 |
+
size_hints = list(reversed(size_hints))
|
| 1018 |
+
|
| 1019 |
+
maxGridSize = [2147483647, 65535, 65535]
|
| 1020 |
+
|
| 1021 |
+
target = conditional_product(x, y, z)
|
| 1022 |
+
if conditional_product(*size_hints) < target:
|
| 1023 |
+
target //= 8
|
| 1024 |
+
|
| 1025 |
+
# shrink sizes to size hints
|
| 1026 |
+
x = min(x, size_hints[0])
|
| 1027 |
+
if y:
|
| 1028 |
+
y = min(y, size_hints[1])
|
| 1029 |
+
if z:
|
| 1030 |
+
z = min(z, size_hints[2])
|
| 1031 |
+
|
| 1032 |
+
# if we are below original block size, scale up where we can;
|
| 1033 |
+
# or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
|
| 1034 |
+
while x < min(size_hints[0], config.triton.max_block["X"]) and (
|
| 1035 |
+
x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target
|
| 1036 |
+
):
|
| 1037 |
+
x *= 2
|
| 1038 |
+
while (
|
| 1039 |
+
y
|
| 1040 |
+
and y < min(size_hints[1], config.triton.max_block["Y"])
|
| 1041 |
+
and (
|
| 1042 |
+
y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target
|
| 1043 |
+
)
|
| 1044 |
+
):
|
| 1045 |
+
y *= 2
|
| 1046 |
+
while (
|
| 1047 |
+
z
|
| 1048 |
+
and z < min(size_hints[2], config.triton.max_block["Z"])
|
| 1049 |
+
and (
|
| 1050 |
+
z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target
|
| 1051 |
+
)
|
| 1052 |
+
):
|
| 1053 |
+
z *= 2
|
| 1054 |
+
|
| 1055 |
+
num_warps = next_power_of_2(
|
| 1056 |
+
min(max(conditional_product(x, y, z) // num_elements_per_warp, 1), 8)
|
| 1057 |
+
)
|
| 1058 |
+
# we are going to arrive at 2 warps only if bs was too small due to
|
| 1059 |
+
# numel being too small. However to workaround some ptx bugs we still
|
| 1060 |
+
# want at least 4 warps if there's enough elements per thread
|
| 1061 |
+
# given that this is a rare situation, don't expect this to affect perf
|
| 1062 |
+
# in general
|
| 1063 |
+
# see https://github.com/pytorch/pytorch/pull/97950
|
| 1064 |
+
num_warps = max(num_warps, 4) if conditional_product(x, y, z) >= 128 else num_warps
|
| 1065 |
+
xnumel = size_hints[0]
|
| 1066 |
+
ynumel = size_hints[1] if y else None
|
| 1067 |
+
znumel = size_hints[2] if z else None
|
| 1068 |
+
|
| 1069 |
+
# Increase x to satisfy min_elem_per_thread requirements.
|
| 1070 |
+
block_size = max(
|
| 1071 |
+
conditional_product(x, y, z),
|
| 1072 |
+
min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps,
|
| 1073 |
+
)
|
| 1074 |
+
x *= math.ceil(block_size / conditional_product(x, y, z))
|
| 1075 |
+
|
| 1076 |
+
cfg = {"XBLOCK": x}
|
| 1077 |
+
if y:
|
| 1078 |
+
cfg["YBLOCK"] = y
|
| 1079 |
+
if z:
|
| 1080 |
+
cfg["ZBLOCK"] = z
|
| 1081 |
+
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
|
| 1082 |
+
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
def triton_config_reduction(size_hints, x, r, num_stages=1, num_warps=None) -> Config:
|
| 1086 |
+
"""
|
| 1087 |
+
Construct a reduction triton config with some adjustment heuristics
|
| 1088 |
+
based on size_hints. Size_hints is a tuple of numels in each tile
|
| 1089 |
+
dimension and will be rounded up to the nearest power of 2.
|
| 1090 |
+
"""
|
| 1091 |
+
|
| 1092 |
+
target = conditional_product(x, r)
|
| 1093 |
+
if conditional_product(*size_hints) < target:
|
| 1094 |
+
target //= 8
|
| 1095 |
+
|
| 1096 |
+
# shrink sizes to size hints
|
| 1097 |
+
x = min(x, size_hints[0])
|
| 1098 |
+
r = min(r, size_hints[1])
|
| 1099 |
+
|
| 1100 |
+
# if we are below original block size, scale up where we can
|
| 1101 |
+
while x < size_hints[0] and conditional_product(x, r) < target:
|
| 1102 |
+
x *= 2
|
| 1103 |
+
while r < size_hints[1] and conditional_product(x, r) < target:
|
| 1104 |
+
r *= 2
|
| 1105 |
+
|
| 1106 |
+
cfg = {"XBLOCK": x, "RBLOCK": r}
|
| 1107 |
+
if num_warps is None:
|
| 1108 |
+
num_warps = conditional_product(x, r) // 128
|
| 1109 |
+
num_warps = next_power_of_2(min(max(num_warps, 2), 8))
|
| 1110 |
+
check_config(cfg, xnumel=size_hints[0])
|
| 1111 |
+
assert (
|
| 1112 |
+
r <= config.triton.max_block["R"]
|
| 1113 |
+
), f"increase config.triton.MAX_BLOCK['r'] to {r}"
|
| 1114 |
+
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1):
|
| 1118 |
+
"""
|
| 1119 |
+
Construct a tile reduction triton config with some adjustment
|
| 1120 |
+
heuristics based on size_hints. Size_hints is a tuple of numels in
|
| 1121 |
+
each tile dimension and will be rounded up to the nearest power of 2.
|
| 1122 |
+
"""
|
| 1123 |
+
|
| 1124 |
+
target = conditional_product(x, y, r)
|
| 1125 |
+
if conditional_product(*size_hints) < target:
|
| 1126 |
+
target //= 8
|
| 1127 |
+
|
| 1128 |
+
# shrink sizes to size hints
|
| 1129 |
+
x = min(x, size_hints[0])
|
| 1130 |
+
y = min(y, size_hints[1])
|
| 1131 |
+
r = min(r, size_hints[2])
|
| 1132 |
+
|
| 1133 |
+
# if we are below original block size, scale up where we can
|
| 1134 |
+
while x < size_hints[0] and conditional_product(x, y, r) < target:
|
| 1135 |
+
x *= 2
|
| 1136 |
+
while r < size_hints[2] and conditional_product(x, y, r) < target:
|
| 1137 |
+
r *= 2
|
| 1138 |
+
while y < size_hints[1] and conditional_product(x, y, r) < target:
|
| 1139 |
+
y *= 2
|
| 1140 |
+
|
| 1141 |
+
cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
|
| 1142 |
+
num_warps = next_power_of_2(min(max(conditional_product(x, y, r) // 256, 1), 8))
|
| 1143 |
+
check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1])
|
| 1144 |
+
assert (
|
| 1145 |
+
r <= config.triton.max_block["R"]
|
| 1146 |
+
), f"increase config.triton.MAX_BLOCK['r'] to {r}"
|
| 1147 |
+
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
| 1148 |
+
|
| 1149 |
+
|
| 1150 |
+
def pointwise(
|
| 1151 |
+
size_hints,
|
| 1152 |
+
triton_meta,
|
| 1153 |
+
tile_hint=None,
|
| 1154 |
+
filename=None,
|
| 1155 |
+
min_elem_per_thread=0,
|
| 1156 |
+
inductor_meta=None,
|
| 1157 |
+
):
|
| 1158 |
+
"""
|
| 1159 |
+
Construct @triton.heuristics() based on size_hints.
|
| 1160 |
+
"""
|
| 1161 |
+
inductor_meta = {} if inductor_meta is None else inductor_meta
|
| 1162 |
+
assert not inductor_meta.get("no_x_dim")
|
| 1163 |
+
|
| 1164 |
+
numel = functools.reduce(operator.mul, size_hints)
|
| 1165 |
+
bs = max(256, min(numel // 128, 1024))
|
| 1166 |
+
|
| 1167 |
+
hinted_configs = autotune_hints_to_configs(
|
| 1168 |
+
inductor_meta.get("autotune_hints", set()), size_hints, bs
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
triton_config_with_settings = functools.partial(
|
| 1172 |
+
triton_config, min_elem_per_thread=min_elem_per_thread
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
if len(size_hints) == 1:
|
| 1176 |
+
if disable_pointwise_autotuning() and not (
|
| 1177 |
+
config.max_autotune or config.max_autotune_pointwise
|
| 1178 |
+
):
|
| 1179 |
+
return cached_autotune(
|
| 1180 |
+
size_hints,
|
| 1181 |
+
[triton_config_with_settings(size_hints, bs)],
|
| 1182 |
+
triton_meta=triton_meta,
|
| 1183 |
+
inductor_meta=inductor_meta,
|
| 1184 |
+
heuristic_type=HeuristicType.POINTWISE,
|
| 1185 |
+
filename=filename,
|
| 1186 |
+
)
|
| 1187 |
+
else:
|
| 1188 |
+
return cached_autotune(
|
| 1189 |
+
size_hints,
|
| 1190 |
+
[
|
| 1191 |
+
triton_config_with_settings(
|
| 1192 |
+
size_hints, bs, num_elements_per_warp=256
|
| 1193 |
+
),
|
| 1194 |
+
triton_config_with_settings(
|
| 1195 |
+
size_hints, bs // 2, num_elements_per_warp=64
|
| 1196 |
+
),
|
| 1197 |
+
*hinted_configs,
|
| 1198 |
+
],
|
| 1199 |
+
triton_meta=triton_meta,
|
| 1200 |
+
inductor_meta=inductor_meta,
|
| 1201 |
+
heuristic_type=HeuristicType.POINTWISE,
|
| 1202 |
+
filename=filename,
|
| 1203 |
+
)
|
| 1204 |
+
if len(size_hints) == 2:
|
| 1205 |
+
if (disable_pointwise_autotuning() or tile_hint == TileHint.SQUARE) and not (
|
| 1206 |
+
config.max_autotune or config.max_autotune_pointwise
|
| 1207 |
+
):
|
| 1208 |
+
return cached_autotune(
|
| 1209 |
+
size_hints,
|
| 1210 |
+
[triton_config_with_settings(size_hints, 32, 32)],
|
| 1211 |
+
triton_meta=triton_meta,
|
| 1212 |
+
inductor_meta=inductor_meta,
|
| 1213 |
+
heuristic_type=HeuristicType.POINTWISE,
|
| 1214 |
+
filename=filename,
|
| 1215 |
+
)
|
| 1216 |
+
return cached_autotune(
|
| 1217 |
+
size_hints,
|
| 1218 |
+
[
|
| 1219 |
+
triton_config_with_settings(size_hints, 32, 32),
|
| 1220 |
+
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
|
| 1221 |
+
triton_config_with_settings(size_hints, 256, 16),
|
| 1222 |
+
triton_config_with_settings(size_hints, 16, 256),
|
| 1223 |
+
triton_config_with_settings(size_hints, bs, 1),
|
| 1224 |
+
triton_config_with_settings(size_hints, 1, bs),
|
| 1225 |
+
*hinted_configs,
|
| 1226 |
+
],
|
| 1227 |
+
triton_meta=triton_meta,
|
| 1228 |
+
inductor_meta=inductor_meta,
|
| 1229 |
+
filename=filename,
|
| 1230 |
+
heuristic_type=HeuristicType.POINTWISE,
|
| 1231 |
+
)
|
| 1232 |
+
if len(size_hints) == 3:
|
| 1233 |
+
if disable_pointwise_autotuning():
|
| 1234 |
+
return cached_autotune(
|
| 1235 |
+
size_hints,
|
| 1236 |
+
[triton_config_with_settings(size_hints, 16, 16, 16)],
|
| 1237 |
+
triton_meta=triton_meta,
|
| 1238 |
+
inductor_meta=inductor_meta,
|
| 1239 |
+
heuristic_type=HeuristicType.POINTWISE,
|
| 1240 |
+
filename=filename,
|
| 1241 |
+
)
|
| 1242 |
+
return cached_autotune(
|
| 1243 |
+
size_hints,
|
| 1244 |
+
[
|
| 1245 |
+
triton_config_with_settings(size_hints, 16, 16, 16),
|
| 1246 |
+
triton_config_with_settings(size_hints, 64, 8, 8),
|
| 1247 |
+
triton_config_with_settings(size_hints, 8, 64, 8),
|
| 1248 |
+
triton_config_with_settings(size_hints, 8, 8, 64),
|
| 1249 |
+
triton_config_with_settings(size_hints, bs, 1, 1),
|
| 1250 |
+
triton_config_with_settings(size_hints, 1, bs, 1),
|
| 1251 |
+
triton_config_with_settings(size_hints, 1, 1, bs),
|
| 1252 |
+
*hinted_configs,
|
| 1253 |
+
],
|
| 1254 |
+
triton_meta=triton_meta,
|
| 1255 |
+
inductor_meta=inductor_meta,
|
| 1256 |
+
filename=filename,
|
| 1257 |
+
heuristic_type=HeuristicType.POINTWISE,
|
| 1258 |
+
)
|
| 1259 |
+
raise NotImplementedError(f"size_hints: {size_hints}")
|
| 1260 |
+
|
| 1261 |
+
|
| 1262 |
+
def _reduction_configs(
|
| 1263 |
+
*, size_hints: List[int], inductor_meta: Dict[str, Any]
|
| 1264 |
+
) -> List[Config]:
|
| 1265 |
+
reduction_hint = inductor_meta.get("reduction_hint", None)
|
| 1266 |
+
assert len(size_hints) == 2
|
| 1267 |
+
rnumel = size_hints[-1]
|
| 1268 |
+
|
| 1269 |
+
contiguous_config = triton_config_reduction(
|
| 1270 |
+
size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048)
|
| 1271 |
+
)
|
| 1272 |
+
outer_config = triton_config_reduction(size_hints, 64, 8)
|
| 1273 |
+
tiny_config = triton_config_reduction(
|
| 1274 |
+
size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, 2048)
|
| 1275 |
+
)
|
| 1276 |
+
if config.max_autotune or config.max_autotune_pointwise:
|
| 1277 |
+
pass # skip all these cases
|
| 1278 |
+
elif reduction_hint == ReductionHint.INNER:
|
| 1279 |
+
return [contiguous_config]
|
| 1280 |
+
elif reduction_hint == ReductionHint.OUTER:
|
| 1281 |
+
return [outer_config]
|
| 1282 |
+
elif reduction_hint == ReductionHint.OUTER_TINY:
|
| 1283 |
+
return [tiny_config]
|
| 1284 |
+
if disable_pointwise_autotuning():
|
| 1285 |
+
return [triton_config_reduction(size_hints, 32, 128)]
|
| 1286 |
+
return [
|
| 1287 |
+
contiguous_config,
|
| 1288 |
+
outer_config,
|
| 1289 |
+
tiny_config,
|
| 1290 |
+
triton_config_reduction(size_hints, 64, 64),
|
| 1291 |
+
triton_config_reduction(size_hints, 8, 512),
|
| 1292 |
+
# halve the XBLOCK/RBLOCK compared to outer_config
|
| 1293 |
+
# TODO: this may only be beneficial when each iteration of the reduction
|
| 1294 |
+
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
|
| 1295 |
+
triton_config_reduction(size_hints, 64, 4, num_warps=8),
|
| 1296 |
+
]
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
def reduction(
|
| 1300 |
+
size_hints,
|
| 1301 |
+
reduction_hint=False,
|
| 1302 |
+
triton_meta=None,
|
| 1303 |
+
filename=None,
|
| 1304 |
+
inductor_meta=None,
|
| 1305 |
+
):
|
| 1306 |
+
"""args to @triton.heuristics()"""
|
| 1307 |
+
inductor_meta = {} if inductor_meta is None else inductor_meta
|
| 1308 |
+
inductor_meta["reduction_hint"] = reduction_hint
|
| 1309 |
+
if inductor_meta.get("no_x_dim"):
|
| 1310 |
+
size_hints = [1, *size_hints[1:]]
|
| 1311 |
+
|
| 1312 |
+
assert triton_meta is not None
|
| 1313 |
+
rnumel = size_hints[-1]
|
| 1314 |
+
if len(size_hints) != 2:
|
| 1315 |
+
raise NotImplementedError(f"size_hints: {size_hints}")
|
| 1316 |
+
|
| 1317 |
+
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
|
| 1318 |
+
return cached_autotune(
|
| 1319 |
+
size_hints,
|
| 1320 |
+
configs=configs,
|
| 1321 |
+
triton_meta=triton_meta,
|
| 1322 |
+
inductor_meta=inductor_meta,
|
| 1323 |
+
heuristic_type=HeuristicType.REDUCTION,
|
| 1324 |
+
filename=filename,
|
| 1325 |
+
)
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
def persistent_reduction(
|
| 1329 |
+
size_hints,
|
| 1330 |
+
reduction_hint=False,
|
| 1331 |
+
triton_meta=None,
|
| 1332 |
+
filename=None,
|
| 1333 |
+
inductor_meta=None,
|
| 1334 |
+
):
|
| 1335 |
+
inductor_meta = {} if inductor_meta is None else inductor_meta
|
| 1336 |
+
inductor_meta["reduction_hint"] = reduction_hint
|
| 1337 |
+
if inductor_meta.get("no_x_dim"):
|
| 1338 |
+
size_hints = [1, *size_hints[1:]]
|
| 1339 |
+
|
| 1340 |
+
xnumel, rnumel = size_hints
|
| 1341 |
+
|
| 1342 |
+
configs = [
|
| 1343 |
+
triton_config_reduction(size_hints, xblock, rnumel)
|
| 1344 |
+
for xblock in (1, 8, 32, 128)
|
| 1345 |
+
if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel)
|
| 1346 |
+
]
|
| 1347 |
+
|
| 1348 |
+
# TODO(jansel): we should be able to improve these heuristics
|
| 1349 |
+
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
|
| 1350 |
+
configs = configs[:1]
|
| 1351 |
+
elif reduction_hint == ReductionHint.OUTER:
|
| 1352 |
+
configs = configs[-1:]
|
| 1353 |
+
elif reduction_hint == ReductionHint.OUTER_TINY:
|
| 1354 |
+
configs = [
|
| 1355 |
+
triton_config_reduction(
|
| 1356 |
+
size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
|
| 1357 |
+
)
|
| 1358 |
+
]
|
| 1359 |
+
for c in configs:
|
| 1360 |
+
# we don't need RBLOCK for persistent reduction
|
| 1361 |
+
c.kwargs.pop("RBLOCK")
|
| 1362 |
+
|
| 1363 |
+
if disable_pointwise_autotuning():
|
| 1364 |
+
configs = configs[:1]
|
| 1365 |
+
|
| 1366 |
+
return cached_autotune(
|
| 1367 |
+
size_hints,
|
| 1368 |
+
configs,
|
| 1369 |
+
triton_meta=triton_meta,
|
| 1370 |
+
inductor_meta=inductor_meta,
|
| 1371 |
+
filename=filename,
|
| 1372 |
+
heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
|
| 1373 |
+
)
|
| 1374 |
+
|
| 1375 |
+
|
| 1376 |
+
def split_scan(
|
| 1377 |
+
size_hints,
|
| 1378 |
+
reduction_hint=False,
|
| 1379 |
+
triton_meta=None,
|
| 1380 |
+
filename=None,
|
| 1381 |
+
inductor_meta=None,
|
| 1382 |
+
):
|
| 1383 |
+
"""Heuristic for TritonSplitScanKernel"""
|
| 1384 |
+
inductor_meta = {} if inductor_meta is None else inductor_meta
|
| 1385 |
+
inductor_meta["reduction_hint"] = reduction_hint
|
| 1386 |
+
if inductor_meta.get("no_x_dim"):
|
| 1387 |
+
size_hints = [1, *size_hints[1:]]
|
| 1388 |
+
|
| 1389 |
+
assert triton_meta is not None
|
| 1390 |
+
rnumel = size_hints[-1]
|
| 1391 |
+
if len(size_hints) != 2:
|
| 1392 |
+
raise NotImplementedError(f"size_hints: {size_hints}")
|
| 1393 |
+
|
| 1394 |
+
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
|
| 1395 |
+
|
| 1396 |
+
# Fixup configs to enforce the minimum RBLOCK size
|
| 1397 |
+
min_rblock = config.triton.min_split_scan_rblock
|
| 1398 |
+
for cfg in configs:
|
| 1399 |
+
if cfg.kwargs["RBLOCK"] < min_rblock:
|
| 1400 |
+
cfg.kwargs["RBLOCK"] = min_rblock
|
| 1401 |
+
|
| 1402 |
+
return cached_autotune(
|
| 1403 |
+
size_hints,
|
| 1404 |
+
configs=configs,
|
| 1405 |
+
triton_meta=triton_meta,
|
| 1406 |
+
inductor_meta=inductor_meta,
|
| 1407 |
+
heuristic_type=HeuristicType.SPLIT_SCAN,
|
| 1408 |
+
filename=filename,
|
| 1409 |
+
)
|
| 1410 |
+
|
| 1411 |
+
|
| 1412 |
+
def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None):
|
| 1413 |
+
"""
|
| 1414 |
+
Compile a triton template
|
| 1415 |
+
"""
|
| 1416 |
+
return cached_autotune(
|
| 1417 |
+
None,
|
| 1418 |
+
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
|
| 1419 |
+
triton_meta=triton_meta,
|
| 1420 |
+
inductor_meta=inductor_meta,
|
| 1421 |
+
heuristic_type=HeuristicType.TEMPLATE,
|
| 1422 |
+
filename=filename,
|
| 1423 |
+
)
|
| 1424 |
+
|
| 1425 |
+
|
| 1426 |
+
def user_autotune(
|
| 1427 |
+
configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
|
| 1428 |
+
):
|
| 1429 |
+
"""
|
| 1430 |
+
Compile a user defined triton kernel
|
| 1431 |
+
"""
|
| 1432 |
+
defaults = inspect.signature(triton.Config).parameters
|
| 1433 |
+
default_num_stages = defaults["num_stages"].default
|
| 1434 |
+
default_num_warps = defaults["num_warps"].default
|
| 1435 |
+
|
| 1436 |
+
if len(configs) == 0:
|
| 1437 |
+
configs = [
|
| 1438 |
+
triton.Config(
|
| 1439 |
+
{}, num_stages=default_num_stages, num_warps=default_num_warps
|
| 1440 |
+
)
|
| 1441 |
+
]
|
| 1442 |
+
else:
|
| 1443 |
+
configs = [
|
| 1444 |
+
triton.Config(
|
| 1445 |
+
c.get("kwargs", {}),
|
| 1446 |
+
num_stages=c.get("num_stages", default_num_stages),
|
| 1447 |
+
num_warps=c.get("num_warps", default_num_warps),
|
| 1448 |
+
)
|
| 1449 |
+
for c in configs
|
| 1450 |
+
]
|
| 1451 |
+
|
| 1452 |
+
return cached_autotune(
|
| 1453 |
+
None,
|
| 1454 |
+
configs,
|
| 1455 |
+
triton_meta=triton_meta,
|
| 1456 |
+
heuristic_type=HeuristicType.USER_AUTOTUNE,
|
| 1457 |
+
filename=filename,
|
| 1458 |
+
inductor_meta=inductor_meta,
|
| 1459 |
+
custom_kernel=custom_kernel,
|
| 1460 |
+
)
|
| 1461 |
+
|
| 1462 |
+
|
| 1463 |
+
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
| 1464 |
+
"""
|
| 1465 |
+
Compile a triton foreach kernel
|
| 1466 |
+
"""
|
| 1467 |
+
return cached_autotune(
|
| 1468 |
+
None,
|
| 1469 |
+
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
| 1470 |
+
triton_meta=triton_meta,
|
| 1471 |
+
inductor_meta=inductor_meta,
|
| 1472 |
+
heuristic_type=HeuristicType.TEMPLATE,
|
| 1473 |
+
filename=filename,
|
| 1474 |
+
)
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
def grid(*numels):
|
| 1478 |
+
"""Helper function to compute triton grids"""
|
| 1479 |
+
if len(numels) == 1:
|
| 1480 |
+
xnumel, ynumel, znumel = numels[0], None, None
|
| 1481 |
+
elif len(numels) == 2:
|
| 1482 |
+
xnumel, ynumel, znumel = numels[1], numels[0], None
|
| 1483 |
+
elif len(numels) == 3:
|
| 1484 |
+
xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
|
| 1485 |
+
else:
|
| 1486 |
+
raise AssertionError(f"invalid size for numels {len(numels)}")
|
| 1487 |
+
|
| 1488 |
+
def get_grid_dim(numel, block):
|
| 1489 |
+
if numel is None:
|
| 1490 |
+
return 1
|
| 1491 |
+
if block is None:
|
| 1492 |
+
return numel
|
| 1493 |
+
return ceildiv(numel, block)
|
| 1494 |
+
|
| 1495 |
+
max_grid_dims = config.triton.max_tiles
|
| 1496 |
+
|
| 1497 |
+
def grid_fn(meta):
|
| 1498 |
+
x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1))
|
| 1499 |
+
y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None))
|
| 1500 |
+
|
| 1501 |
+
MAX_Y_GRID = get_max_y_grid()
|
| 1502 |
+
if znumel is None and max_grid_dims <= 2:
|
| 1503 |
+
div = ceildiv(y_grid, MAX_Y_GRID)
|
| 1504 |
+
y_grid = y_grid // div
|
| 1505 |
+
z_grid = div
|
| 1506 |
+
else:
|
| 1507 |
+
z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None))
|
| 1508 |
+
torch._check(
|
| 1509 |
+
y_grid <= MAX_Y_GRID,
|
| 1510 |
+
lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue",
|
| 1511 |
+
)
|
| 1512 |
+
|
| 1513 |
+
return (
|
| 1514 |
+
x_grid,
|
| 1515 |
+
y_grid,
|
| 1516 |
+
z_grid,
|
| 1517 |
+
)
|
| 1518 |
+
|
| 1519 |
+
return grid_fn
|
| 1520 |
+
|
| 1521 |
+
|
| 1522 |
+
def split_scan_grid(xnumel, rnumel):
|
| 1523 |
+
def grid_fn(meta):
|
| 1524 |
+
assert meta.get("XBLOCK", 1) == 1
|
| 1525 |
+
return (ceildiv(rnumel, meta.get("RBLOCK", 1)), xnumel, 1)
|
| 1526 |
+
|
| 1527 |
+
return grid_fn
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc
ADDED
|
Binary file (1.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc
ADDED
|
Binary file (550 Bytes). View file
|
|
|