| import ast |
| import contextlib |
| import inspect |
| import threading |
| from collections.abc import Generator, Iterable |
| from typing import Any, Callable, Optional, Union |
|
|
| from torch.utils._exposed_in import exposed_in |
|
|
| from .custom_ops import custom_op, CustomOpDef |
| from .infer_schema import infer_schema |
|
|
|
|
| triton_ops_to_kernels: dict[str, list[object]] = {} |
|
|
|
|
| def get_triton_kernels_for_op(name: str) -> list[object]: |
| return triton_ops_to_kernels.get(name, []) |
|
|
|
|
| def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]: |
| """ |
| Inspect the source of an arbitrary callable passed to torch._library.triton_op, |
| and grab all of the triton kernels that are wrapped inside of it. |
| |
| TODO: This check is best effort. It does *not* handle the case where the triton |
| kernel is hidden behind recursive function calls. |
| """ |
|
|
| def find_triton_kernels(fn: Callable[..., Any]) -> list[object]: |
| try: |
| source = inspect.getsource(fn) |
| except (OSError, TypeError): |
| return [] |
|
|
| from torch._inductor.utils import IndentedBuffer |
|
|
| buffer = IndentedBuffer() |
| buffer.splice(source, strip=True) |
| tree = ast.parse(buffer.getrawvalue()) |
|
|
| |
| class Visitor(ast.NodeVisitor): |
| def __init__(self) -> None: |
| self.triton_kernels: list[Any] = [] |
|
|
| def visit_Call(self, node: ast.Call) -> None: |
| triton_func_names = ("capture_triton", "wrap_triton") |
| if isinstance(node.func, ast.Attribute): |
| attr = node.func |
| if ( |
| isinstance(attr.value, ast.Attribute) |
| and isinstance(attr.value.value, ast.Name) |
| and attr.value.value.id == "torch" |
| and attr.value.attr == "_library" |
| and attr.attr in triton_func_names |
| ): |
| if node.args and isinstance(node.args[0], ast.Name): |
| self.triton_kernels.append(node.args[0].id) |
|
|
| |
| |
| elif isinstance(node.func, ast.Name): |
| if node.func.id in triton_func_names: |
| if node.args and isinstance(node.args[0], ast.Name): |
| self.triton_kernels.append(node.args[0].id) |
|
|
| self.generic_visit(node) |
|
|
| collector = Visitor() |
| collector.visit(tree) |
| closure_vars = inspect.getclosurevars(fn) |
| resolved = [] |
| |
| for name in collector.triton_kernels: |
| if name in closure_vars.nonlocals: |
| resolved.append(closure_vars.nonlocals[name]) |
| elif name in closure_vars.globals: |
| resolved.append(closure_vars.globals[name]) |
| elif name in closure_vars.builtins: |
| resolved.append(closure_vars.builtins[name]) |
| return resolved |
|
|
| return find_triton_kernels(fn) |
|
|
|
|
| @exposed_in("torch.library") |
| def triton_op( |
| name: str, |
| fn: Optional[Callable] = None, |
| /, |
| *, |
| mutates_args: Union[str, Iterable[str]], |
| schema: Optional[str] = None, |
| ) -> Callable: |
| """Create a custom operator whose implementation is backed by 1+ triton kernels. |
| |
| This is a more structured way of using triton kernels with PyTorch. |
| Prefer using triton kernels with no ``torch.library`` custom operator wrappers |
| (like :func:`torch.library.custom_op`, :func:`torch.library.triton_op`) because |
| that is simpler; |
| only use :func:`torch.library.custom_op`/:func:`torch.library.triton_op` if you |
| want to create an operator that behaves like PyTorch built-in operators. |
| For example, you may use a ``torch.library`` wrapper API to define the |
| behavior of the triton kernel when passed a tensor subclass or under |
| a TorchDispatchMode. |
| |
| Use :func:`torch.library.triton_op` instead of :func:`torch.library.custom_op` |
| when the implementation |
| consists of 1+ triton kernels. :func:`torch.library.custom_op` treats |
| custom operators as opaque (:func:`torch.compile` and |
| :func:`torch.export.export` will never trace into them), but ``triton_op`` |
| makes the implementation visible to these subsystems, allowing them |
| to optimize the triton kernel(s). |
| |
| Note that ``fn`` must only consist of calls to PyTorch-understood |
| operators and triton kernels. Any triton kernels called inside ``fn`` |
| must be wrapped in a call to :func:`torch.library.wrap_triton`. |
| |
| Args: |
| name (str): A name for the custom op that looks like "{namespace}::{name}", |
| e.g. "mylib::my_linear". The name is used as the op's stable identifier |
| in PyTorch subsystems (e.g. torch.export, FX graphs). |
| To avoid name collisions, please use your project name as the namespace; |
| e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. |
| mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. |
| This MUST be accurate, otherwise, the behavior is undefined. If "unknown", |
| it pessimistically assumes that all inputs to the operator are being mutated. |
| schema (None | str): A schema string for the operator. If None |
| (recommended) we'll infer a schema for the operator from its type |
| annotations. We recommend letting us infer a schema unless you |
| have a specific reason not to. |
| Example: "(Tensor x, int y) -> (Tensor, Tensor)". |
| |
| Example:: |
| |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| >>> import torch |
| >>> from torch.library import triton_op, wrap_triton |
| >>> |
| >>> import triton |
| >>> from triton import language as tl |
| >>> |
| >>> @triton.jit |
| >>> def add_kernel( |
| >>> in_ptr0, |
| >>> in_ptr1, |
| >>> out_ptr, |
| >>> n_elements, |
| >>> BLOCK_SIZE: "tl.constexpr", |
| >>> ): |
| >>> pid = tl.program_id(axis=0) |
| >>> block_start = pid * BLOCK_SIZE |
| >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| >>> mask = offsets < n_elements |
| >>> x = tl.load(in_ptr0 + offsets, mask=mask) |
| >>> y = tl.load(in_ptr1 + offsets, mask=mask) |
| >>> output = x + y |
| >>> tl.store(out_ptr + offsets, output, mask=mask) |
| >>> |
| >>> @triton_op("mylib::add", mutates_args={}) |
| >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| >>> output = torch.empty_like(x) |
| >>> n_elements = output.numel() |
| >>> |
| >>> def grid(meta): |
| >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| >>> |
| >>> # NB: we need to wrap the triton kernel in a call to wrap_triton |
| >>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16) |
| >>> return output |
| >>> |
| >>> @torch.compile |
| >>> def f(x, y): |
| >>> return add(x, y) |
| >>> |
| >>> x = torch.randn(3, device="cuda") |
| >>> y = torch.randn(3, device="cuda") |
| >>> |
| >>> z = f(x, y) |
| >>> assert torch.allclose(z, x + y) |
| |
| """ |
|
|
| def dec(fn: Callable[..., object]) -> CustomOpDef: |
| def backend_fn(*args, **kwargs): |
| |
| |
| with set_wrap_triton_enabled(False): |
| return fn(*args, **kwargs) |
|
|
| result = custom_op( |
| name, |
| backend_fn, |
| mutates_args=mutates_args, |
| schema=infer_schema(fn, mutates_args=mutates_args), |
| ) |
| from .._subclasses.functional_tensor import FunctionalTensorMode |
|
|
| |
| |
| result.register_fake(fn) |
|
|
| |
| |
| |
| |
| |
| def functional_decomp( |
| mode, op, types, args, kwargs |
| ): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from torch.export._trace import custom_triton_ops_decomposition_disabled |
|
|
| if custom_triton_ops_decomposition_disabled(): |
| return mode.__torch_dispatch__(op, types, args, kwargs) |
| else: |
| |
| |
| import torch._subclasses |
|
|
| unrecognized_types = [ |
| t |
| for t in types |
| if not issubclass(t, torch._subclasses.FakeTensor) |
| and t |
| not in [ |
| torch.Tensor, |
| torch._subclasses.functional_tensor.FunctionalTensor, |
| ] |
| ] |
|
|
| if unrecognized_types: |
| return NotImplemented |
| with mode: |
| return fn(*args, **kwargs) |
|
|
| triton_kernels = get_inner_triton_kernels(fn) |
| triton_ops_to_kernels[name] = triton_kernels |
| result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) |
| return result |
|
|
| if fn is None: |
| return dec |
| else: |
| return dec(fn) |
|
|
|
|
| wrap_triton_enabled = threading.local() |
| wrap_triton_enabled_default = True |
|
|
|
|
| @contextlib.contextmanager |
| def set_wrap_triton_enabled(enabled: bool) -> Generator[None, None, None]: |
| """If triton kernels annotated with @wrap_triton should dispatch via HOP |
| or go straight to the triton kernel execution. |
| |
| We have this switch because eager-mode performance of HOP dispatch is slow |
| enough to matter (~1ms) and we know that wrap_triton isn't necessary in |
| some situations (eager-mode with regular Tensors) |
| """ |
| try: |
| prev = is_wrap_triton_enabled() |
| wrap_triton_enabled.value = enabled |
| yield |
| finally: |
| wrap_triton_enabled.value = prev |
|
|
|
|
| def is_wrap_triton_enabled() -> bool: |
| return getattr(wrap_triton_enabled, "value", wrap_triton_enabled_default) |
|
|
|
|
| def capture_triton(triton_kernel: Callable, /) -> Any: |
| """This API has been renamed to wrap_triton""" |
| return wrap_triton(triton_kernel) |
|
|
|
|
| @exposed_in("torch.library") |
| def wrap_triton(triton_kernel: Callable, /) -> Any: |
| """Allows capture of a triton kernel into a graph via make_fx or |
| non-strict ``torch.export``. |
| |
| These technologies perform Dispatcher-based tracing (via |
| ``__torch_dispatch__``) and cannot see calls to raw triton kernels. |
| The ``wrap_triton`` API wraps a triton kernel into a callable that |
| can actually be traced into a graph. |
| |
| Please use this API together with :func:`torch.library.triton_op`. |
| |
| Examples: |
| |
| >>> # xdoctest: +SKIP |
| >>> import torch |
| >>> import triton |
| >>> from triton import language as tl |
| >>> from torch.fx.experimental.proxy_tensor import make_fx |
| >>> from torch.library import wrap_triton |
| >>> |
| >>> @triton.jit |
| >>> def add_kernel( |
| >>> in_ptr0, |
| >>> in_ptr1, |
| >>> out_ptr, |
| >>> n_elements, |
| >>> BLOCK_SIZE: "tl.constexpr", |
| >>> ): |
| >>> pid = tl.program_id(axis=0) |
| >>> block_start = pid * BLOCK_SIZE |
| >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| >>> mask = offsets < n_elements |
| >>> x = tl.load(in_ptr0 + offsets, mask=mask) |
| >>> y = tl.load(in_ptr1 + offsets, mask=mask) |
| >>> output = x + y |
| >>> tl.store(out_ptr + offsets, output, mask=mask) |
| >>> |
| >>> def add(x, y): |
| >>> output = torch.empty_like(x) |
| >>> n_elements = output.numel() |
| >>> |
| >>> def grid_fn(meta): |
| >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| >>> |
| >>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) |
| >>> return output |
| >>> |
| >>> x = torch.randn(3, device="cuda") |
| >>> y = torch.randn(3, device="cuda") |
| >>> gm = make_fx(add)(x, y) |
| >>> print(gm.code) |
| >>> # def forward(self, x_1, y_1): |
| >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) |
| >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( |
| >>> # kernel_idx = 0, constant_args_idx = 0, |
| >>> # grid = [(1, 1, 1)], kwargs = { |
| >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, |
| >>> # 'n_elements': 3, 'BLOCK_SIZE': 16 |
| >>> # }) |
| >>> # return empty_like |
| |
| """ |
| from triton.runtime.autotuner import Autotuner |
| from triton.runtime.jit import JITFunction |
|
|
| from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper |
|
|
| if not isinstance(triton_kernel, (JITFunction, Autotuner)): |
| raise RuntimeError( |
| "wrap_triton only works on functions annotated with triton.jit or triton.autotune" |
| ) |
| if not is_wrap_triton_enabled(): |
| return triton_kernel |
| return TraceableTritonKernelWrapper(triton_kernel, None, None) |
|
|