|
|
from __future__ import annotations
|
|
|
|
|
|
import dataclasses
|
|
|
import itertools
|
|
|
import re
|
|
|
from dataclasses import dataclass
|
|
|
from enum import auto, Enum
|
|
|
from typing import Callable, Optional, TYPE_CHECKING
|
|
|
from typing_extensions import assert_never
|
|
|
|
|
|
from torchgen.utils import NamespaceHelper, OrderedSet
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from collections.abc import Iterator, Sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class Location:
|
|
|
file: str
|
|
|
line: int
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
return f"{self.file}:{self.line}"
|
|
|
|
|
|
|
|
|
|
|
|
class Variant(Enum):
|
|
|
function = auto()
|
|
|
method = auto()
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_KERNEL_NAMESPACE = "at::native"
|
|
|
|
|
|
|
|
|
BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
|
|
|
FUNCTIONALITY_KEYS = [
|
|
|
"",
|
|
|
"Quantized",
|
|
|
"Sparse",
|
|
|
"SparseCsr",
|
|
|
"NestedTensor",
|
|
|
"Autograd",
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [
|
|
|
"Autograd" + component for component in BACKEND_COMPONENTS
|
|
|
]
|
|
|
|
|
|
FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DispatchKey(Enum):
|
|
|
Undefined = 0
|
|
|
CatchAll = Undefined
|
|
|
|
|
|
FPGA = auto()
|
|
|
MAIA = auto()
|
|
|
Vulkan = auto()
|
|
|
Metal = auto()
|
|
|
MKLDNN = auto()
|
|
|
OpenGL = auto()
|
|
|
OpenCL = auto()
|
|
|
IDEEP = auto()
|
|
|
CustomRNGKeyId = auto()
|
|
|
MkldnnCPU = auto()
|
|
|
Sparse = auto()
|
|
|
SparseCsr = auto()
|
|
|
NestedTensor = auto()
|
|
|
Dense = auto()
|
|
|
|
|
|
PythonTLSSnapshot = auto()
|
|
|
PreDispatch = auto()
|
|
|
PythonDispatcher = auto()
|
|
|
Python = auto()
|
|
|
FuncTorchDynamicLayerBackMode = auto()
|
|
|
ZeroTensor = auto()
|
|
|
Conjugate = auto()
|
|
|
Negative = auto()
|
|
|
BackendSelect = auto()
|
|
|
Named = auto()
|
|
|
AutogradOther = auto()
|
|
|
AutogradFunctionality = auto()
|
|
|
AutogradNestedTensor = auto()
|
|
|
Tracer = auto()
|
|
|
Autocast = auto()
|
|
|
AutocastCPU = auto()
|
|
|
AutocastCUDA = auto()
|
|
|
Batched = auto()
|
|
|
VmapMode = auto()
|
|
|
FuncTorchGradWrapper = auto()
|
|
|
FuncTorchBatched = auto()
|
|
|
BatchedNestedTensor = auto()
|
|
|
FuncTorchVmapMode = auto()
|
|
|
FuncTorchDynamicLayerFrontMode = auto()
|
|
|
Functionalize = auto()
|
|
|
TESTING_ONLY_GenericWrapper = auto()
|
|
|
TESTING_ONLY_GenericMode = auto()
|
|
|
|
|
|
ADInplaceOrView = auto()
|
|
|
Autograd = auto()
|
|
|
CompositeImplicitAutograd = auto()
|
|
|
CompositeImplicitAutogradNestedTensor = auto()
|
|
|
CompositeExplicitAutograd = auto()
|
|
|
CompositeExplicitAutogradNonFunctional = auto()
|
|
|
FuncTorchBatchedDecomposition = auto()
|
|
|
|
|
|
|
|
|
CPU = auto()
|
|
|
CUDA = auto()
|
|
|
HIP = auto()
|
|
|
XLA = auto()
|
|
|
MTIA = auto()
|
|
|
MPS = auto()
|
|
|
IPU = auto()
|
|
|
XPU = auto()
|
|
|
HPU = auto()
|
|
|
VE = auto()
|
|
|
Lazy = auto()
|
|
|
Meta = auto()
|
|
|
PrivateUse1 = auto()
|
|
|
PrivateUse2 = auto()
|
|
|
PrivateUse3 = auto()
|
|
|
QuantizedCPU = auto()
|
|
|
QuantizedCUDA = auto()
|
|
|
QuantizedHIP = auto()
|
|
|
QuantizedXLA = auto()
|
|
|
QuantizedMTIA = auto()
|
|
|
QuantizedMPS = auto()
|
|
|
QuantizedIPU = auto()
|
|
|
QuantizedXPU = auto()
|
|
|
QuantizedHPU = auto()
|
|
|
QuantizedVE = auto()
|
|
|
QuantizedLazy = auto()
|
|
|
QuantizedMeta = auto()
|
|
|
QuantizedPrivateUse1 = auto()
|
|
|
QuantizedPrivateUse2 = auto()
|
|
|
QuantizedPrivateUse3 = auto()
|
|
|
SparseCPU = auto()
|
|
|
SparseCUDA = auto()
|
|
|
SparseHIP = auto()
|
|
|
SparseXLA = auto()
|
|
|
SparseMTIA = auto()
|
|
|
SparseMPS = auto()
|
|
|
SparseIPU = auto()
|
|
|
SparseXPU = auto()
|
|
|
SparseHPU = auto()
|
|
|
SparseVE = auto()
|
|
|
SparseLazy = auto()
|
|
|
SparseMeta = auto()
|
|
|
SparsePrivateUse1 = auto()
|
|
|
SparsePrivateUse2 = auto()
|
|
|
SparsePrivateUse3 = auto()
|
|
|
SparseCsrCPU = auto()
|
|
|
SparseCsrCUDA = auto()
|
|
|
SparseCsrHIP = auto()
|
|
|
SparseCsrXLA = auto()
|
|
|
SparseCsrMTIA = auto()
|
|
|
SparseCsrMPS = auto()
|
|
|
SparseCsrIPU = auto()
|
|
|
SparseCsrXPU = auto()
|
|
|
SparseCsrHPU = auto()
|
|
|
SparseCsrVE = auto()
|
|
|
SparseCsrLazy = auto()
|
|
|
SparseCsrMeta = auto()
|
|
|
SparseCsrPrivateUse1 = auto()
|
|
|
SparseCsrPrivateUse2 = auto()
|
|
|
SparseCsrPrivateUse3 = auto()
|
|
|
NestedTensorCPU = auto()
|
|
|
NestedTensorCUDA = auto()
|
|
|
NestedTensorHIP = auto()
|
|
|
NestedTensorXLA = auto()
|
|
|
NestedTensorMTIA = auto()
|
|
|
NestedTensorMPS = auto()
|
|
|
NestedTensorIPU = auto()
|
|
|
NestedTensorXPU = auto()
|
|
|
NestedTensorHPU = auto()
|
|
|
NestedTensorVE = auto()
|
|
|
NestedTensorLazy = auto()
|
|
|
NestedTensorMeta = auto()
|
|
|
NestedTensorPrivateUse1 = auto()
|
|
|
NestedTensorPrivateUse2 = auto()
|
|
|
NestedTensorPrivateUse3 = auto()
|
|
|
AutogradCPU = auto()
|
|
|
AutogradCUDA = auto()
|
|
|
AutogradHIP = auto()
|
|
|
AutogradXLA = auto()
|
|
|
AutogradMTIA = auto()
|
|
|
AutogradMPS = auto()
|
|
|
AutogradIPU = auto()
|
|
|
AutogradXPU = auto()
|
|
|
AutogradHPU = auto()
|
|
|
AutogradVE = auto()
|
|
|
AutogradLazy = auto()
|
|
|
AutogradMeta = auto()
|
|
|
AutogradPrivateUse1 = auto()
|
|
|
AutogradPrivateUse2 = auto()
|
|
|
AutogradPrivateUse3 = auto()
|
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
return self.name
|
|
|
|
|
|
def lower(self) -> str:
|
|
|
return str(self).lower()
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(value: str) -> DispatchKey:
|
|
|
for k, v in DispatchKey.__members__.items():
|
|
|
if k == value:
|
|
|
return v
|
|
|
raise AssertionError(f"unknown dispatch key {value}")
|
|
|
|
|
|
|
|
|
class _TorchDispatchModeKey(Enum):
|
|
|
FAKE = auto()
|
|
|
PROXY = auto()
|
|
|
FUNCTIONAL = auto()
|
|
|
|
|
|
|
|
|
def codegen_per_backend_entries() -> str:
|
|
|
r: list[str] = []
|
|
|
for fk in FUNCTIONALITY_KEYS:
|
|
|
r.extend(f" {fk}{bc} = auto()" for bc in BACKEND_COMPONENTS)
|
|
|
return "\n".join(r)
|
|
|
|
|
|
|
|
|
for fk in FUNCTIONALITY_KEYS:
|
|
|
for bc in BACKEND_COMPONENTS:
|
|
|
if not hasattr(DispatchKey, fk + bc):
|
|
|
r = codegen_per_backend_entries()
|
|
|
print(r)
|
|
|
raise RuntimeError(
|
|
|
f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}"
|
|
|
)
|
|
|
|
|
|
|
|
|
STRUCTURED_DISPATCH_KEYS = {
|
|
|
DispatchKey.MPS,
|
|
|
DispatchKey.CUDA,
|
|
|
DispatchKey.CPU,
|
|
|
DispatchKey.XPU,
|
|
|
DispatchKey.MTIA,
|
|
|
}
|
|
|
UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}
|
|
|
|
|
|
|
|
|
dispatch_keys = [
|
|
|
DispatchKey.CPU,
|
|
|
DispatchKey.SparseCPU,
|
|
|
DispatchKey.SparseCsrCPU,
|
|
|
DispatchKey.MkldnnCPU,
|
|
|
DispatchKey.CUDA,
|
|
|
DispatchKey.MPS,
|
|
|
DispatchKey.XPU,
|
|
|
DispatchKey.SparseXPU,
|
|
|
DispatchKey.SparseCsrXPU,
|
|
|
DispatchKey.SparseCUDA,
|
|
|
DispatchKey.SparseCsrCUDA,
|
|
|
DispatchKey.QuantizedCPU,
|
|
|
DispatchKey.QuantizedCUDA,
|
|
|
DispatchKey.CompositeImplicitAutograd,
|
|
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
|
|
DispatchKey.CompositeExplicitAutograd,
|
|
|
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
|
|
DispatchKey.NestedTensorCPU,
|
|
|
DispatchKey.NestedTensorCUDA,
|
|
|
DispatchKey.NestedTensorXPU,
|
|
|
DispatchKey.NestedTensorHPU,
|
|
|
|
|
|
|
|
|
DispatchKey.Meta,
|
|
|
DispatchKey.SparseMeta,
|
|
|
DispatchKey.SparseCsrMeta,
|
|
|
DispatchKey.QuantizedMeta,
|
|
|
DispatchKey.NestedTensorMeta,
|
|
|
DispatchKey.ZeroTensor,
|
|
|
DispatchKey.MTIA,
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_generic_dispatch_key(dk: DispatchKey) -> bool:
|
|
|
return dk in {
|
|
|
DispatchKey.CompositeExplicitAutograd,
|
|
|
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
|
|
DispatchKey.CompositeImplicitAutograd,
|
|
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
|
|
|
return dk in {
|
|
|
DispatchKey.CUDA,
|
|
|
DispatchKey.QuantizedCUDA,
|
|
|
DispatchKey.SparseCUDA,
|
|
|
DispatchKey.SparseCsrCUDA,
|
|
|
DispatchKey.NestedTensorCUDA,
|
|
|
DispatchKey.AutogradCUDA,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def is_xpu_dispatch_key(dk: DispatchKey) -> bool:
|
|
|
return dk in {
|
|
|
DispatchKey.XPU,
|
|
|
DispatchKey.QuantizedXPU,
|
|
|
DispatchKey.SparseXPU,
|
|
|
DispatchKey.SparseCsrXPU,
|
|
|
DispatchKey.NestedTensorXPU,
|
|
|
DispatchKey.AutogradXPU,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_structured_dispatch_key(dk: DispatchKey) -> bool:
|
|
|
return dk in STRUCTURED_DISPATCH_KEYS
|
|
|
|
|
|
|
|
|
def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
|
|
|
|
|
|
return dk in UFUNC_DISPATCH_KEYS
|
|
|
|
|
|
|
|
|
dispatch_device_map = {is_cuda_dispatch_key: "cuda", is_xpu_dispatch_key: "xpu"}
|
|
|
|
|
|
|
|
|
|
|
|
class ScalarType(Enum):
|
|
|
Byte = auto()
|
|
|
Char = auto()
|
|
|
Short = auto()
|
|
|
Int = auto()
|
|
|
Long = auto()
|
|
|
Half = auto()
|
|
|
Float = auto()
|
|
|
Double = auto()
|
|
|
ComplexHalf = auto()
|
|
|
ComplexFloat = auto()
|
|
|
ComplexDouble = auto()
|
|
|
Bool = auto()
|
|
|
BFloat16 = auto()
|
|
|
Float8_e5m2 = auto()
|
|
|
Float8_e5m2fnuz = auto()
|
|
|
Float8_e4m3fn = auto()
|
|
|
Float8_e4m3fnuz = auto()
|
|
|
Float8_e8m0fnu = auto()
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
return self.name
|
|
|
|
|
|
@staticmethod
|
|
|
def maybe_parse(value: str) -> ScalarType | None:
|
|
|
for k, v in ScalarType.__members__.items():
|
|
|
if k == value:
|
|
|
return v
|
|
|
return None
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(value: str) -> ScalarType:
|
|
|
mb_r = ScalarType.maybe_parse(value)
|
|
|
assert mb_r is not None, f"unknown dtype {value}"
|
|
|
return mb_r
|
|
|
|
|
|
@staticmethod
|
|
|
def parse_set(values: str) -> OrderedSet[ScalarType]:
|
|
|
dtypes: OrderedSet[ScalarType] = OrderedSet()
|
|
|
for value in values.split(", "):
|
|
|
if value in DTYPE_CLASSES:
|
|
|
dtypes.update(DTYPE_CLASSES[value])
|
|
|
else:
|
|
|
dtypes.add(ScalarType.parse(value))
|
|
|
return dtypes
|
|
|
|
|
|
|
|
|
DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {}
|
|
|
|
|
|
DTYPE_CLASSES["Integral"] = OrderedSet(
|
|
|
[
|
|
|
ScalarType.Byte,
|
|
|
ScalarType.Char,
|
|
|
ScalarType.Int,
|
|
|
ScalarType.Long,
|
|
|
ScalarType.Short,
|
|
|
]
|
|
|
)
|
|
|
|
|
|
DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double])
|
|
|
DTYPE_CLASSES["Complex"] = OrderedSet(
|
|
|
[ScalarType.ComplexFloat, ScalarType.ComplexDouble]
|
|
|
)
|
|
|
DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"]
|
|
|
DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"]
|
|
|
DTYPE_CLASSES["FloatingAndComplex"] = (
|
|
|
DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"]
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UfuncKey(Enum):
|
|
|
|
|
|
|
|
|
CUDAFunctor = auto()
|
|
|
CUDAFunctorOnOther = auto()
|
|
|
CUDAFunctorOnSelf = auto()
|
|
|
|
|
|
CPUScalar = auto()
|
|
|
CPUVector = auto()
|
|
|
|
|
|
|
|
|
|
|
|
ScalarOnly = auto()
|
|
|
Generic = auto()
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
return self.name
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(value: str) -> UfuncKey:
|
|
|
for k, v in UfuncKey.__members__.items():
|
|
|
if k == value:
|
|
|
return v
|
|
|
raise AssertionError(f"unknown ufunc key {value}")
|
|
|
|
|
|
|
|
|
class DeviceCheckType(Enum):
|
|
|
NoCheck = 0
|
|
|
ExactSame = 1
|
|
|
|
|
|
|
|
|
class ViewSchemaKind(Enum):
|
|
|
aliasing = auto()
|
|
|
aliasing_inplace = auto()
|
|
|
non_aliasing = auto()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class NativeFunction:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func: FunctionSchema
|
|
|
|
|
|
|
|
|
|
|
|
use_const_ref_for_mutable_tensors: bool
|
|
|
|
|
|
|
|
|
device_guard: bool
|
|
|
|
|
|
|
|
|
device_check: DeviceCheckType
|
|
|
|
|
|
|
|
|
python_module: str | None
|
|
|
|
|
|
|
|
|
category_override: str | None
|
|
|
|
|
|
|
|
|
|
|
|
variants: set[Variant]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
manual_kernel_registration: bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
manual_cpp_binding: bool
|
|
|
|
|
|
|
|
|
|
|
|
loc: Location
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
autogen: list[OperatorName]
|
|
|
|
|
|
|
|
|
|
|
|
ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
structured: bool
|
|
|
|
|
|
|
|
|
|
|
|
structured_delegate: OperatorName | None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
structured_inherits: str | None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
precomputed: Precompute | None
|
|
|
|
|
|
|
|
|
|
|
|
cpp_no_default_args: set[str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_abstract: bool
|
|
|
|
|
|
|
|
|
has_composite_implicit_autograd_kernel: bool
|
|
|
has_composite_implicit_autograd_nested_tensor_kernel: bool
|
|
|
has_composite_explicit_autograd_kernel: bool
|
|
|
has_composite_explicit_autograd_non_functional_kernel: bool
|
|
|
|
|
|
|
|
|
|
|
|
tags: set[str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
def from_yaml(
|
|
|
ei: dict[str, object],
|
|
|
loc: Location,
|
|
|
valid_tags: set[str],
|
|
|
ignore_keys: set[DispatchKey] | None = None,
|
|
|
) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
|
|
|
"""
|
|
|
Parse a NativeFunction from a dictionary as directly parsed
|
|
|
from native_functions.yaml
|
|
|
"""
|
|
|
e = ei.copy()
|
|
|
|
|
|
funcs = e.pop("func")
|
|
|
assert isinstance(funcs, str), f"not a str: {funcs}"
|
|
|
|
|
|
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
|
|
namespaced_entity=funcs, max_level=1
|
|
|
)
|
|
|
namespace = namespace_helper.get_cpp_namespace(default="aten")
|
|
|
func = FunctionSchema.parse(namespace_helper.entity_name)
|
|
|
|
|
|
cpp_no_default_args_list = e.pop("cpp_no_default_args", [])
|
|
|
assert isinstance(cpp_no_default_args_list, list)
|
|
|
cpp_no_default_args = set(cpp_no_default_args_list)
|
|
|
|
|
|
use_const_ref_for_mutable_tensors = e.pop(
|
|
|
"use_const_ref_for_mutable_tensors", False
|
|
|
)
|
|
|
assert isinstance(use_const_ref_for_mutable_tensors, bool)
|
|
|
|
|
|
if use_const_ref_for_mutable_tensors:
|
|
|
assert not func.arguments.out, (
|
|
|
"see https://github.com/pytorch/pytorch/issues/145522"
|
|
|
)
|
|
|
|
|
|
variants_s = e.pop("variants", "function")
|
|
|
assert isinstance(variants_s, str)
|
|
|
variants: set[Variant] = set()
|
|
|
for v in variants_s.split(", "):
|
|
|
if v == "function":
|
|
|
variants.add(Variant.function)
|
|
|
elif v == "method":
|
|
|
variants.add(Variant.method)
|
|
|
else:
|
|
|
raise AssertionError(f"illegal variant {v}")
|
|
|
|
|
|
manual_kernel_registration = e.pop("manual_kernel_registration", False)
|
|
|
assert isinstance(manual_kernel_registration, bool), (
|
|
|
f"not a bool: {manual_kernel_registration}"
|
|
|
)
|
|
|
|
|
|
manual_cpp_binding = e.pop("manual_cpp_binding", False)
|
|
|
assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}"
|
|
|
|
|
|
device_guard = e.pop("device_guard", True)
|
|
|
assert isinstance(device_guard, bool), f"not a bool: {device_guard}"
|
|
|
|
|
|
device_check_s = e.pop("device_check", None)
|
|
|
assert device_check_s is None or isinstance(device_check_s, str), (
|
|
|
f"not a str: {device_check_s}"
|
|
|
)
|
|
|
assert (
|
|
|
device_check_s is None or device_check_s in DeviceCheckType.__members__
|
|
|
), f"illegal device_check: {device_check_s}"
|
|
|
device_check: DeviceCheckType
|
|
|
if device_check_s is None:
|
|
|
device_check = DeviceCheckType.ExactSame
|
|
|
else:
|
|
|
device_check = DeviceCheckType[device_check_s]
|
|
|
|
|
|
structured = e.pop("structured", False)
|
|
|
assert isinstance(structured, bool), f"not a bool: {structured}"
|
|
|
|
|
|
structured_delegate_s = e.pop("structured_delegate", None)
|
|
|
assert structured_delegate_s is None or isinstance(
|
|
|
structured_delegate_s, str
|
|
|
), f"not a str: {structured_delegate_s}"
|
|
|
assert structured_delegate_s is None or "::" not in structured_delegate_s, (
|
|
|
"namespace is not supported in structured delegate,"
|
|
|
" using the same namespace as the native function"
|
|
|
)
|
|
|
structured_delegate: OperatorName | None = None
|
|
|
if structured_delegate_s is not None:
|
|
|
structured_delegate = OperatorName.parse(structured_delegate_s)
|
|
|
|
|
|
structured_inherits = e.pop("structured_inherits", None)
|
|
|
assert structured_inherits is None or isinstance(structured_inherits, str), (
|
|
|
f"not a str: {structured_inherits}"
|
|
|
)
|
|
|
assert structured_inherits is None or "::" not in structured_inherits, (
|
|
|
"namespace is not supported in structured inherits,"
|
|
|
" using the same namespace as the native function"
|
|
|
)
|
|
|
|
|
|
python_module = e.pop("python_module", None)
|
|
|
assert python_module is None or isinstance(python_module, str), (
|
|
|
f"not a str: {python_module}"
|
|
|
)
|
|
|
assert python_module is None or Variant.method not in variants, (
|
|
|
"functions in modules cannot be methods"
|
|
|
)
|
|
|
|
|
|
category_override = e.pop("category_override", None)
|
|
|
assert category_override is None or isinstance(category_override, str), (
|
|
|
f"not a str: {category_override}"
|
|
|
)
|
|
|
|
|
|
precomputed_dict = e.pop("precomputed", None)
|
|
|
assert precomputed_dict is None or structured is True
|
|
|
precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
|
|
|
|
|
|
tags_inp = e.pop("tags", [])
|
|
|
if isinstance(tags_inp, str):
|
|
|
tags_inp = [tags_inp]
|
|
|
assert isinstance(tags_inp, list)
|
|
|
|
|
|
|
|
|
if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
|
|
|
tags_inp.append("pt2_compliant_tag")
|
|
|
|
|
|
tags: set[str] = set()
|
|
|
for t in tags_inp:
|
|
|
assert len(valid_tags) > 0
|
|
|
|
|
|
if t in valid_tags:
|
|
|
tags.add(t)
|
|
|
else:
|
|
|
raise AssertionError(f"illegal tag {t}")
|
|
|
|
|
|
from torchgen.api import cpp
|
|
|
|
|
|
raw_dispatch = e.pop("dispatch", None)
|
|
|
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
|
|
|
dispatch: dict[DispatchKey, BackendMetadata] = {}
|
|
|
num_dispatch_keys: int = 0
|
|
|
if raw_dispatch is not None:
|
|
|
assert not manual_kernel_registration, (
|
|
|
"cannot specify both manual_kernel_registration and dispatch; with "
|
|
|
"manual registration, dispatch has no effect!"
|
|
|
)
|
|
|
redundant_composite_implicit_autograd = False
|
|
|
for ks, v in raw_dispatch.items():
|
|
|
if ks == "__line__":
|
|
|
continue
|
|
|
assert isinstance(ks, str), (
|
|
|
f"illegal dispatch key '{ks}' in {raw_dispatch}"
|
|
|
)
|
|
|
assert isinstance(v, str), (
|
|
|
f"illegal dispatch value '{v}' in {raw_dispatch}"
|
|
|
)
|
|
|
for k in ks.split(","):
|
|
|
dispatch_key = DispatchKey.parse(k.strip())
|
|
|
num_dispatch_keys += 1
|
|
|
|
|
|
if ignore_keys and dispatch_key in ignore_keys:
|
|
|
continue
|
|
|
assert dispatch_key in dispatch_keys, (
|
|
|
f"Dispatch key {dispatch_key} of kernel {v} "
|
|
|
"is not a supported dispatch key."
|
|
|
)
|
|
|
|
|
|
|
|
|
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
|
|
v, max_level=3
|
|
|
)
|
|
|
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
|
|
|
|
|
|
|
|
|
|
|
|
dispatch[dispatch_key] = BackendMetadata(
|
|
|
kernel=namespace_helper.entity_name,
|
|
|
structured=structured
|
|
|
and is_structured_dispatch_key(dispatch_key),
|
|
|
cpp_namespace=(kernel_namespace + "::native"),
|
|
|
)
|
|
|
if (
|
|
|
dispatch_key is DispatchKey.CompositeImplicitAutograd
|
|
|
and v == cpp.name(func)
|
|
|
):
|
|
|
redundant_composite_implicit_autograd = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert not (
|
|
|
num_dispatch_keys == 1 and redundant_composite_implicit_autograd
|
|
|
), (
|
|
|
"unnecessary dispatch table for this function; just delete the dispatch "
|
|
|
"key entirely"
|
|
|
)
|
|
|
|
|
|
|
|
|
assert (
|
|
|
structured_delegate
|
|
|
or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
|
|
|
or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
|
|
|
or num_dispatch_keys != 1
|
|
|
), (
|
|
|
f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
|
|
|
f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected "
|
|
|
"name, then delete the dispatch table"
|
|
|
)
|
|
|
elif not structured and structured_delegate is None:
|
|
|
name = str(func.name.name)
|
|
|
assert not (
|
|
|
name.startswith("new_")
|
|
|
or name.endswith("_like")
|
|
|
|
|
|
or (
|
|
|
func.arguments.tensor_options
|
|
|
and not func.arguments.has_tensor_arg()
|
|
|
)
|
|
|
), (
|
|
|
f"expected {name} to have a CompositeExplicitAutograd "
|
|
|
"dispatch entry, but there was no dispatch table. Factory functions "
|
|
|
"should not have implicit dispatch as they should not be decomposed "
|
|
|
"for __torch_dispatch__"
|
|
|
)
|
|
|
dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata(
|
|
|
cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE
|
|
|
)
|
|
|
|
|
|
composites_in_dispatch = [
|
|
|
d
|
|
|
for d in dispatch
|
|
|
if d == DispatchKey.CompositeExplicitAutograd
|
|
|
or d == DispatchKey.CompositeExplicitAutogradNonFunctional
|
|
|
or d == DispatchKey.CompositeImplicitAutograd
|
|
|
or d == DispatchKey.CompositeImplicitAutogradNestedTensor
|
|
|
]
|
|
|
|
|
|
assert len(composites_in_dispatch) <= 1 or (
|
|
|
len(composites_in_dispatch) == 2
|
|
|
and (
|
|
|
DispatchKey.CompositeExplicitAutogradNonFunctional
|
|
|
not in composites_in_dispatch
|
|
|
)
|
|
|
and (
|
|
|
DispatchKey.CompositeImplicitAutogradNestedTensor
|
|
|
in composites_in_dispatch
|
|
|
)
|
|
|
), (
|
|
|
"cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
|
|
|
"or CompositeImplicitAutograd on a single kernel; each "
|
|
|
"strictly subsumes the other. If you wanted to provide an explicit autograd "
|
|
|
"implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only"
|
|
|
)
|
|
|
|
|
|
autogen_str = e.pop("autogen", "")
|
|
|
assert isinstance(autogen_str, str)
|
|
|
autogen = (
|
|
|
[]
|
|
|
if autogen_str == ""
|
|
|
else [OperatorName.parse(x) for x in autogen_str.split(", ")]
|
|
|
)
|
|
|
|
|
|
raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {})
|
|
|
ufunc_inner_loop = {}
|
|
|
if isinstance(raw_ufunc_inner_loop, str):
|
|
|
ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse(
|
|
|
raw_ufunc_inner_loop, UfuncKey.Generic
|
|
|
)
|
|
|
elif isinstance(raw_ufunc_inner_loop, dict):
|
|
|
for k, vo in raw_ufunc_inner_loop.items():
|
|
|
if k == "__line__":
|
|
|
continue
|
|
|
assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}"
|
|
|
assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}"
|
|
|
ufunc_key = UfuncKey.parse(k)
|
|
|
ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key)
|
|
|
else:
|
|
|
raise AssertionError(
|
|
|
f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}"
|
|
|
)
|
|
|
|
|
|
if ufunc_inner_loop:
|
|
|
assert structured, "ufunc must be structured"
|
|
|
|
|
|
|
|
|
|
|
|
import torchgen.api.ufunc as ufunc
|
|
|
|
|
|
for dispatch_key in UFUNC_DISPATCH_KEYS:
|
|
|
assert dispatch_key not in dispatch, (
|
|
|
f"ufunc should not have explicit dispatch entry for {dispatch_key}"
|
|
|
)
|
|
|
dispatch[dispatch_key] = BackendMetadata(
|
|
|
kernel=ufunc.schema_kernel_name(func, dispatch_key),
|
|
|
structured=True,
|
|
|
cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
|
|
|
)
|
|
|
|
|
|
if structured_delegate:
|
|
|
|
|
|
is_abstract = True
|
|
|
else:
|
|
|
is_abstract = (
|
|
|
dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
|
|
|
and dispatch.keys()
|
|
|
!= {DispatchKey.CompositeImplicitAutogradNestedTensor}
|
|
|
and dispatch.keys()
|
|
|
!= {
|
|
|
DispatchKey.CompositeImplicitAutograd,
|
|
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
|
|
}
|
|
|
)
|
|
|
|
|
|
has_composite_implicit_autograd_kernel = (
|
|
|
DispatchKey.CompositeImplicitAutograd in dispatch
|
|
|
)
|
|
|
has_composite_implicit_autograd_nested_tensor_kernel = (
|
|
|
DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch
|
|
|
)
|
|
|
has_composite_explicit_autograd_kernel = (
|
|
|
DispatchKey.CompositeExplicitAutograd in dispatch
|
|
|
)
|
|
|
has_composite_explicit_autograd_non_functional_kernel = (
|
|
|
DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend_metadata = {k: {func.name: v} for k, v in dispatch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
e.pop("__line__", None)
|
|
|
assert not e, f"leftover entries: {e}"
|
|
|
|
|
|
|
|
|
if structured_delegate is not None:
|
|
|
for key in STRUCTURED_DISPATCH_KEYS:
|
|
|
assert key not in dispatch, (
|
|
|
f"if structured_delegate, then must not have {key} in dispatch dictionary "
|
|
|
"(it is delegated!)"
|
|
|
)
|
|
|
|
|
|
return (
|
|
|
NativeFunction(
|
|
|
func=func,
|
|
|
use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
|
|
|
variants=variants,
|
|
|
structured=structured,
|
|
|
structured_delegate=structured_delegate,
|
|
|
structured_inherits=structured_inherits,
|
|
|
precomputed=precomputed,
|
|
|
autogen=autogen,
|
|
|
ufunc_inner_loop=ufunc_inner_loop,
|
|
|
manual_kernel_registration=manual_kernel_registration,
|
|
|
manual_cpp_binding=manual_cpp_binding,
|
|
|
python_module=python_module,
|
|
|
category_override=category_override,
|
|
|
device_guard=device_guard,
|
|
|
device_check=device_check,
|
|
|
loc=loc,
|
|
|
cpp_no_default_args=cpp_no_default_args,
|
|
|
is_abstract=is_abstract,
|
|
|
has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
|
|
|
has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
|
|
|
has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
|
|
|
has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
|
|
|
tags=tags,
|
|
|
namespace=namespace,
|
|
|
),
|
|
|
backend_metadata,
|
|
|
)
|
|
|
|
|
|
def validate_unstructured(self) -> None:
|
|
|
|
|
|
|
|
|
assert not self.structured, (
|
|
|
"This function is structured, but there was "
|
|
|
"no valid functional variant of it."
|
|
|
)
|
|
|
assert self.structured_delegate, (
|
|
|
"This function delegates to another structured out function, "
|
|
|
"but no valid function was found (the delegate may not exist, or it has the wrong type)"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __post_init__(self) -> None:
|
|
|
if self.func.arguments.out:
|
|
|
assert self.variants == {Variant.function}, (
|
|
|
"Native functions with out arguments MUST "
|
|
|
"be declared with only function variant; e.g., variants: function; "
|
|
|
"otherwise you will tickle a Python argument binding bug "
|
|
|
"(which usually manifests itself as the result variable being undefined.)"
|
|
|
)
|
|
|
if self.structured:
|
|
|
assert self.func.kind() == SchemaKind.out, (
|
|
|
"Put structured field on the out= "
|
|
|
"variant of a function; did you mean structured_delegate?"
|
|
|
)
|
|
|
assert self.device_guard, (
|
|
|
"device_guard: False is not respected by structured kernels"
|
|
|
)
|
|
|
if self.structured_delegate:
|
|
|
assert self.func.kind() != SchemaKind.out, (
|
|
|
"structured_delegate field not allowed "
|
|
|
"on out= functions; did you mean structured?"
|
|
|
)
|
|
|
assert self.device_guard, (
|
|
|
"device_guard: False is not respected by structured kernels"
|
|
|
)
|
|
|
|
|
|
|
|
|
assert not (self.structured and self.structured_delegate), (
|
|
|
"Cannot have both structured and structured_delegate on function"
|
|
|
)
|
|
|
defaulted_arguments = {
|
|
|
a.name for a in self.func.schema_order_arguments() if a.default is not None
|
|
|
}
|
|
|
invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
|
|
|
assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}"
|
|
|
if self.structured_inherits is not None:
|
|
|
assert self.structured, (
|
|
|
"structured_inherits must also imply structured: True"
|
|
|
)
|
|
|
if str(self.func.name).startswith("_foreach"):
|
|
|
assert self.device_check == DeviceCheckType.NoCheck, (
|
|
|
"foreach kernels fall back to slow path when tensor are on different devices, "
|
|
|
"device_check not allowed to be enabled"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
"rand" in str(self.func.name)
|
|
|
or (
|
|
|
(
|
|
|
"dropout" in str(self.func.name)
|
|
|
or any(
|
|
|
"dropout" in arg.name for arg in self.func.arguments.flat_all
|
|
|
)
|
|
|
)
|
|
|
|
|
|
and "backward" not in str(self.func.name)
|
|
|
and str(self.func.name.name) not in ["_cudnn_init_dropout_state"]
|
|
|
)
|
|
|
or self.func.arguments.has_generator_arg()
|
|
|
):
|
|
|
assert "nondeterministic_seeded" in self.tags, str(self.func.name)
|
|
|
|
|
|
@property
|
|
|
def has_composite_kernel(self) -> bool:
|
|
|
return (
|
|
|
self.has_composite_implicit_autograd_kernel
|
|
|
or self.has_composite_explicit_autograd_kernel
|
|
|
or self.has_composite_explicit_autograd_non_functional_kernel
|
|
|
) or (
|
|
|
self.has_composite_implicit_autograd_kernel
|
|
|
and self.has_composite_implicit_autograd_nested_tensor_kernel
|
|
|
)
|
|
|
|
|
|
@property
|
|
|
def is_view_op(self) -> bool:
|
|
|
rets = self.func.returns
|
|
|
is_non_mutating_view = len(rets) > 0 and any(
|
|
|
r.annotation is not None and not r.annotation.is_write for r in rets
|
|
|
)
|
|
|
|
|
|
is_inplace_view = (
|
|
|
"inplace_view" in self.tags
|
|
|
and str(self.func.name) != "resize_"
|
|
|
and str(self.func.name) != "resize_as_"
|
|
|
)
|
|
|
is_wildcard_view = any(
|
|
|
inp.annotation is not None and "*" in inp.annotation.alias_set_after
|
|
|
for inp in self.func.schema_order_arguments()
|
|
|
)
|
|
|
return is_non_mutating_view or is_inplace_view or is_wildcard_view
|
|
|
|
|
|
@property
|
|
|
def view_schema_kind(self) -> ViewSchemaKind:
|
|
|
if self.is_view_op and self.func.name.name.inplace:
|
|
|
assert "inplace_view" in self.tags
|
|
|
return ViewSchemaKind.aliasing_inplace
|
|
|
if self.is_view_op:
|
|
|
return ViewSchemaKind.aliasing
|
|
|
else:
|
|
|
return ViewSchemaKind.non_aliasing
|
|
|
|
|
|
@property
|
|
|
def root_name(self) -> str:
|
|
|
return self.func.name.name.base
|
|
|
|
|
|
@property
|
|
|
def part_of_structured_group(self) -> bool:
|
|
|
return self.structured or self.structured_delegate is not None
|
|
|
|
|
|
|
|
|
class SchemaKind(Enum):
|
|
|
functional = auto()
|
|
|
inplace = auto()
|
|
|
out = auto()
|
|
|
mutable = auto()
|
|
|
scratch = auto()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class NativeFunctionsGroup:
|
|
|
functional: NativeFunction
|
|
|
inplace: NativeFunction | None
|
|
|
mutable: NativeFunction | None
|
|
|
out: NativeFunction
|
|
|
|
|
|
@property
|
|
|
def structured(self) -> bool:
|
|
|
|
|
|
return self.out.structured
|
|
|
|
|
|
def __post_init__(self) -> None:
|
|
|
test_sig: FunctionSchema = self.functional.func.signature()
|
|
|
for f in self.functions():
|
|
|
if test_sig != f.func.signature():
|
|
|
raise AssertionError(
|
|
|
"NativeFunctionsGroup constructed from two NativeFunctions "
|
|
|
f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
|
|
|
)
|
|
|
|
|
|
if self.structured != f.part_of_structured_group:
|
|
|
raise AssertionError(
|
|
|
"NativeFunctionsGroup constructed from structured and unstructured "
|
|
|
f"functions: {self.out.func.name} and {f.func.name}"
|
|
|
)
|
|
|
assert self.functional.func.kind() == SchemaKind.functional
|
|
|
assert self.out.func.kind() == SchemaKind.out
|
|
|
assert self.functional.namespace == self.out.namespace
|
|
|
if self.inplace is not None:
|
|
|
assert self.inplace.func.kind() == SchemaKind.inplace
|
|
|
assert self.inplace.namespace == self.functional.namespace
|
|
|
|
|
|
if self.mutable is not None:
|
|
|
assert self.mutable.func.kind() == SchemaKind.mutable
|
|
|
assert self.mutable.namespace == self.functional.namespace
|
|
|
|
|
|
assert self.functional.func.name.name.functional_overload
|
|
|
|
|
|
if self.structured:
|
|
|
|
|
|
|
|
|
assert (
|
|
|
not self.out.has_composite_implicit_autograd_kernel
|
|
|
and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
|
|
|
)
|
|
|
|
|
|
assert self.functional.structured_delegate == self.out.func.name, (
|
|
|
f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
|
|
|
f"but its actual delegate is {self.out.func.name}"
|
|
|
)
|
|
|
if self.inplace is not None:
|
|
|
assert self.inplace.structured_delegate == self.out.func.name
|
|
|
|
|
|
generated_fns = sorted(
|
|
|
[str(f.func.name) for f in self.functions() if "generated" in f.tags]
|
|
|
)
|
|
|
generated_fns_str = ", ".join(str(x) for x in generated_fns)
|
|
|
expected_generated_fns: set[str] = set()
|
|
|
for f in self.functions():
|
|
|
expected_generated_fns.update(str(op) for op in f.autogen)
|
|
|
expected_generated_fns_str = ", ".join(
|
|
|
str(x) for x in sorted(expected_generated_fns)
|
|
|
)
|
|
|
if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
|
|
|
raise RuntimeError(
|
|
|
f"The codegen expects to be able to generate '{generated_fns_str}'."
|
|
|
" In order to generate them however, we expect them to be called out explicitly in the yaml."
|
|
|
f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
|
|
|
)
|
|
|
if expected_generated_fns_str != generated_fns_str:
|
|
|
raise RuntimeError(
|
|
|
f"The codegen expects to be able to generate '{generated_fns_str}'."
|
|
|
f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
|
|
|
f" Instead, it found 'autogen: {expected_generated_fns_str}'"
|
|
|
)
|
|
|
|
|
|
def signature(self) -> FunctionSchema:
|
|
|
return self.out.func.signature()
|
|
|
|
|
|
def functions(self) -> Iterator[NativeFunction]:
|
|
|
yield self.functional
|
|
|
yield self.out
|
|
|
if self.inplace is not None:
|
|
|
yield self.inplace
|
|
|
if self.mutable is not None:
|
|
|
yield self.mutable
|
|
|
|
|
|
@property
|
|
|
def root_name(self) -> str:
|
|
|
return self.functional.root_name
|
|
|
|
|
|
@staticmethod
|
|
|
def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None:
|
|
|
assert d
|
|
|
if len(d) == 1:
|
|
|
return None
|
|
|
d = dict(d)
|
|
|
functional = d.pop(SchemaKind.functional, None)
|
|
|
inplace = d.pop(SchemaKind.inplace, None)
|
|
|
mutable = d.pop(SchemaKind.mutable, None)
|
|
|
out = d.pop(SchemaKind.out, None)
|
|
|
assert not d
|
|
|
assert functional is not None
|
|
|
|
|
|
|
|
|
if out is None:
|
|
|
return None
|
|
|
|
|
|
return NativeFunctionsGroup(
|
|
|
functional=functional,
|
|
|
inplace=inplace,
|
|
|
mutable=mutable,
|
|
|
out=out,
|
|
|
)
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class BackendMetadata:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kernel: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
structured: bool
|
|
|
|
|
|
|
|
|
cpp_namespace: str
|
|
|
|
|
|
def supports_symint(self) -> bool:
|
|
|
return "_symint" in self.kernel
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class UfuncInnerLoop:
|
|
|
name: str
|
|
|
supported_dtypes: OrderedSet[ScalarType]
|
|
|
|
|
|
|
|
|
ufunc_key: UfuncKey
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop:
|
|
|
name, supported_dtypes_str = value.split(" ", 1)
|
|
|
assert supported_dtypes_str[0] == "("
|
|
|
assert supported_dtypes_str[-1] == ")"
|
|
|
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
|
|
|
for k in supported_dtypes_str[1:-1].split(", "):
|
|
|
supported_dtypes |= ScalarType.parse_set(k)
|
|
|
return UfuncInnerLoop(
|
|
|
name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class BackendIndex:
|
|
|
dispatch_key: DispatchKey
|
|
|
|
|
|
|
|
|
use_out_as_primary: bool
|
|
|
|
|
|
|
|
|
|
|
|
device_guard: bool
|
|
|
|
|
|
external: bool
|
|
|
|
|
|
index: dict[OperatorName, BackendMetadata]
|
|
|
|
|
|
@staticmethod
|
|
|
def grow_index(
|
|
|
parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
|
|
child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
|
|
) -> None:
|
|
|
for k, v in child_index.items():
|
|
|
for op_name, metadata in v.items():
|
|
|
assert op_name not in parent_index[k], (
|
|
|
f"duplicate operator {op_name} for dispatch key {k}"
|
|
|
)
|
|
|
parent_index[k][op_name] = metadata
|
|
|
|
|
|
def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
|
|
|
if self.use_out_as_primary:
|
|
|
return g.out
|
|
|
else:
|
|
|
return g.functional
|
|
|
|
|
|
def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
|
|
|
m = self.get_kernel(g)
|
|
|
return m is not None
|
|
|
|
|
|
def get_kernel(
|
|
|
self, g: NativeFunction | NativeFunctionsGroup
|
|
|
) -> BackendMetadata | None:
|
|
|
if isinstance(g, NativeFunction):
|
|
|
f = g
|
|
|
elif isinstance(g, NativeFunctionsGroup):
|
|
|
f = self.primary(g)
|
|
|
else:
|
|
|
assert_never(g)
|
|
|
if f.func.name not in self.index:
|
|
|
return None
|
|
|
return self.index[f.func.name]
|
|
|
|
|
|
def native_function_class_name(self) -> str | None:
|
|
|
if self.external:
|
|
|
return f"{str(self.dispatch_key)}NativeFunctions"
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class FunctionSchema:
|
|
|
|
|
|
name: OperatorName
|
|
|
|
|
|
arguments: Arguments
|
|
|
|
|
|
|
|
|
returns: tuple[Return, ...]
|
|
|
|
|
|
@property
|
|
|
def is_mutable(self) -> bool:
|
|
|
def is_write(arg: Argument) -> bool:
|
|
|
if arg.annotation is None:
|
|
|
return False
|
|
|
return arg.annotation.is_write
|
|
|
|
|
|
|
|
|
|
|
|
return any(is_write(a) for a in self.arguments.flat_all)
|
|
|
|
|
|
def schema_order_arguments(self) -> Iterator[Argument]:
|
|
|
return itertools.chain(
|
|
|
self.arguments.flat_positional,
|
|
|
self.arguments.flat_kwarg_only,
|
|
|
self.arguments.out,
|
|
|
)
|
|
|
|
|
|
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(func: str) -> FunctionSchema:
|
|
|
|
|
|
decls = FunctionSchema.decl_re.findall(func)
|
|
|
assert len(decls) == 1, f"Invalid function schema: {func}"
|
|
|
ops, args, return_decl = decls[0]
|
|
|
name = OperatorName.parse(ops)
|
|
|
arguments = Arguments.parse(args)
|
|
|
returns = parse_returns(return_decl)
|
|
|
r = FunctionSchema(name=name, arguments=arguments, returns=returns)
|
|
|
assert str(r) == func, f"{str(r)} != {func}"
|
|
|
return r
|
|
|
|
|
|
def returns_are_aliased(self) -> bool:
|
|
|
|
|
|
return any(
|
|
|
r
|
|
|
for r in self.returns
|
|
|
if r.annotation is not None and r.annotation.is_write
|
|
|
)
|
|
|
|
|
|
def __post_init__(self) -> None:
|
|
|
for arg, ret in zip(self.arguments.out, self.returns):
|
|
|
assert arg.annotation == ret.annotation, (
|
|
|
"Out arguments must have matching return Tensor; furthermore, "
|
|
|
"the ith-argument needs to correspond to the ith return"
|
|
|
)
|
|
|
|
|
|
|
|
|
for a in self.arguments.post_self_positional_mutable:
|
|
|
assert not any(a.annotation == r.annotation for r in self.returns), (
|
|
|
f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_and_self = list(self.arguments.out) + [
|
|
|
arg for arg in self.arguments.flat_positional if arg.name == "self"
|
|
|
]
|
|
|
mutable_returns = [
|
|
|
ret
|
|
|
for ret in self.returns
|
|
|
if ret.annotation is not None and ret.annotation.is_write
|
|
|
]
|
|
|
immutable_returns = [
|
|
|
ret
|
|
|
for ret in self.returns
|
|
|
if ret.annotation is None or not ret.annotation.is_write
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert len(mutable_returns) == 0 or len(immutable_returns) == 0, (
|
|
|
f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
|
|
|
)
|
|
|
for ret in mutable_returns:
|
|
|
assert any(ret.annotation == arg.annotation for arg in out_and_self), (
|
|
|
'All mutable returns must be aliased either to a keyword argument, or to "self". '
|
|
|
"Did you forget to mark an out argument as keyword-only?"
|
|
|
)
|
|
|
if self.arguments.out:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
|
|
|
assert len(self.returns) == 0, (
|
|
|
"out= ops that accept tensor lists as out arguments "
|
|
|
)
|
|
|
"are expected to have no return type (since you can't do method chaining on them)"
|
|
|
else:
|
|
|
|
|
|
|
|
|
assert len(
|
|
|
[
|
|
|
arg
|
|
|
for arg in self.arguments.out
|
|
|
if not arg.name.startswith("_scratch_")
|
|
|
]
|
|
|
) == len(self.returns), (
|
|
|
"Must return as many arguments as there are out arguments, or no return at all"
|
|
|
)
|
|
|
|
|
|
if self.name.name.inplace:
|
|
|
self_a = self.arguments.self_arg
|
|
|
assert (
|
|
|
self_a
|
|
|
and self_a.argument.annotation
|
|
|
and self_a.argument.annotation.is_write
|
|
|
)
|
|
|
if self_a.argument.type == BaseType(BaseTy.Tensor):
|
|
|
|
|
|
|
|
|
assert (
|
|
|
len(self.returns) == 1
|
|
|
and self.returns[0].annotation == self_a.argument.annotation
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
|
|
|
assert len(self.returns) == 0
|
|
|
|
|
|
if self.arguments.tensor_options is not None:
|
|
|
assert self.kind() == SchemaKind.functional, (
|
|
|
"Found an operator that is not functional or out variant, but has tensor options arguments."
|
|
|
"This is not allowed- tensor options arguments are only allowed for factory functions."
|
|
|
f"schema: {str(self)}"
|
|
|
)
|
|
|
if self.is_functional_fn():
|
|
|
assert self.kind() == SchemaKind.functional, (
|
|
|
"Found an operator that is not functional, but its overload contains the string 'functional'."
|
|
|
"This is a special keyword in the codegen, please use a different overload name."
|
|
|
f"schema: {str(self)}"
|
|
|
)
|
|
|
|
|
|
def is_functional_fn(self) -> bool:
|
|
|
return "functional" in self.name.overload_name
|
|
|
|
|
|
def is_out_fn(self) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return bool(self.arguments.out)
|
|
|
|
|
|
def kind(self) -> SchemaKind:
|
|
|
"""
|
|
|
What kind of schema is this? A functional schema is one
|
|
|
that returns a newly allocated output; an inplace schema
|
|
|
modifies the self argument inplace; an out schema writes
|
|
|
the result into an explicitly provided out argument.
|
|
|
"""
|
|
|
is_out = bool(self.arguments.out)
|
|
|
is_scratch = bool(
|
|
|
[arg for arg in self.arguments.out if arg.name.startswith("_scratch_")]
|
|
|
)
|
|
|
is_inplace = self.name.name.inplace
|
|
|
is_mutable = any(
|
|
|
a.annotation is not None and a.annotation.is_write
|
|
|
for a in self.arguments.post_self_positional
|
|
|
)
|
|
|
assert not (is_out and is_inplace)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_inplace:
|
|
|
return SchemaKind.inplace
|
|
|
elif is_scratch:
|
|
|
assert is_out, (
|
|
|
"invariant: all scratch operators are expected to be out= operators too"
|
|
|
)
|
|
|
return SchemaKind.scratch
|
|
|
elif is_out:
|
|
|
assert not is_scratch, (
|
|
|
"We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
|
|
|
)
|
|
|
return SchemaKind.out
|
|
|
elif is_mutable:
|
|
|
return SchemaKind.mutable
|
|
|
else:
|
|
|
return SchemaKind.functional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def aliased_return_names(self) -> list[str | None]:
|
|
|
outs: list[str | None] = []
|
|
|
for r in self.returns:
|
|
|
aliased_args = [
|
|
|
a
|
|
|
for a in self.arguments.flat_all
|
|
|
if a.annotation is not None and a.annotation == r.annotation
|
|
|
]
|
|
|
if len(aliased_args) == 0:
|
|
|
outs.append(None)
|
|
|
elif len(aliased_args) == 1:
|
|
|
outs.append(aliased_args[0].name)
|
|
|
else:
|
|
|
aliased_names = ", ".join(a.name for a in aliased_args)
|
|
|
raise AssertionError(
|
|
|
f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})"
|
|
|
)
|
|
|
return outs
|
|
|
|
|
|
def signature(
|
|
|
self,
|
|
|
*,
|
|
|
strip_default: bool = False,
|
|
|
strip_view_copy_name: bool = False,
|
|
|
keep_return_names: bool = False,
|
|
|
) -> FunctionSchema:
|
|
|
"""
|
|
|
Certain schemas are 'related', in that they are simply
|
|
|
inplace/out/functional versions of the same function. This method
|
|
|
factors these schemas into the "core" functional signature which
|
|
|
is equal across all versions.
|
|
|
|
|
|
Here is what normalization happens to the schema to convert
|
|
|
it to a signature:
|
|
|
- The overload name is stripped (name is retained, since
|
|
|
it expresses semantic content about what the function does)
|
|
|
- Inplace is set False
|
|
|
- Out arguments are stripped
|
|
|
- Mutable post_self_positional args are converted to returns
|
|
|
- Mutability annotations are stripped (this is sound
|
|
|
because you cannot overload on mutability annotation)
|
|
|
- Return names are stripped since they are not overloadable and
|
|
|
some variants have return names but some not
|
|
|
- TensorOptions are dropped
|
|
|
because out= variants of factory functions don't include them
|
|
|
(and we want to be able to pair up factory functions with their out variants)
|
|
|
|
|
|
Finally, we want to be able to pair up related "view" and their
|
|
|
corresponding "view_copy" operators. We do this by optionally
|
|
|
stripping the trailing "_copy" from the base name.
|
|
|
|
|
|
Example of a mutable op before and after:
|
|
|
|
|
|
f.func (Mutable operator):
|
|
|
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
|
|
|
|
|
|
f.func (Corresponding functional operator):
|
|
|
_fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950
|
|
|
|
|
|
f.func.signature() output:
|
|
|
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950
|
|
|
"""
|
|
|
|
|
|
def strip_ret_annotation(r: Return) -> Return:
|
|
|
return Return(
|
|
|
name=r.name if keep_return_names else None,
|
|
|
type=r.type,
|
|
|
annotation=None,
|
|
|
)
|
|
|
|
|
|
base_name = self.name.name.base
|
|
|
if strip_view_copy_name:
|
|
|
if base_name.endswith("_copy"):
|
|
|
base_name = base_name.replace("_copy", "")
|
|
|
elif base_name.endswith("_scatter"):
|
|
|
base_name = base_name.replace("scatter", "inverse")
|
|
|
|
|
|
|
|
|
returns_from_mutable_inputs = tuple(
|
|
|
|
|
|
|
|
|
|
|
|
Return(
|
|
|
name=f"{a.name}_out" if keep_return_names else None,
|
|
|
type=a.type,
|
|
|
annotation=None,
|
|
|
)
|
|
|
for a in itertools.chain(
|
|
|
|
|
|
|
|
|
(
|
|
|
[self.arguments.self_arg.argument]
|
|
|
if self.arguments.self_arg is not None
|
|
|
else []
|
|
|
),
|
|
|
self.arguments.out,
|
|
|
self.arguments.post_self_positional,
|
|
|
)
|
|
|
if a.annotation is not None
|
|
|
and a.annotation.is_write
|
|
|
and not any(a.annotation == r.annotation for r in self.returns)
|
|
|
)
|
|
|
original_returns = tuple(map(strip_ret_annotation, self.returns))
|
|
|
|
|
|
returns = original_returns + returns_from_mutable_inputs
|
|
|
|
|
|
args_sig = self.arguments.signature(strip_default=strip_default)
|
|
|
|
|
|
if str(self.name) == "bernoulli.p":
|
|
|
args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
|
|
|
|
|
|
return FunctionSchema(
|
|
|
name=OperatorName(
|
|
|
name=BaseOperatorName(
|
|
|
base=base_name,
|
|
|
inplace=False,
|
|
|
dunder_method=self.name.name.dunder_method,
|
|
|
),
|
|
|
overload_name="",
|
|
|
),
|
|
|
arguments=args_sig,
|
|
|
returns=returns,
|
|
|
)
|
|
|
|
|
|
def view_signature(self) -> FunctionSchema:
|
|
|
return self.signature(strip_view_copy_name=True)
|
|
|
|
|
|
def with_name(self, name: OperatorName) -> FunctionSchema:
|
|
|
return FunctionSchema(
|
|
|
name=name,
|
|
|
arguments=self.arguments,
|
|
|
returns=self.returns,
|
|
|
)
|
|
|
|
|
|
@property
|
|
|
def modifies_arguments(self) -> bool:
|
|
|
return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
|
|
|
|
|
|
def has_symint(self) -> bool:
|
|
|
return self.arguments.has_symint_arg()
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
all_arguments_str = str(self.arguments)
|
|
|
if len(self.returns) == 1:
|
|
|
returns = str(self.returns[0])
|
|
|
else:
|
|
|
returns = "(" + ", ".join(map(str, self.returns)) + ")"
|
|
|
return f"{self.name}({all_arguments_str}) -> {returns}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class Annotation:
|
|
|
|
|
|
|
|
|
alias_set: tuple[str, ...]
|
|
|
is_write: bool
|
|
|
alias_set_after: tuple[str, ...]
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(ann: str) -> Annotation:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann)
|
|
|
|
|
|
assert m is not None, f"unrecognized alias annotation {ann}"
|
|
|
before_alias = m.group(1) + (m.group(2) if m.group(2) else "")
|
|
|
alias_set = tuple(before_alias.split("|"))
|
|
|
is_write = m.group(3) == "!"
|
|
|
assert not (is_write and len(alias_set) > 1), (
|
|
|
f"alias set larger than 1 is not mutable, got {ann} instead."
|
|
|
)
|
|
|
after_set = tuple(m.group(5).split("|")) if m.group(5) else ()
|
|
|
assert not (len(before_alias) > 1 and len(after_set) > 1), (
|
|
|
f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead."
|
|
|
)
|
|
|
r = Annotation(
|
|
|
alias_set=alias_set, is_write=is_write, alias_set_after=after_set
|
|
|
)
|
|
|
assert str(r) == ann, f"{r} != {ann}"
|
|
|
return r
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
alias_set = "|".join(self.alias_set)
|
|
|
if self.is_write:
|
|
|
alias_set = f"{alias_set}!"
|
|
|
alias_set_after = "|".join(self.alias_set_after)
|
|
|
if alias_set_after:
|
|
|
alias_set = f"{alias_set} -> {alias_set_after}"
|
|
|
return alias_set
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class Type:
|
|
|
@staticmethod
|
|
|
def parse(t: str) -> Type:
|
|
|
r = Type._parse(t)
|
|
|
assert str(r) == t, f"{r} != {t}"
|
|
|
return r
|
|
|
|
|
|
@staticmethod
|
|
|
def _parse(t: str) -> Type:
|
|
|
m = re.match(r"^(.+)\?$", t)
|
|
|
if m is not None:
|
|
|
return OptionalType(Type.parse(m.group(1)))
|
|
|
m = re.match(r"^(.+)\[([0-9]+)?\]$", t)
|
|
|
if m is not None:
|
|
|
size = int(m.group(2)) if m.group(2) is not None else None
|
|
|
return ListType(elem=Type.parse(m.group(1)), size=size)
|
|
|
|
|
|
|
|
|
m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t)
|
|
|
if m is not None:
|
|
|
return CustomClassType(m.group(1))
|
|
|
try:
|
|
|
return BaseType(BaseTy[t])
|
|
|
except KeyError as e:
|
|
|
raise RuntimeError(f"unrecognized type {t}") from e
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def is_tensor_like(self) -> bool:
|
|
|
return self.is_base_ty_like(BaseTy.Tensor)
|
|
|
|
|
|
def is_generator_like(self) -> bool:
|
|
|
return self.is_base_ty_like(BaseTy.Generator)
|
|
|
|
|
|
def is_symint_like(self) -> bool:
|
|
|
return self.is_base_ty_like(BaseTy.SymInt)
|
|
|
|
|
|
def is_nullable(self) -> bool:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def is_list_like(self) -> ListType | None:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
class BaseTy(Enum):
|
|
|
Generator = auto()
|
|
|
ScalarType = auto()
|
|
|
Tensor = auto()
|
|
|
int = auto()
|
|
|
Dimname = auto()
|
|
|
DimVector = auto()
|
|
|
float = auto()
|
|
|
str = auto()
|
|
|
bool = auto()
|
|
|
Layout = auto()
|
|
|
Device = auto()
|
|
|
DeviceIndex = auto()
|
|
|
Scalar = auto()
|
|
|
MemoryFormat = auto()
|
|
|
QScheme = auto()
|
|
|
Storage = auto()
|
|
|
Stream = auto()
|
|
|
SymInt = auto()
|
|
|
SymBool = auto()
|
|
|
GraphModule = auto()
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class BaseType(Type):
|
|
|
name: BaseTy
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
return f"{self.name.name}"
|
|
|
|
|
|
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
|
|
return self.name == base_ty
|
|
|
|
|
|
def is_nullable(self) -> bool:
|
|
|
return False
|
|
|
|
|
|
def is_list_like(self) -> ListType | None:
|
|
|
return None
|
|
|
|
|
|
def is_symint_like(self) -> bool:
|
|
|
return self.name == BaseTy.SymInt
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class OptionalType(Type):
|
|
|
elem: Type
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
return f"{self.elem}?"
|
|
|
|
|
|
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
|
|
return self.elem.is_base_ty_like(base_ty)
|
|
|
|
|
|
def is_symint_like(self) -> bool:
|
|
|
return self.elem.is_symint_like()
|
|
|
|
|
|
def is_nullable(self) -> bool:
|
|
|
return True
|
|
|
|
|
|
def is_list_like(self) -> ListType | None:
|
|
|
return self.elem.is_list_like()
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class CustomClassType(Type):
|
|
|
class_name: str
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
"""
|
|
|
Return the class name will prefix __torch__.torch.classes
|
|
|
"""
|
|
|
return f"__torch__.torch.classes.{self.class_name}"
|
|
|
|
|
|
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
|
|
return False
|
|
|
|
|
|
def is_symint_like(self) -> bool:
|
|
|
return False
|
|
|
|
|
|
def is_nullable(self) -> bool:
|
|
|
"""
|
|
|
Assume a custom class is not nullable.
|
|
|
"""
|
|
|
return False
|
|
|
|
|
|
def is_list_like(self) -> ListType | None:
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class ListType(Type):
|
|
|
elem: Type
|
|
|
size: int | None
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
size = f"{self.size}" if self.size else ""
|
|
|
return f"{self.elem}[{size}]"
|
|
|
|
|
|
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
|
|
return self.elem.is_base_ty_like(base_ty)
|
|
|
|
|
|
def is_symint_like(self) -> bool:
|
|
|
return self.elem.is_symint_like()
|
|
|
|
|
|
def is_nullable(self) -> bool:
|
|
|
return self.elem.is_nullable()
|
|
|
|
|
|
def is_list_like(self) -> ListType | None:
|
|
|
return self
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class Argument:
|
|
|
|
|
|
|
|
|
|
|
|
name: str
|
|
|
type: Type
|
|
|
default: str | None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
annotation: Annotation | None
|
|
|
|
|
|
@property
|
|
|
def alias_info(self) -> Annotation | None:
|
|
|
return self.annotation
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(arg: str) -> Argument:
|
|
|
name: str
|
|
|
default: str | None
|
|
|
assert " " in arg, f"illegal argument '{arg}'"
|
|
|
if "=" in arg:
|
|
|
assert arg.count("=") == 1, f"illegal argument with default value: '{arg}'"
|
|
|
type_and_annot_and_name, default = arg.split("=")
|
|
|
type_and_annot, name = type_and_annot_and_name.rsplit(" ", 1)
|
|
|
name_and_default = f"{name}={default}"
|
|
|
else:
|
|
|
type_and_annot, name_and_default = arg.rsplit(" ", 1)
|
|
|
name = name_and_default
|
|
|
default = None
|
|
|
|
|
|
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
|
|
|
annotation: Annotation | None
|
|
|
if match:
|
|
|
|
|
|
assert match.group(2) in [
|
|
|
"",
|
|
|
"?",
|
|
|
"[]",
|
|
|
], "unrecognized alias analysis form with Tensor"
|
|
|
type_s = "Tensor" + match.group(2)
|
|
|
annotation = Annotation.parse(match.group(1))
|
|
|
else:
|
|
|
type_s = type_and_annot
|
|
|
annotation = None
|
|
|
type = Type.parse(type_s)
|
|
|
r = Argument(
|
|
|
name=name,
|
|
|
type=type,
|
|
|
default=default,
|
|
|
annotation=annotation,
|
|
|
)
|
|
|
assert str(r) == arg, f"{str(r)} != {arg}"
|
|
|
return r
|
|
|
|
|
|
@property
|
|
|
def is_write(self) -> bool:
|
|
|
return self.annotation is not None and self.annotation.is_write
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
type = f"{self.type}"
|
|
|
if self.annotation:
|
|
|
assert type in ["Tensor", "Tensor?", "Tensor[]"]
|
|
|
type = type.replace("Tensor", f"Tensor({self.annotation})")
|
|
|
if self.name is None:
|
|
|
return type
|
|
|
else:
|
|
|
mb_default = ""
|
|
|
if self.default:
|
|
|
mb_default = f"={self.default}"
|
|
|
return f"{type} {self.name}{mb_default}"
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class Return:
|
|
|
name: str | None
|
|
|
type: Type
|
|
|
annotation: Annotation | None
|
|
|
|
|
|
@property
|
|
|
def alias_info(self) -> Annotation | None:
|
|
|
return self.annotation
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(arg: str) -> Return:
|
|
|
name: str | None
|
|
|
if " " in arg:
|
|
|
type_and_annot, name = arg.rsplit(" ", 1)
|
|
|
else:
|
|
|
type_and_annot = arg
|
|
|
name = None
|
|
|
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
|
|
|
annotation: Annotation | None
|
|
|
if match:
|
|
|
|
|
|
assert match.group(2) in [
|
|
|
"",
|
|
|
"?",
|
|
|
"[]",
|
|
|
], "unrecognized alias analysis form with Tensor"
|
|
|
type_s = "Tensor" + match.group(2)
|
|
|
annotation = Annotation.parse(match.group(1))
|
|
|
else:
|
|
|
type_s = type_and_annot
|
|
|
annotation = None
|
|
|
type = Type.parse(type_s)
|
|
|
r = Return(
|
|
|
name=name,
|
|
|
type=type,
|
|
|
annotation=annotation,
|
|
|
)
|
|
|
assert str(r) == arg, f"{str(r)} != {arg}"
|
|
|
return r
|
|
|
|
|
|
@property
|
|
|
def is_write(self) -> bool:
|
|
|
return self.annotation is not None and self.annotation.is_write
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
type = f"{self.type}"
|
|
|
if self.annotation:
|
|
|
assert type in ["Tensor", "Tensor?", "Tensor[]"]
|
|
|
type = type.replace("Tensor", f"Tensor({self.annotation})")
|
|
|
if self.name is None:
|
|
|
return type
|
|
|
else:
|
|
|
return f"{type} {self.name}"
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class SelfArgument:
|
|
|
argument: Argument
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class TensorOptionsArguments:
|
|
|
dtype: Argument
|
|
|
layout: Argument
|
|
|
device: Argument
|
|
|
pin_memory: Argument
|
|
|
|
|
|
def all(self) -> Sequence[Argument]:
|
|
|
return [self.dtype, self.layout, self.device, self.pin_memory]
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class Arguments:
|
|
|
|
|
|
|
|
|
|
|
|
pre_self_positional: tuple[Argument, ...]
|
|
|
self_arg: SelfArgument | None
|
|
|
post_self_positional: tuple[Argument, ...]
|
|
|
|
|
|
pre_tensor_options_kwarg_only: tuple[Argument, ...]
|
|
|
tensor_options: TensorOptionsArguments | None
|
|
|
|
|
|
|
|
|
|
|
|
post_tensor_options_kwarg_only: tuple[Argument, ...]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out: tuple[Argument, ...]
|
|
|
|
|
|
@property
|
|
|
def flat_non_out(self) -> Sequence[Argument]:
|
|
|
ret: list[Argument] = []
|
|
|
ret.extend(self.flat_positional)
|
|
|
ret.extend(self.flat_kwarg_only)
|
|
|
return ret
|
|
|
|
|
|
@property
|
|
|
def flat_positional(self) -> Sequence[Argument]:
|
|
|
ret: list[Argument] = []
|
|
|
ret.extend(self.pre_self_positional)
|
|
|
if self.self_arg is not None:
|
|
|
ret.append(self.self_arg.argument)
|
|
|
ret.extend(self.post_self_positional)
|
|
|
return ret
|
|
|
|
|
|
@property
|
|
|
def post_self_positional_mutable(self) -> Sequence[Argument]:
|
|
|
return [a for a in self.post_self_positional if a.is_write]
|
|
|
|
|
|
|
|
|
@property
|
|
|
def flat_kwarg_only(self) -> Sequence[Argument]:
|
|
|
ret: list[Argument] = []
|
|
|
ret.extend(self.pre_tensor_options_kwarg_only)
|
|
|
if self.tensor_options is not None:
|
|
|
ret.extend(self.tensor_options.all())
|
|
|
ret.extend(self.post_tensor_options_kwarg_only)
|
|
|
return ret
|
|
|
|
|
|
@property
|
|
|
def flat_all(self) -> Sequence[Argument]:
|
|
|
ret: list[Argument] = []
|
|
|
ret.extend(self.flat_positional)
|
|
|
ret.extend(self.flat_kwarg_only)
|
|
|
ret.extend(self.out)
|
|
|
return ret
|
|
|
|
|
|
@property
|
|
|
def non_out(
|
|
|
self,
|
|
|
) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
|
|
|
ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
|
|
|
ret.extend(self.positional)
|
|
|
ret.extend(self.kwarg_only)
|
|
|
return ret
|
|
|
|
|
|
@property
|
|
|
def positional(self) -> Sequence[Argument | SelfArgument]:
|
|
|
ret: list[Argument | SelfArgument] = []
|
|
|
ret.extend(self.pre_self_positional)
|
|
|
if self.self_arg is not None:
|
|
|
ret.append(self.self_arg)
|
|
|
ret.extend(self.post_self_positional)
|
|
|
return ret
|
|
|
|
|
|
@property
|
|
|
def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]:
|
|
|
ret: list[Argument | TensorOptionsArguments] = []
|
|
|
ret.extend(self.pre_tensor_options_kwarg_only)
|
|
|
if self.tensor_options is not None:
|
|
|
ret.append(self.tensor_options)
|
|
|
ret.extend(self.post_tensor_options_kwarg_only)
|
|
|
return ret
|
|
|
|
|
|
@property
|
|
|
def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
|
|
|
ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
|
|
|
ret.extend(self.positional)
|
|
|
ret.extend(self.kwarg_only)
|
|
|
ret.extend(self.out)
|
|
|
return ret
|
|
|
|
|
|
def mutable_arg_names(self) -> list[str]:
|
|
|
return [
|
|
|
a.name
|
|
|
for a in self.flat_all
|
|
|
if a.annotation is not None and a.annotation.is_write
|
|
|
]
|
|
|
|
|
|
def has_tensor_arg(self) -> bool:
|
|
|
return any(a.type.is_tensor_like() for a in self.flat_non_out)
|
|
|
|
|
|
def has_symint_arg(self) -> bool:
|
|
|
return any(a.type.is_symint_like() for a in self.flat_non_out)
|
|
|
|
|
|
def has_generator_arg(self) -> bool:
|
|
|
return any(a.type.is_generator_like() for a in self.flat_non_out)
|
|
|
|
|
|
def signature(self, *, strip_default: bool = False) -> Arguments:
|
|
|
|
|
|
|
|
|
def strip_arg_annotation(a: Argument) -> Argument:
|
|
|
return Argument(
|
|
|
name=a.name,
|
|
|
type=a.type,
|
|
|
default=a.default if not strip_default else None,
|
|
|
annotation=None,
|
|
|
)
|
|
|
|
|
|
return Arguments(
|
|
|
pre_self_positional=tuple(
|
|
|
map(strip_arg_annotation, self.pre_self_positional)
|
|
|
),
|
|
|
self_arg=(
|
|
|
SelfArgument(strip_arg_annotation(self.self_arg.argument))
|
|
|
if self.self_arg is not None
|
|
|
else None
|
|
|
),
|
|
|
post_self_positional=tuple(
|
|
|
map(strip_arg_annotation, self.post_self_positional)
|
|
|
),
|
|
|
|
|
|
|
|
|
pre_tensor_options_kwarg_only=tuple(
|
|
|
map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
|
|
|
)
|
|
|
+ tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
|
|
|
|
|
|
|
|
|
tensor_options=None,
|
|
|
post_tensor_options_kwarg_only=(),
|
|
|
|
|
|
out=(),
|
|
|
)
|
|
|
|
|
|
def remove_self_annotation(self) -> Arguments:
|
|
|
assert self.self_arg is not None
|
|
|
return dataclasses.replace(
|
|
|
self,
|
|
|
self_arg=SelfArgument(
|
|
|
dataclasses.replace(self.self_arg.argument, annotation=None)
|
|
|
),
|
|
|
)
|
|
|
|
|
|
def with_out_args(self, outs: list[Argument]) -> Arguments:
|
|
|
assert len(self.out) == 0
|
|
|
return dataclasses.replace(
|
|
|
self,
|
|
|
out=tuple(outs),
|
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
|
def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]:
|
|
|
positional: list[Argument] = []
|
|
|
kwarg_only: list[Argument] = []
|
|
|
out: list[Argument] = []
|
|
|
arguments_acc = positional
|
|
|
|
|
|
|
|
|
|
|
|
for arg in args.split(", "):
|
|
|
if not arg:
|
|
|
continue
|
|
|
if arg == "*":
|
|
|
assert arguments_acc is positional, (
|
|
|
"invalid syntax: kwarg-only specifier * can only occur once"
|
|
|
)
|
|
|
arguments_acc = kwarg_only
|
|
|
continue
|
|
|
parg = Argument.parse(arg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if parg.annotation is not None and parg.annotation.is_write:
|
|
|
if arguments_acc is positional:
|
|
|
pass
|
|
|
elif arguments_acc is kwarg_only:
|
|
|
arguments_acc = out
|
|
|
else:
|
|
|
assert arguments_acc is not out
|
|
|
arguments_acc.append(parg)
|
|
|
|
|
|
return positional, kwarg_only, out
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(args: str) -> Arguments:
|
|
|
"""
|
|
|
Input: 'int x, int y, int z'
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
positional, kwarg_only, out = Arguments._preparse(args)
|
|
|
|
|
|
|
|
|
self_ix = None
|
|
|
for i, a in enumerate(positional):
|
|
|
if a.name == "self":
|
|
|
self_ix = i
|
|
|
break
|
|
|
pre_self_positional: list[Argument]
|
|
|
self_arg: SelfArgument | None
|
|
|
post_self_positional: list[Argument]
|
|
|
if self_ix is not None:
|
|
|
pre_self_positional = positional[:self_ix]
|
|
|
self_arg = SelfArgument(positional[self_ix])
|
|
|
post_self_positional = positional[self_ix + 1 :]
|
|
|
else:
|
|
|
pre_self_positional = []
|
|
|
self_arg = None
|
|
|
post_self_positional = positional
|
|
|
|
|
|
|
|
|
pre_tensor_options_kwarg_only: list[Argument] = []
|
|
|
tensor_options: TensorOptionsArguments | None = None
|
|
|
post_tensor_options_kwarg_only: list[Argument] = []
|
|
|
kwarg_only_acc = pre_tensor_options_kwarg_only
|
|
|
|
|
|
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
|
|
|
return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
|
|
|
|
|
|
predicates = [
|
|
|
pred("dtype", Type.parse("ScalarType")),
|
|
|
pred("layout", Type.parse("Layout")),
|
|
|
pred("device", Type.parse("Device")),
|
|
|
pred("pin_memory", Type.parse("bool")),
|
|
|
]
|
|
|
|
|
|
i = 0
|
|
|
while i < len(kwarg_only):
|
|
|
|
|
|
if i <= len(kwarg_only) - len(predicates):
|
|
|
|
|
|
if all(
|
|
|
p(a)
|
|
|
for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])
|
|
|
):
|
|
|
assert kwarg_only_acc is pre_tensor_options_kwarg_only
|
|
|
|
|
|
tensor_options = TensorOptionsArguments(
|
|
|
dtype=kwarg_only[i],
|
|
|
layout=kwarg_only[i + 1],
|
|
|
device=kwarg_only[i + 2],
|
|
|
pin_memory=kwarg_only[i + 3],
|
|
|
)
|
|
|
i += len(predicates)
|
|
|
kwarg_only_acc = post_tensor_options_kwarg_only
|
|
|
continue
|
|
|
kwarg_only_acc.append(kwarg_only[i])
|
|
|
i += 1
|
|
|
|
|
|
return Arguments(
|
|
|
pre_self_positional=tuple(pre_self_positional),
|
|
|
self_arg=self_arg,
|
|
|
post_self_positional=tuple(post_self_positional),
|
|
|
pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only),
|
|
|
tensor_options=tensor_options,
|
|
|
post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only),
|
|
|
out=tuple(out),
|
|
|
)
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
all_arguments: list[str] = []
|
|
|
all_arguments.extend(map(str, self.flat_positional))
|
|
|
if self.flat_kwarg_only or self.out:
|
|
|
all_arguments.append("*")
|
|
|
all_arguments.extend(map(str, self.flat_kwarg_only))
|
|
|
all_arguments.extend(map(str, self.out))
|
|
|
return ", ".join(all_arguments)
|
|
|
|
|
|
def __post_init__(self) -> None:
|
|
|
|
|
|
|
|
|
if self.self_arg is None:
|
|
|
assert not self.pre_self_positional
|
|
|
if self.tensor_options is None:
|
|
|
assert not self.post_tensor_options_kwarg_only
|
|
|
|
|
|
|
|
|
|
|
|
mutable_pre_self_positionals = [
|
|
|
a
|
|
|
for a in self.pre_self_positional
|
|
|
if a.annotation is not None and a.annotation.is_write
|
|
|
]
|
|
|
assert len(mutable_pre_self_positionals) == 0, (
|
|
|
"mutable pre_self_positional arguments are not currently supported in the schema"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AUGMENTED_ASSIGNMENT_NAMES = [
|
|
|
"add",
|
|
|
"sub",
|
|
|
"mul",
|
|
|
"div",
|
|
|
"mod",
|
|
|
"pow",
|
|
|
"lshift",
|
|
|
"rshift",
|
|
|
"and",
|
|
|
"xor",
|
|
|
"or",
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class BaseOperatorName:
|
|
|
base: str
|
|
|
inplace: bool
|
|
|
dunder_method: bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
functional_overload: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace: Optional[str] = None
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(op: str) -> BaseOperatorName:
|
|
|
assert op != ""
|
|
|
assert not op.endswith("_out"), (
|
|
|
"_out suffix is reserved and not permitted for operator names; "
|
|
|
"did you mean to specify an out overload name instead?"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
match = re.match(r"^(?:(.*)::)?(.*)$", op)
|
|
|
namespace = match.group(1) if match else ""
|
|
|
op_without_ns = match.group(2) if match else op
|
|
|
m = re.match(r"^__([^_]+)__$", op_without_ns)
|
|
|
if m is not None:
|
|
|
dunder_method = True
|
|
|
base = m.group(1)
|
|
|
if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES):
|
|
|
inplace = True
|
|
|
base = base[1:]
|
|
|
else:
|
|
|
inplace = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert base[0] != "i"
|
|
|
else:
|
|
|
dunder_method = False
|
|
|
base = op_without_ns
|
|
|
if base[-1] == "_":
|
|
|
inplace = True
|
|
|
base = base[:-1]
|
|
|
else:
|
|
|
inplace = False
|
|
|
|
|
|
|
|
|
functional_suffix = "_functional"
|
|
|
if base.endswith(functional_suffix):
|
|
|
functional_overload = True
|
|
|
base = base[: -len(functional_suffix)]
|
|
|
|
|
|
|
|
|
assert not dunder_method and not inplace
|
|
|
else:
|
|
|
functional_overload = False
|
|
|
|
|
|
r = BaseOperatorName(
|
|
|
base=base,
|
|
|
inplace=inplace,
|
|
|
dunder_method=dunder_method,
|
|
|
functional_overload=functional_overload,
|
|
|
namespace=namespace,
|
|
|
)
|
|
|
assert str(r) == op, f"{str(r)} != {op}"
|
|
|
return r
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
namespace_prefix = f"{self.namespace}::" if self.namespace else ""
|
|
|
if self.dunder_method:
|
|
|
i = "i" if self.inplace else ""
|
|
|
return f"{namespace_prefix}__{i}{self.base}__"
|
|
|
else:
|
|
|
i = (
|
|
|
"_"
|
|
|
if self.inplace
|
|
|
else "_functional"
|
|
|
if self.functional_overload
|
|
|
else ""
|
|
|
)
|
|
|
return f"{namespace_prefix}{self.base}{i}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class OperatorName:
|
|
|
name: BaseOperatorName
|
|
|
overload_name: str
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(op_name: str) -> OperatorName:
|
|
|
if "." in op_name:
|
|
|
name, overload_name = op_name.split(".", 1)
|
|
|
else:
|
|
|
name = op_name
|
|
|
overload_name = ""
|
|
|
r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name)
|
|
|
assert str(r) == op_name, f"{str(r)} != {op_name}"
|
|
|
return r
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
if self.overload_name:
|
|
|
return f"{self.name}.{self.overload_name}"
|
|
|
else:
|
|
|
return f"{self.name}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unambiguous_name(self) -> str:
|
|
|
if self.overload_name:
|
|
|
return f"{self.name}_{self.overload_name}"
|
|
|
else:
|
|
|
return f"{self.name}"
|
|
|
|
|
|
def remove_inplace(self) -> OperatorName:
|
|
|
return OperatorName(
|
|
|
name=BaseOperatorName(
|
|
|
base=self.name.base,
|
|
|
inplace=False,
|
|
|
dunder_method=self.name.dunder_method,
|
|
|
),
|
|
|
overload_name=self.overload_name,
|
|
|
)
|
|
|
|
|
|
def with_overload(self, overload: str) -> OperatorName:
|
|
|
return OperatorName(
|
|
|
name=BaseOperatorName(
|
|
|
base=self.name.base,
|
|
|
inplace=False,
|
|
|
dunder_method=self.name.dunder_method,
|
|
|
),
|
|
|
overload_name=overload,
|
|
|
)
|
|
|
|
|
|
|
|
|
def gets_generated_out_inplace_wrapper(
|
|
|
f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex
|
|
|
) -> bool:
|
|
|
return (
|
|
|
f.func.kind() is not SchemaKind.functional
|
|
|
and not b.has_kernel(f)
|
|
|
and b.has_kernel(g.functional)
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class NativeFunctionsViewGroup:
|
|
|
view: NativeFunction
|
|
|
|
|
|
|
|
|
|
|
|
view_copy: NativeFunction | None
|
|
|
|
|
|
view_inplace: NativeFunction | None
|
|
|
|
|
|
def __post_init__(self) -> None:
|
|
|
assert self.view.is_view_op
|
|
|
if self.view_copy is None:
|
|
|
assert not gets_generated_view_copy(self.view), (
|
|
|
f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs."
|
|
|
" The codegen expects you to add a corresponding operator to native_functions.yaml:"
|
|
|
f" {get_view_copy_name(self.view)!s}."
|
|
|
" See Note [view_copy NativeFunctions] for details."
|
|
|
)
|
|
|
else:
|
|
|
assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter"))
|
|
|
assert self.view.func.signature() == self.view_copy.func.signature(
|
|
|
strip_view_copy_name=True,
|
|
|
)
|
|
|
assert "view_copy" in self.view_copy.tags, (
|
|
|
f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
|
|
|
" view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml."
|
|
|
" See Note [view_copy NativeFunction] for details."
|
|
|
)
|
|
|
if self.view_inplace is not None:
|
|
|
assert self.view.func.signature() == self.view_inplace.func.signature()
|
|
|
|
|
|
if self.view.has_composite_implicit_autograd_kernel:
|
|
|
if self.view_inplace is not None:
|
|
|
assert self.view_inplace.has_composite_implicit_autograd_kernel, (
|
|
|
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
|
|
|
" both have CompositeImplicitAutograd kernels, or both not have composite kernels."
|
|
|
)
|
|
|
if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
|
|
|
if self.view_inplace is not None:
|
|
|
assert self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel, (
|
|
|
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
|
|
|
" both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
|
|
|
)
|
|
|
|
|
|
def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
|
|
|
yield self.view
|
|
|
if self.view_inplace is not None:
|
|
|
yield self.view_inplace
|
|
|
if self.view_copy is not None and include_copy:
|
|
|
yield self.view_copy
|
|
|
|
|
|
@property
|
|
|
def root_name(self) -> str:
|
|
|
return self.view.root_name
|
|
|
|
|
|
@property
|
|
|
def composite(self) -> bool:
|
|
|
|
|
|
|
|
|
return self.view.has_composite_implicit_autograd_kernel
|
|
|
|
|
|
|
|
|
def gets_generated_view_copy(f: NativeFunction) -> bool:
|
|
|
|
|
|
if not f.is_view_op:
|
|
|
return False
|
|
|
|
|
|
|
|
|
if f.has_composite_implicit_autograd_kernel:
|
|
|
return False
|
|
|
|
|
|
if "inplace_view" in f.tags:
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if f.func.name.name.base.endswith("_inverse"):
|
|
|
return False
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_view_copy_name(f: NativeFunction) -> OperatorName:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
list_of_ops_with_explicit_view_copy_operators = ["narrow"]
|
|
|
if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators:
|
|
|
assert gets_generated_view_copy(f)
|
|
|
|
|
|
base_name = f"{f.func.name.name.base}_copy"
|
|
|
view_copy_name = OperatorName(
|
|
|
name=BaseOperatorName(
|
|
|
base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method
|
|
|
),
|
|
|
overload_name=f.func.name.overload_name,
|
|
|
)
|
|
|
return view_copy_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_returns(return_decl: str) -> tuple[Return, ...]:
|
|
|
"""
|
|
|
Input: '()'
|
|
|
Output: []
|
|
|
"""
|
|
|
if return_decl == "()":
|
|
|
return ()
|
|
|
if return_decl[0] == "(" and return_decl[-1] == ")":
|
|
|
return_decl = return_decl[1:-1]
|
|
|
return tuple(Return.parse(arg) for arg in return_decl.split(", "))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class Precompute:
|
|
|
|
|
|
|
|
|
replace: dict[str, list[Argument]]
|
|
|
|
|
|
add: list[Argument]
|
|
|
|
|
|
@staticmethod
|
|
|
def parse(src: object) -> Precompute:
|
|
|
assert isinstance(src, list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
add_args = []
|
|
|
if " -> " not in src[-1]:
|
|
|
add_list = src[-1].split(",")
|
|
|
add_args = [Argument.parse(name.strip()) for name in add_list]
|
|
|
src = src[:-1]
|
|
|
|
|
|
replace = {}
|
|
|
for raw_replace_item in src:
|
|
|
assert isinstance(raw_replace_item, str)
|
|
|
assert " -> " in raw_replace_item, (
|
|
|
"precomputed parameters without replacement"
|
|
|
" are allowed only in the last line"
|
|
|
)
|
|
|
|
|
|
arg, with_list_raw = raw_replace_item.split(" -> ")
|
|
|
assert " " not in arg, (
|
|
|
f"illegal kernel param name '{arg}' in precomputed parameters'"
|
|
|
)
|
|
|
with_list = with_list_raw.split(",")
|
|
|
with_list_args = [Argument.parse(name.strip()) for name in with_list]
|
|
|
replace[arg] = with_list_args
|
|
|
|
|
|
r = Precompute(replace=replace, add=add_args)
|
|
|
assert r.to_list() == src, "r.to_list() != src"
|
|
|
return r
|
|
|
|
|
|
def __post_init__(self) -> None:
|
|
|
|
|
|
|
|
|
for a in self.add:
|
|
|
assert a.name.upper() != a.name
|
|
|
for args in self.replace.values():
|
|
|
for a in args:
|
|
|
assert a.name.upper() != a.name
|
|
|
|
|
|
def to_list(self) -> list[str]:
|
|
|
replace_list = []
|
|
|
for kernel_param, replacement_params in self.replace.items():
|
|
|
replacements = ", ".join(str(param) for param in replacement_params)
|
|
|
replace_list.append(f"{kernel_param} -> {replacements}")
|
|
|
|
|
|
return replace_list
|
|
|
|