|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Callable, Iterable, Union |
|
|
|
|
|
from torch.library import custom_op, CustomOpDef |
|
|
from torch._library.triton import set_wrap_triton_enabled |
|
|
|
|
|
|
|
|
def triton_op( |
|
|
name: str, |
|
|
fn: Optional[Callable] = None, |
|
|
/, |
|
|
*, |
|
|
mutates_args: Union[str, Iterable[str]], |
|
|
schema: Optional[str] = None, |
|
|
|
|
|
|
|
|
|
|
|
allow_decomposition=True, |
|
|
) -> Callable: |
|
|
def dec(fn: Callable[..., object]) -> CustomOpDef: |
|
|
def backend_fn(*args, **kwargs): |
|
|
|
|
|
|
|
|
with set_wrap_triton_enabled(False): |
|
|
return fn(*args, **kwargs) |
|
|
|
|
|
result = custom_op( |
|
|
name, |
|
|
backend_fn, |
|
|
mutates_args=mutates_args, |
|
|
|
|
|
schema=schema, |
|
|
) |
|
|
from torch._subclasses.functional_tensor import FunctionalTensorMode |
|
|
|
|
|
|
|
|
|
|
|
result.register_fake(fn) |
|
|
|
|
|
if allow_decomposition: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def functional_decomp( |
|
|
mode, op, types, args, kwargs |
|
|
): |
|
|
from torch.export._trace import custom_triton_ops_decomposition_disabled |
|
|
|
|
|
if custom_triton_ops_decomposition_disabled(): |
|
|
return mode.__torch_dispatch__(op, types, args, kwargs) |
|
|
else: |
|
|
with mode: |
|
|
return fn(*args, **kwargs) |
|
|
|
|
|
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) |
|
|
|
|
|
return result |
|
|
|
|
|
if fn is None: |
|
|
return dec |
|
|
else: |
|
|
return dec(fn) |
|
|
|