Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torchgen/dest/__init__.py +19 -0
- .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/native_functions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py +707 -0
- .venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py +63 -0
- .venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py +1005 -0
- .venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py +551 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/model.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/custom_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/custom_ops.py +149 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/et_cpp.py +370 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__init__.py +4 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/signatures.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/signatures.py +76 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/types/types.py +83 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/api/unboxing.py +230 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/model.py +220 -0
- .venv/lib/python3.11/site-packages/torchgen/executorch/parse.py +153 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/native/native_functions.yaml +0 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/native/tags.yaml +74 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/ATenOpList.cpp +36 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp +73 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunction.h +23 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h +29 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h +22 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp +13 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h +19 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Function.h +26 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h +33 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.cpp +103 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.h +143 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyIr.h +19 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h +11 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/MethodOperators.h +24 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h +17 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h +33 -0
- .venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h +23 -0
.venv/lib/python3.11/site-packages/torchgen/dest/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchgen.dest.lazy_ir import (
|
| 2 |
+
generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
|
| 3 |
+
GenLazyIR as GenLazyIR,
|
| 4 |
+
GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition,
|
| 5 |
+
GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition,
|
| 6 |
+
)
|
| 7 |
+
from torchgen.dest.native_functions import (
|
| 8 |
+
compute_native_function_declaration as compute_native_function_declaration,
|
| 9 |
+
)
|
| 10 |
+
from torchgen.dest.register_dispatch_key import (
|
| 11 |
+
gen_registration_headers as gen_registration_headers,
|
| 12 |
+
gen_registration_helpers as gen_registration_helpers,
|
| 13 |
+
RegisterDispatchKey as RegisterDispatchKey,
|
| 14 |
+
)
|
| 15 |
+
from torchgen.dest.ufunc import (
|
| 16 |
+
compute_ufunc_cpu as compute_ufunc_cpu,
|
| 17 |
+
compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel,
|
| 18 |
+
compute_ufunc_cuda as compute_ufunc_cuda,
|
| 19 |
+
)
|
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (913 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-311.pyc
ADDED
|
Binary file (40.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-311.pyc
ADDED
|
Binary file (3.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/native_functions.cpython-311.pyc
ADDED
|
Binary file (3.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-311.pyc
ADDED
|
Binary file (44.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/dest/__pycache__/ufunc.cpython-311.pyc
ADDED
|
Binary file (27.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ir.py
ADDED
|
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
from abc import ABC
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import torchgen.api.dispatcher as dispatcher
|
| 9 |
+
from torchgen.api.lazy import (
|
| 10 |
+
getValueT,
|
| 11 |
+
isValueType,
|
| 12 |
+
LazyArgument,
|
| 13 |
+
LazyIrProperties,
|
| 14 |
+
LazyIrSchema,
|
| 15 |
+
tensorListValueT,
|
| 16 |
+
)
|
| 17 |
+
from torchgen.api.translate import translate
|
| 18 |
+
from torchgen.api.types import (
|
| 19 |
+
BaseCType,
|
| 20 |
+
Binding,
|
| 21 |
+
deviceT,
|
| 22 |
+
DispatcherSignature,
|
| 23 |
+
kernel_signature,
|
| 24 |
+
NativeSignature,
|
| 25 |
+
OptionalCType,
|
| 26 |
+
VectorCType,
|
| 27 |
+
)
|
| 28 |
+
from torchgen.context import method_with_native_function
|
| 29 |
+
from torchgen.dest.lazy_ts_lowering import ts_lowering_body
|
| 30 |
+
from torchgen.model import (
|
| 31 |
+
Argument,
|
| 32 |
+
BackendIndex,
|
| 33 |
+
BackendMetadata,
|
| 34 |
+
BaseTy,
|
| 35 |
+
BaseType,
|
| 36 |
+
FunctionSchema,
|
| 37 |
+
ListType,
|
| 38 |
+
NativeFunction,
|
| 39 |
+
NativeFunctionsGroup,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Given a LazyArgument,
|
| 46 |
+
generate a c++ string for materializing an rvalue of that arg for passing into
|
| 47 |
+
a lazy Node constructor.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
# TODO: Matching on CType seems wrong; should be matching on Type
|
| 51 |
+
if isValueType(arg.lazy_type):
|
| 52 |
+
if isinstance(arg.lazy_type, BaseCType):
|
| 53 |
+
if arg.is_wrapped_scalar:
|
| 54 |
+
return f"node_{arg.name}"
|
| 55 |
+
elif arg.lazy_type.type is tensorListValueT:
|
| 56 |
+
return f"lazy_{arg.name}_tensorlist"
|
| 57 |
+
elif arg.is_symint_or_list:
|
| 58 |
+
return f"GetSymIntValue({arg.name})"
|
| 59 |
+
return f"lazy_{arg.name}->GetIrValue()"
|
| 60 |
+
elif isinstance(arg.lazy_type, OptionalCType):
|
| 61 |
+
if arg.is_symint_or_list:
|
| 62 |
+
# TODO: I don't understand when you should put lazy_ in the name
|
| 63 |
+
# or not
|
| 64 |
+
return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt"
|
| 65 |
+
elif arg.is_wrapped_scalar:
|
| 66 |
+
return f"node_{arg.name}"
|
| 67 |
+
return (
|
| 68 |
+
f"lazy_{arg.name} ? "
|
| 69 |
+
f"std::make_optional(lazy_{arg.name}->GetIrValue()) : "
|
| 70 |
+
"::std::nullopt"
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
raise AssertionError(
|
| 74 |
+
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
# NB: this is here because right now we aren't treating SymInt[] as a
|
| 78 |
+
# value type; when we do this needs to move above
|
| 79 |
+
# NB: we cannot test arg.lazy_type as we've already specified it is an
|
| 80 |
+
# int64_t and so we cannot distinguish between SymInt and int64_t
|
| 81 |
+
if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
|
| 82 |
+
BaseTy.SymInt
|
| 83 |
+
):
|
| 84 |
+
if arg.symint:
|
| 85 |
+
return f"GetSymIntArrayRefValue({arg.name})"
|
| 86 |
+
else:
|
| 87 |
+
return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
|
| 88 |
+
elif isinstance(arg.lazy_type, VectorCType) and isinstance(
|
| 89 |
+
arg.lazy_type.elem, BaseCType
|
| 90 |
+
):
|
| 91 |
+
return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
|
| 92 |
+
elif (
|
| 93 |
+
isinstance(arg.lazy_type, OptionalCType)
|
| 94 |
+
and isinstance(arg.lazy_type.elem, VectorCType)
|
| 95 |
+
and isinstance(arg.lazy_type.elem.elem, BaseCType)
|
| 96 |
+
):
|
| 97 |
+
return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
|
| 98 |
+
else:
|
| 99 |
+
return f"{arg.name}"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def node_ctor_inputs(schema: LazyIrSchema) -> str:
|
| 103 |
+
"""
|
| 104 |
+
Produce a formatted string with the arguments as passed into the constructor of a node class.
|
| 105 |
+
"""
|
| 106 |
+
node_ctor_values = [
|
| 107 |
+
node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
|
| 108 |
+
]
|
| 109 |
+
return ", ".join(node_ctor_values)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def gen_fallback_code(
|
| 113 |
+
schema: LazyIrSchema,
|
| 114 |
+
sig: DispatcherSignature | NativeSignature,
|
| 115 |
+
overload_name: str,
|
| 116 |
+
) -> str:
|
| 117 |
+
"""
|
| 118 |
+
Generate code that falls back to eager conditioned on a predicate
|
| 119 |
+
"""
|
| 120 |
+
dispatcher_sig = DispatcherSignature.from_schema(schema.func)
|
| 121 |
+
exprs = translate(sig.arguments(), dispatcher_sig.arguments())
|
| 122 |
+
fallback_args = ",\n ".join([a.expr for a in exprs])
|
| 123 |
+
if len(overload_name):
|
| 124 |
+
aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
|
| 125 |
+
else:
|
| 126 |
+
aten_op_str = f"ATEN_OP({schema.aten_name})"
|
| 127 |
+
return f"""
|
| 128 |
+
if (force_eager_fallback({aten_symbol(schema)})) {{
|
| 129 |
+
return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call(
|
| 130 |
+
{fallback_args}
|
| 131 |
+
);
|
| 132 |
+
}}
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def aten_symbol(schema: LazyIrSchema) -> str:
|
| 137 |
+
missing_interned_strings = {
|
| 138 |
+
"sigmoid_backward",
|
| 139 |
+
}
|
| 140 |
+
if schema.aten_name in missing_interned_strings:
|
| 141 |
+
return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
|
| 142 |
+
|
| 143 |
+
if not schema.aten_name.startswith("at::"):
|
| 144 |
+
return f"at::aten::{schema.aten_name}"
|
| 145 |
+
else:
|
| 146 |
+
return schema.aten_name
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# converts all tensor-like arguments to meta tensors. Returns:
|
| 150 |
+
# (1) a string containing all of the logic that does the conversions.
|
| 151 |
+
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
| 152 |
+
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
|
| 153 |
+
context: list[Binding] = []
|
| 154 |
+
unwrapped_tensor_args: list[str] = []
|
| 155 |
+
for arg in sig.arguments():
|
| 156 |
+
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
|
| 157 |
+
unwrapped_name = f"{arg.name}_meta"
|
| 158 |
+
unwrapped_tensor_args.append(
|
| 159 |
+
f"auto {unwrapped_name} = to_meta({arg.name});"
|
| 160 |
+
)
|
| 161 |
+
context.append(arg.with_name(unwrapped_name))
|
| 162 |
+
else:
|
| 163 |
+
context.append(arg)
|
| 164 |
+
unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
|
| 165 |
+
return unwrap_tensor_args_str, context
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@dataclass(frozen=True)
|
| 169 |
+
class GenLazyIR(ABC):
|
| 170 |
+
backend_index: BackendIndex
|
| 171 |
+
backend_name: str
|
| 172 |
+
node_base: str
|
| 173 |
+
use_lazy_shape: bool
|
| 174 |
+
|
| 175 |
+
@method_with_native_function
|
| 176 |
+
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
|
| 177 |
+
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
| 178 |
+
metadata = self.backend_index.get_kernel(
|
| 179 |
+
f.functional if isinstance(f, NativeFunctionsGroup) else f
|
| 180 |
+
)
|
| 181 |
+
schema = LazyIrSchema(
|
| 182 |
+
func, symint=metadata is not None and metadata.supports_symint()
|
| 183 |
+
)
|
| 184 |
+
return self.gen(schema)
|
| 185 |
+
|
| 186 |
+
# there is no lowering functionality generated unless this IR base class is subclassed and
|
| 187 |
+
# implemented as a backend-specific node
|
| 188 |
+
def lowering_function(self, schema: LazyIrSchema) -> str:
|
| 189 |
+
return ""
|
| 190 |
+
|
| 191 |
+
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
| 192 |
+
return ""
|
| 193 |
+
|
| 194 |
+
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
| 195 |
+
return f"""bool CanBeReused({node_ctor_args}) const {{
|
| 196 |
+
return false;
|
| 197 |
+
}}"""
|
| 198 |
+
|
| 199 |
+
def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
|
| 200 |
+
value_args = schema.filtered_args(values=True, scalars=False)
|
| 201 |
+
# backends can customize the way the node base class constructor is called,
|
| 202 |
+
# as long as all of its arguments can be generated from information available from the schema
|
| 203 |
+
base_ctor_value_args_list = []
|
| 204 |
+
for arg in value_args:
|
| 205 |
+
if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
|
| 206 |
+
base_ctor_value_args_list.append(f"{arg.name}")
|
| 207 |
+
elif isinstance(arg.lazy_type, OptionalCType):
|
| 208 |
+
base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
|
| 209 |
+
else:
|
| 210 |
+
raise AssertionError(
|
| 211 |
+
f"Unsupported type ({arg.lazy_type}) - add support if necessary"
|
| 212 |
+
)
|
| 213 |
+
base_ctor_value_args = ", ".join(base_ctor_value_args_list)
|
| 214 |
+
|
| 215 |
+
scalar_args = schema.filtered_args(values=False, scalars=True)
|
| 216 |
+
|
| 217 |
+
# Shape construction.
|
| 218 |
+
# Conditionally build shape depending on specified shape property
|
| 219 |
+
if schema.properties.ShapePrecompute:
|
| 220 |
+
shape_ctor_arg = "std::move(shapes),"
|
| 221 |
+
elif schema.properties.ShapeCompute:
|
| 222 |
+
shape_args = [a.name for a in value_args]
|
| 223 |
+
shape_args.extend(a.name for a in scalar_args)
|
| 224 |
+
shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
|
| 225 |
+
elif schema.properties.ShapeCache:
|
| 226 |
+
shape_args = [f"operand({i})" for i in range(len(value_args))]
|
| 227 |
+
shape_args.extend(a.name for a in scalar_args)
|
| 228 |
+
shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
|
| 229 |
+
else:
|
| 230 |
+
shape_ctor_arg = ""
|
| 231 |
+
|
| 232 |
+
scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
|
| 233 |
+
|
| 234 |
+
return f"""{self.node_base}(
|
| 235 |
+
{schema.node_name}::ClassOpKind(),
|
| 236 |
+
OpList{{{base_ctor_value_args}}},
|
| 237 |
+
{shape_ctor_arg}
|
| 238 |
+
/* num_outputs */ {len(schema.returns)},
|
| 239 |
+
torch::lazy::MHash({scalar_hashes}))"""
|
| 240 |
+
|
| 241 |
+
def gen(self, schema: LazyIrSchema) -> list[str]:
|
| 242 |
+
opkind = schema.opkind or aten_symbol(schema)
|
| 243 |
+
|
| 244 |
+
# for now, we just want one IR class decl and soon after also the method defs
|
| 245 |
+
# and we use the functional version not out/inplace.
|
| 246 |
+
all_args = schema.filtered_args()
|
| 247 |
+
scalar_args = schema.filtered_args(values=False, scalars=True)
|
| 248 |
+
|
| 249 |
+
ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
|
| 250 |
+
reuse_ctor_args = ", ".join(ctor_args)
|
| 251 |
+
if self.use_lazy_shape and schema.properties.ShapePrecompute:
|
| 252 |
+
ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
|
| 253 |
+
node_ctor_args = ", ".join(ctor_args)
|
| 254 |
+
|
| 255 |
+
scalar_initializers = ",\n ".join(
|
| 256 |
+
[
|
| 257 |
+
# This code is just special casing the mapping from string_view -> strings
|
| 258 |
+
f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
|
| 259 |
+
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
|
| 260 |
+
else f"{a.name}({a.name})"
|
| 261 |
+
for a in scalar_args
|
| 262 |
+
]
|
| 263 |
+
)
|
| 264 |
+
if len(scalar_initializers):
|
| 265 |
+
scalar_initializers = f",\n {scalar_initializers}"
|
| 266 |
+
scalar_decls = "\n ".join(
|
| 267 |
+
[
|
| 268 |
+
f"std::string {a.name};"
|
| 269 |
+
if a.lazy_type.cpp_type() == "c10::string_view"
|
| 270 |
+
else f"::std::optional<std::string> {a.name};"
|
| 271 |
+
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
|
| 272 |
+
else f"{a.lazy_type.cpp_type()} {a.name};"
|
| 273 |
+
for a in scalar_args
|
| 274 |
+
]
|
| 275 |
+
)
|
| 276 |
+
optional_values = [
|
| 277 |
+
arg.name
|
| 278 |
+
for arg in schema.filtered_args(values=True, scalars=False)
|
| 279 |
+
if isinstance(arg.lazy_type, OptionalCType)
|
| 280 |
+
]
|
| 281 |
+
has_optional_decls = "\n ".join(
|
| 282 |
+
[f"bool has_{value}: 1;" for value in optional_values]
|
| 283 |
+
)
|
| 284 |
+
has_optional_defs = "\n ".join(
|
| 285 |
+
[f"has_{value} = !!{value};" for value in optional_values]
|
| 286 |
+
)
|
| 287 |
+
members_to_string = []
|
| 288 |
+
for arg in scalar_args:
|
| 289 |
+
if isinstance(arg.lazy_type, OptionalCType):
|
| 290 |
+
value = f"{arg.name}.value()"
|
| 291 |
+
if arg.is_generator:
|
| 292 |
+
value = '"torch.Generator()"'
|
| 293 |
+
members_to_string.append(
|
| 294 |
+
f"""if ({arg.name}.has_value()) {{
|
| 295 |
+
ss << ", {arg.name}=" << {value};
|
| 296 |
+
}} else {{
|
| 297 |
+
ss << ", {arg.name}=null";
|
| 298 |
+
}}"""
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
|
| 302 |
+
members_to_string_str = "\n ".join(members_to_string)
|
| 303 |
+
|
| 304 |
+
return [
|
| 305 |
+
f"""\
|
| 306 |
+
class {schema.node_name} : public {self.node_base} {{
|
| 307 |
+
public:
|
| 308 |
+
static torch::lazy::OpKind ClassOpKind() {{
|
| 309 |
+
return torch::lazy::OpKind({opkind});
|
| 310 |
+
}}
|
| 311 |
+
|
| 312 |
+
{schema.node_name}({node_ctor_args})
|
| 313 |
+
: {self.node_base_ctor_call(schema)}{scalar_initializers}
|
| 314 |
+
{{
|
| 315 |
+
{has_optional_defs}
|
| 316 |
+
}}
|
| 317 |
+
|
| 318 |
+
std::string ToString() const override {{
|
| 319 |
+
std::stringstream ss;
|
| 320 |
+
ss << {self.node_base}::ToString();
|
| 321 |
+
{members_to_string_str}
|
| 322 |
+
return ss.str();
|
| 323 |
+
}}
|
| 324 |
+
|
| 325 |
+
{self.create_function(schema, reuse_ctor_args)}
|
| 326 |
+
|
| 327 |
+
{self.can_be_reused_function(schema, reuse_ctor_args)}
|
| 328 |
+
|
| 329 |
+
{self.lowering_function(schema)}
|
| 330 |
+
|
| 331 |
+
{scalar_decls}
|
| 332 |
+
{has_optional_decls}
|
| 333 |
+
|
| 334 |
+
}};
|
| 335 |
+
|
| 336 |
+
""",
|
| 337 |
+
]
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
@dataclass(frozen=True)
|
| 341 |
+
class GenTSLazyIR(GenLazyIR):
|
| 342 |
+
def lowering_function(self, schema: LazyIrSchema) -> str:
|
| 343 |
+
signature = """
|
| 344 |
+
torch::lazy::TSOpVector Lower(
|
| 345 |
+
std::shared_ptr<torch::jit::GraphFunction> function,
|
| 346 |
+
torch::lazy::TSLoweringContext* loctx) const override"""
|
| 347 |
+
|
| 348 |
+
if schema.properties.LowerDeclOnly:
|
| 349 |
+
return f"{signature};"
|
| 350 |
+
elif schema.properties.Lower:
|
| 351 |
+
return f"""{signature} {{
|
| 352 |
+
{ts_lowering_body(schema)}
|
| 353 |
+
}}
|
| 354 |
+
"""
|
| 355 |
+
else:
|
| 356 |
+
return ""
|
| 357 |
+
|
| 358 |
+
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
| 359 |
+
signature = f"static NodePtr Create({node_ctor_args})"
|
| 360 |
+
if schema.properties.CreateFnDeclOnly:
|
| 361 |
+
return f"{signature};"
|
| 362 |
+
elif not schema.properties.CreateFn:
|
| 363 |
+
return ""
|
| 364 |
+
return f"""{signature} {{
|
| 365 |
+
return ReuseOrMakeNode<{schema.node_name}>(data);
|
| 366 |
+
}}"""
|
| 367 |
+
|
| 368 |
+
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
|
| 369 |
+
signature = f"bool CanBeReused({node_ctor_args}) const"
|
| 370 |
+
if schema.properties.CanBeReusedDeclOnly:
|
| 371 |
+
return f"{signature};"
|
| 372 |
+
elif not schema.properties.CanBeReused:
|
| 373 |
+
return ""
|
| 374 |
+
value_comparison = []
|
| 375 |
+
for arg in itertools.chain(schema.positional_values, schema.keyword_values):
|
| 376 |
+
if isinstance(arg.lazy_type, OptionalCType):
|
| 377 |
+
value_comparison.append(
|
| 378 |
+
f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
|
| 379 |
+
)
|
| 380 |
+
else:
|
| 381 |
+
value_comparison.append(f"operand(i++) == {arg.name}")
|
| 382 |
+
for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
|
| 383 |
+
if isinstance(arg.lazy_type, OptionalCType):
|
| 384 |
+
value_comparison.append(
|
| 385 |
+
f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
|
| 386 |
+
)
|
| 387 |
+
else:
|
| 388 |
+
value_comparison.append(f"this->{arg.name} == {arg.name}")
|
| 389 |
+
value_comparison_str = " &&\n ".join(value_comparison)
|
| 390 |
+
|
| 391 |
+
return f"""{signature} {{
|
| 392 |
+
size_t i = 0;
|
| 393 |
+
return ({value_comparison_str});
|
| 394 |
+
}}"""
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
@dataclass(frozen=True)
|
| 398 |
+
class GenLazyNativeFuncDefinition:
|
| 399 |
+
class_method_name: str
|
| 400 |
+
backend_index: BackendIndex
|
| 401 |
+
tensor_class: str
|
| 402 |
+
gen_forced_fallback_code: bool
|
| 403 |
+
backend_namespace: str
|
| 404 |
+
get_tensorlist: str
|
| 405 |
+
get_tensor_or_wrap_number: str
|
| 406 |
+
try_get_tensor: str
|
| 407 |
+
metrics_counter: str
|
| 408 |
+
create_tensor: str
|
| 409 |
+
create_from_first_tensor: bool
|
| 410 |
+
create_aten_from_ltc_tensor: str
|
| 411 |
+
tuple_aten_from_ltc_tensors: str
|
| 412 |
+
lazy_tensor_ptr: str
|
| 413 |
+
get_device_fn: str
|
| 414 |
+
|
| 415 |
+
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 416 |
+
value_args = schema.filtered_args(values=True, scalars=False)
|
| 417 |
+
# Generates lazy_{name} variables for LazyTensors wrapping input tensors
|
| 418 |
+
lazy_tensor_decls: list[str] = []
|
| 419 |
+
for arg in value_args:
|
| 420 |
+
if arg.is_wrapped_scalar:
|
| 421 |
+
if isinstance(arg.lazy_type, OptionalCType):
|
| 422 |
+
lazy_tensor_decls.append(
|
| 423 |
+
f"""auto node_{arg.name} = {arg.name} ?
|
| 424 |
+
std::make_optional(torch::lazy::LazyGraphExecutor::Get()->
|
| 425 |
+
GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
|
| 426 |
+
::std::nullopt;"""
|
| 427 |
+
)
|
| 428 |
+
else:
|
| 429 |
+
lazy_tensor_decls.append(
|
| 430 |
+
f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
|
| 431 |
+
GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
|
| 432 |
+
)
|
| 433 |
+
elif arg.is_symint_or_list:
|
| 434 |
+
continue # values are extracted in isValueType
|
| 435 |
+
elif isinstance(arg.lazy_type, BaseCType):
|
| 436 |
+
if arg.lazy_type.type is tensorListValueT:
|
| 437 |
+
lazy_tensor_decls.append(
|
| 438 |
+
f"auto lazy_{arg.name}_tensorlist = "
|
| 439 |
+
f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
|
| 440 |
+
)
|
| 441 |
+
else:
|
| 442 |
+
lazy_tensor_decls.append(
|
| 443 |
+
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
|
| 444 |
+
f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
|
| 445 |
+
)
|
| 446 |
+
elif isinstance(arg.lazy_type, OptionalCType):
|
| 447 |
+
assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
|
| 448 |
+
# TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
|
| 449 |
+
# until we encounter a real world example.
|
| 450 |
+
lazy_tensor_decls.append(
|
| 451 |
+
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
|
| 452 |
+
f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
|
| 453 |
+
)
|
| 454 |
+
else:
|
| 455 |
+
raise AssertionError(
|
| 456 |
+
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
|
| 457 |
+
)
|
| 458 |
+
return ("\n ").join(lazy_tensor_decls)
|
| 459 |
+
|
| 460 |
+
def force_eager_fallback(
|
| 461 |
+
self,
|
| 462 |
+
func: NativeFunction,
|
| 463 |
+
schema: LazyIrSchema,
|
| 464 |
+
metadata: BackendMetadata,
|
| 465 |
+
sig: DispatcherSignature | NativeSignature,
|
| 466 |
+
) -> str:
|
| 467 |
+
if self.gen_forced_fallback_code:
|
| 468 |
+
return gen_fallback_code(
|
| 469 |
+
schema, sig, overload_name=func.func.name.overload_name
|
| 470 |
+
)
|
| 471 |
+
return ""
|
| 472 |
+
|
| 473 |
+
def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 474 |
+
return f"{self.metrics_counter};"
|
| 475 |
+
|
| 476 |
+
def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 477 |
+
value_args = schema.filtered_args(values=True, scalars=False)
|
| 478 |
+
scalar_args = schema.filtered_args(values=False, scalars=True)
|
| 479 |
+
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
|
| 480 |
+
optional_device = OptionalCType(BaseCType(deviceT))
|
| 481 |
+
optional_devices = [
|
| 482 |
+
a.name for a in scalar_args if a.lazy_type == optional_device
|
| 483 |
+
]
|
| 484 |
+
assert (
|
| 485 |
+
len(value_types_names) > 0 or len(optional_devices) > 0
|
| 486 |
+
), "Expected at least one Value or Device type"
|
| 487 |
+
get_device_str = (
|
| 488 |
+
f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
|
| 489 |
+
)
|
| 490 |
+
return f"""auto common_device = {get_device_str};
|
| 491 |
+
TORCH_INTERNAL_ASSERT(common_device);
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 495 |
+
metadata = self.backend_index.get_kernel(func)
|
| 496 |
+
assert metadata is not None
|
| 497 |
+
all_args = schema.filtered_args()
|
| 498 |
+
returns_length = len(schema.returns)
|
| 499 |
+
# call the meta kernel if it exists, to compute output shape/dtype for our IR
|
| 500 |
+
# Note [Generated LTC Shape Functions]
|
| 501 |
+
# LTC uses meta tensors from core to do shape inference when possible, and otherwise
|
| 502 |
+
# we generate a shape function declaration that needs to be manually implemented.
|
| 503 |
+
# How do we detect which ops are eligible to use meta tensors?
|
| 504 |
+
# In general we should be able to use meta tensors not just on structured operators,
|
| 505 |
+
# but also on composite operators that are implemented in terms of structured kernels.
|
| 506 |
+
# We don't currently have a way of knowing at codegen time which ops are implemented that way.
|
| 507 |
+
# This is the case for all view and view_copy operators however, so we're going to
|
| 508 |
+
# use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
|
| 509 |
+
is_view_copy_op = "view_copy" in func.tags
|
| 510 |
+
is_structured = func.structured or func.structured_delegate is not None
|
| 511 |
+
if is_structured or is_view_copy_op:
|
| 512 |
+
meta_out = """
|
| 513 |
+
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
|
| 514 |
+
if returns_length > 1:
|
| 515 |
+
|
| 516 |
+
def this_shape(i: int) -> str:
|
| 517 |
+
return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
|
| 518 |
+
|
| 519 |
+
shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
|
| 520 |
+
meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
|
| 521 |
+
|
| 522 |
+
# Convert tensor args to the meta device and call it.
|
| 523 |
+
# (We can't pass in the input tensors directly, because they are "functional wrappers".
|
| 524 |
+
# If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
|
| 525 |
+
# Even at::meta:: functions might redispatch, e.g. if they call into view ops.
|
| 526 |
+
dispatcher_sig = DispatcherSignature.from_schema(func.func)
|
| 527 |
+
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
|
| 528 |
+
meta_call_args = [
|
| 529 |
+
e.expr
|
| 530 |
+
for e in translate(
|
| 531 |
+
meta_call_ctx, dispatcher_sig.arguments(), method=False
|
| 532 |
+
)
|
| 533 |
+
]
|
| 534 |
+
if is_view_copy_op:
|
| 535 |
+
# view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
|
| 536 |
+
assert func.has_composite_explicit_autograd_non_functional_kernel
|
| 537 |
+
dispatch_ns = "compositeexplicitautogradnonfunctional"
|
| 538 |
+
else:
|
| 539 |
+
dispatch_ns = "meta"
|
| 540 |
+
aten_name = schema.aten_name
|
| 541 |
+
# TODO: this is trolling
|
| 542 |
+
if func.func.has_symint() and metadata.supports_symint():
|
| 543 |
+
aten_name += "_symint"
|
| 544 |
+
shape_str = f"""\
|
| 545 |
+
{meta_conversion_str}
|
| 546 |
+
auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
|
| 547 |
+
{meta_out}"""
|
| 548 |
+
else:
|
| 549 |
+
shape_sig = ComputeShapeSignature(
|
| 550 |
+
metadata.kernel, func, symint=metadata.supports_symint()
|
| 551 |
+
)
|
| 552 |
+
shape_str = f"""
|
| 553 |
+
auto shapes = {shape_sig.shape_call};"""
|
| 554 |
+
|
| 555 |
+
shape_str += f"""
|
| 556 |
+
TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
|
| 557 |
+
|
| 558 |
+
# Calculating which dimensions are symbolic
|
| 559 |
+
func_schema_str = "aten::" + str(func.func)
|
| 560 |
+
shape_str += f"""
|
| 561 |
+
if(torch::lazy::symbolicShapeEnabled()){{
|
| 562 |
+
std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
|
| 563 |
+
const char* schema_str = "{func_schema_str}";
|
| 564 |
+
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
| 565 |
+
}}
|
| 566 |
+
"""
|
| 567 |
+
return shape_str
|
| 568 |
+
|
| 569 |
+
def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 570 |
+
node_ctor_input_str = node_ctor_inputs(schema)
|
| 571 |
+
return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
|
| 572 |
+
if (!node) {{
|
| 573 |
+
{self.shape_inference(func, schema)}
|
| 574 |
+
node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
|
| 575 |
+
CacheNode(node);
|
| 576 |
+
}}
|
| 577 |
+
"""
|
| 578 |
+
|
| 579 |
+
def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
|
| 580 |
+
# xla uses an instance method for tensor creation, for the time being
|
| 581 |
+
if self.create_from_first_tensor:
|
| 582 |
+
# TODO(whc) remove this if XLA switches to using static method for creation
|
| 583 |
+
assert (
|
| 584 |
+
first_tensor_name is not None
|
| 585 |
+
), "Requires first tensor to create lazy tensor"
|
| 586 |
+
return f"{first_tensor_name}.{self.create_tensor}"
|
| 587 |
+
return f"{self.backend_namespace}::{self.create_tensor}"
|
| 588 |
+
|
| 589 |
+
def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
| 590 |
+
returns_length = len(schema.returns)
|
| 591 |
+
value_args = schema.filtered_args(values=True, scalars=False)
|
| 592 |
+
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
|
| 593 |
+
first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
|
| 594 |
+
bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
|
| 595 |
+
{self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
|
| 596 |
+
|
| 597 |
+
if returns_length > 1:
|
| 598 |
+
assert (
|
| 599 |
+
len(value_types_names) > 0
|
| 600 |
+
), "Code below assumes there is at least one tensor arg"
|
| 601 |
+
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
|
| 602 |
+
for (int i = 0; i < {returns_length}; i++) {{
|
| 603 |
+
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
|
| 604 |
+
}}
|
| 605 |
+
auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
|
| 606 |
+
|
| 607 |
+
if schema.name.name.inplace or func.func.is_out_fn():
|
| 608 |
+
assert returns_length == 1, (
|
| 609 |
+
"We assumed there was no such case where an op is an in-place variant "
|
| 610 |
+
f"and has tuple outputs, but got tuple of len {returns_length}."
|
| 611 |
+
)
|
| 612 |
+
bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
|
| 613 |
+
auto& result = {first_tensor_name};"""
|
| 614 |
+
|
| 615 |
+
bridge_str += """
|
| 616 |
+
return result;"""
|
| 617 |
+
return bridge_str
|
| 618 |
+
|
| 619 |
+
@method_with_native_function
|
| 620 |
+
def __call__(self, func: NativeFunction) -> list[str]:
|
| 621 |
+
sig = kernel_signature(func, self.backend_index)
|
| 622 |
+
metadata = self.backend_index.get_kernel(func)
|
| 623 |
+
assert metadata is not None
|
| 624 |
+
schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
|
| 625 |
+
return [
|
| 626 |
+
f"""\
|
| 627 |
+
{sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
|
| 628 |
+
{self.force_eager_fallback(func, schema, metadata, sig)}
|
| 629 |
+
{self.metrics(func, schema)}
|
| 630 |
+
{self.get_device(func, schema)}
|
| 631 |
+
{self.lazy_tensor_decls(func, schema)}
|
| 632 |
+
{self.build_ir_node(func, schema)}
|
| 633 |
+
{self.return_aten_tensor(func, schema)}
|
| 634 |
+
}}\n
|
| 635 |
+
"""
|
| 636 |
+
]
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
class ComputeShapeSignature:
|
| 640 |
+
"""
|
| 641 |
+
Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
|
| 645 |
+
self.__schema = LazyIrSchema(f.func, symint=symint)
|
| 646 |
+
self.__dispatch_args = ", ".join(
|
| 647 |
+
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
|
| 648 |
+
)
|
| 649 |
+
self.__call_args = ", ".join(
|
| 650 |
+
[f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
|
| 651 |
+
)
|
| 652 |
+
self.__kernel_name = kernel_name
|
| 653 |
+
|
| 654 |
+
def __decl_suffix(self) -> str:
|
| 655 |
+
return f"{self.__kernel_name}({self.__dispatch_args})"
|
| 656 |
+
|
| 657 |
+
def __call_suffix(self) -> str:
|
| 658 |
+
return f"{self.__kernel_name}({self.__call_args})"
|
| 659 |
+
|
| 660 |
+
@property
|
| 661 |
+
def shape_decl(self) -> str:
|
| 662 |
+
return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
|
| 663 |
+
|
| 664 |
+
@property
|
| 665 |
+
def shape_call(self) -> str:
|
| 666 |
+
return f"torch::lazy::compute_shape_{self.__call_suffix()}"
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
@dataclass(frozen=True)
|
| 670 |
+
class GenLazyShapeInferenceDefinition:
|
| 671 |
+
backend_index: BackendIndex
|
| 672 |
+
tensor_class: str
|
| 673 |
+
|
| 674 |
+
@method_with_native_function
|
| 675 |
+
def __call__(self, f: NativeFunction) -> list[str]:
|
| 676 |
+
metadata = self.backend_index.get_kernel(f)
|
| 677 |
+
assert metadata is not None
|
| 678 |
+
|
| 679 |
+
# See Note [Generated LTC Shape Functions]
|
| 680 |
+
is_view_copy_op = "view_copy" in f.tags
|
| 681 |
+
is_structured = f.structured or f.structured_delegate is not None
|
| 682 |
+
if is_structured or is_view_copy_op:
|
| 683 |
+
return []
|
| 684 |
+
else:
|
| 685 |
+
shape_sig = ComputeShapeSignature(
|
| 686 |
+
metadata.kernel, f, symint=metadata.supports_symint()
|
| 687 |
+
)
|
| 688 |
+
return ["\n".join([f"{shape_sig.shape_decl};"])]
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def generate_non_native_lazy_ir_nodes(
|
| 692 |
+
non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
|
| 693 |
+
) -> list[str]:
|
| 694 |
+
"""Generate the non-native lazy IR node classes"""
|
| 695 |
+
nodes = []
|
| 696 |
+
for op in non_native:
|
| 697 |
+
# Set default properties for Non-Native IRs
|
| 698 |
+
properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
|
| 699 |
+
for p in op.get("properties", []):
|
| 700 |
+
setattr(properties, p, True)
|
| 701 |
+
|
| 702 |
+
# non-native is assumed to want symint bindings if you wrote symint
|
| 703 |
+
schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
|
| 704 |
+
schema.opkind = op.get("opkind")
|
| 705 |
+
nodes.append(gen_lazy_ir.gen(schema)[0])
|
| 706 |
+
|
| 707 |
+
return nodes
|
.venv/lib/python3.11/site-packages/torchgen/dest/native_functions.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torchgen.api.meta as meta
|
| 4 |
+
import torchgen.api.structured as structured
|
| 5 |
+
from torchgen.api.types import kernel_signature
|
| 6 |
+
from torchgen.context import with_native_function_and_index
|
| 7 |
+
from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
|
| 8 |
+
from torchgen.utils import mapMaybe
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@with_native_function_and_index
|
| 12 |
+
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
|
| 13 |
+
sig = kernel_signature(f, backend_index)
|
| 14 |
+
metadata = backend_index.get_kernel(f)
|
| 15 |
+
if metadata is None:
|
| 16 |
+
return None
|
| 17 |
+
if "legacy::" in metadata.kernel:
|
| 18 |
+
return None
|
| 19 |
+
else:
|
| 20 |
+
prefix = "static" if backend_index.external else "TORCH_API"
|
| 21 |
+
return f"{prefix} {sig.decl(name=metadata.kernel)};"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@with_native_function_and_index
|
| 25 |
+
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
|
| 26 |
+
meta_name = meta.name(g)
|
| 27 |
+
out_args = structured.impl_arguments(g)
|
| 28 |
+
metadata = backend_index.get_kernel(g)
|
| 29 |
+
if metadata is None:
|
| 30 |
+
return []
|
| 31 |
+
prefix = "" if backend_index.external else "TORCH_API "
|
| 32 |
+
return [
|
| 33 |
+
f"""\
|
| 34 |
+
struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
|
| 35 |
+
void impl({', '.join(a.decl() for a in out_args)});
|
| 36 |
+
}};
|
| 37 |
+
"""
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Generates NativeFunctions.h, a list of forward declarations of all
|
| 42 |
+
# actual kernel definitions we keep in aten/src/ATen/native/
|
| 43 |
+
@with_native_function_and_index
|
| 44 |
+
def compute_native_function_declaration(
|
| 45 |
+
g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
|
| 46 |
+
) -> list[str]:
|
| 47 |
+
metadata = backend_index.get_kernel(g)
|
| 48 |
+
if isinstance(g, NativeFunctionsGroup):
|
| 49 |
+
if metadata is not None and metadata.structured:
|
| 50 |
+
if backend_index.external:
|
| 51 |
+
# Structured hasn't been tested with external backends yet.
|
| 52 |
+
raise AssertionError(
|
| 53 |
+
"Structured external backend functions are not implemented yet."
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
return gen_structured(g, backend_index)
|
| 57 |
+
else:
|
| 58 |
+
return list(
|
| 59 |
+
mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
x = gen_unstructured(g, backend_index)
|
| 63 |
+
return [] if x is None else [x]
|
.venv/lib/python3.11/site-packages/torchgen/dest/register_dispatch_key.py
ADDED
|
@@ -0,0 +1,1005 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
import textwrap
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Literal, TYPE_CHECKING
|
| 7 |
+
|
| 8 |
+
import torchgen.api.cpp as cpp
|
| 9 |
+
import torchgen.api.meta as meta
|
| 10 |
+
import torchgen.api.structured as structured
|
| 11 |
+
from torchgen.api.translate import translate
|
| 12 |
+
from torchgen.api.types import (
|
| 13 |
+
BaseCType,
|
| 14 |
+
Binding,
|
| 15 |
+
ConstRefCType,
|
| 16 |
+
CppSignature,
|
| 17 |
+
CppSignatureGroup,
|
| 18 |
+
DispatcherSignature,
|
| 19 |
+
Expr,
|
| 20 |
+
kernel_signature,
|
| 21 |
+
MutRefCType,
|
| 22 |
+
NamedCType,
|
| 23 |
+
NativeSignature,
|
| 24 |
+
tensorT,
|
| 25 |
+
)
|
| 26 |
+
from torchgen.context import method_with_native_function, native_function_manager
|
| 27 |
+
from torchgen.model import (
|
| 28 |
+
Argument,
|
| 29 |
+
BackendIndex,
|
| 30 |
+
DeviceCheckType,
|
| 31 |
+
DispatchKey,
|
| 32 |
+
gets_generated_out_inplace_wrapper,
|
| 33 |
+
is_cuda_dispatch_key,
|
| 34 |
+
NativeFunction,
|
| 35 |
+
NativeFunctionsGroup,
|
| 36 |
+
SchemaKind,
|
| 37 |
+
TensorOptionsArguments,
|
| 38 |
+
)
|
| 39 |
+
from torchgen.utils import assert_never, mapMaybe, Target
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if TYPE_CHECKING:
|
| 43 |
+
from torchgen.selective_build.selector import SelectiveBuilder
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def gen_registration_headers(
|
| 47 |
+
backend_index: BackendIndex,
|
| 48 |
+
per_operator_headers: bool,
|
| 49 |
+
rocm: bool,
|
| 50 |
+
) -> list[str]:
|
| 51 |
+
if per_operator_headers:
|
| 52 |
+
headers = ["#include <ATen/ops/as_strided_native.h>"]
|
| 53 |
+
else:
|
| 54 |
+
headers = ["#include <ATen/NativeFunctions.h>"]
|
| 55 |
+
|
| 56 |
+
if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
|
| 57 |
+
headers.append("#include <ATen/EmptyTensor.h>")
|
| 58 |
+
elif backend_index.dispatch_key == DispatchKey.CUDA:
|
| 59 |
+
if rocm:
|
| 60 |
+
headers.append("#include <ATen/hip/EmptyTensor.h>")
|
| 61 |
+
else:
|
| 62 |
+
headers.append("#include <ATen/cuda/EmptyTensor.h>")
|
| 63 |
+
elif backend_index.dispatch_key == DispatchKey.MPS:
|
| 64 |
+
headers.append("#include <ATen/mps/EmptyTensor.h>")
|
| 65 |
+
elif backend_index.dispatch_key == DispatchKey.XPU:
|
| 66 |
+
# XPU specific, this header resides in third_party/torch-xpu-ops
|
| 67 |
+
headers.append("#include <ATen/xpu/EmptyTensor.h>")
|
| 68 |
+
elif per_operator_headers:
|
| 69 |
+
headers += [
|
| 70 |
+
"#include <ATen/ops/empty.h>",
|
| 71 |
+
"#include <ATen/ops/empty_strided.h>",
|
| 72 |
+
"#include <ATen/ops/_copy_from_and_resize.h>",
|
| 73 |
+
"#include <ATen/ops/_copy_from.h>",
|
| 74 |
+
]
|
| 75 |
+
else:
|
| 76 |
+
headers.append("#include <ATen/Functions.h>")
|
| 77 |
+
|
| 78 |
+
headers.append("#include <c10/macros/Macros.h>")
|
| 79 |
+
return headers
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def gen_empty_impl_names(
|
| 83 |
+
backend_index: BackendIndex,
|
| 84 |
+
) -> tuple[str | None, str | None]:
|
| 85 |
+
empty_impl = None
|
| 86 |
+
empty_strided_impl = None
|
| 87 |
+
|
| 88 |
+
if backend_index.dispatch_key in (
|
| 89 |
+
DispatchKey.Meta,
|
| 90 |
+
DispatchKey.CPU,
|
| 91 |
+
DispatchKey.CUDA,
|
| 92 |
+
DispatchKey.MPS,
|
| 93 |
+
DispatchKey.XPU,
|
| 94 |
+
):
|
| 95 |
+
dispatch = str(backend_index.dispatch_key).lower()
|
| 96 |
+
empty_impl = f"at::detail::empty_{dispatch}"
|
| 97 |
+
empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
|
| 98 |
+
elif backend_index.dispatch_key in (
|
| 99 |
+
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
| 100 |
+
DispatchKey.QuantizedCPU,
|
| 101 |
+
DispatchKey.QuantizedCUDA,
|
| 102 |
+
DispatchKey.XPU,
|
| 103 |
+
):
|
| 104 |
+
empty_impl = "at::empty"
|
| 105 |
+
empty_strided_impl = "at::empty_strided"
|
| 106 |
+
|
| 107 |
+
return empty_impl, empty_strided_impl
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
|
| 111 |
+
if backend_index.dispatch_key == DispatchKey.Meta:
|
| 112 |
+
empty_options = "options.device(at::kMeta)"
|
| 113 |
+
else:
|
| 114 |
+
empty_options = "options"
|
| 115 |
+
|
| 116 |
+
empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
|
| 117 |
+
if empty_impl is None:
|
| 118 |
+
return []
|
| 119 |
+
|
| 120 |
+
return [
|
| 121 |
+
f"""
|
| 122 |
+
Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
|
| 123 |
+
if (strides.empty()) {{
|
| 124 |
+
return {empty_impl}(sizes, {empty_options});
|
| 125 |
+
}} else {{
|
| 126 |
+
return {empty_strided_impl}(sizes, strides, {empty_options});
|
| 127 |
+
}}
|
| 128 |
+
}}
|
| 129 |
+
"""
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
|
| 134 |
+
_, empty_strided_impl = gen_empty_impl_names(backend_index)
|
| 135 |
+
return (
|
| 136 |
+
[]
|
| 137 |
+
if empty_strided_impl is None
|
| 138 |
+
else [
|
| 139 |
+
f"""
|
| 140 |
+
std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
|
| 141 |
+
if (out.strides() != strides) {{
|
| 142 |
+
return {empty_strided_impl}(sizes, strides, options);
|
| 143 |
+
}}
|
| 144 |
+
return std::nullopt;
|
| 145 |
+
}}
|
| 146 |
+
"""
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
|
| 152 |
+
if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
|
| 153 |
+
# The function isn't used by this key (since only functional ops have a kernel for this key),
|
| 154 |
+
# so we need to not include it to avoid a defined-but-not-used error.
|
| 155 |
+
return []
|
| 156 |
+
return [
|
| 157 |
+
"""
|
| 158 |
+
void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
|
| 159 |
+
TORCH_CHECK(options.dtype() == out.dtype(),
|
| 160 |
+
"Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
|
| 161 |
+
TORCH_CHECK(options.device() == out.device(),
|
| 162 |
+
"Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
|
| 163 |
+
const bool resized = at::native::resize_output(out, sizes);
|
| 164 |
+
// Only restride if a resize occurred; otherwise we ignore the (advisory)
|
| 165 |
+
// strides from the meta function and directly use the output tensor's
|
| 166 |
+
// preexisting strides
|
| 167 |
+
if (resized) {
|
| 168 |
+
if (!strides.empty()) {
|
| 169 |
+
TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
|
| 170 |
+
// TODO: avoid the redispatch here
|
| 171 |
+
out.as_strided_(sizes, strides);
|
| 172 |
+
} else if (options.memory_format_opt().has_value()) {
|
| 173 |
+
out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
"""
|
| 178 |
+
]
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
|
| 182 |
+
return [
|
| 183 |
+
"""
|
| 184 |
+
void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
|
| 185 |
+
// These checks are needed on those operators that:
|
| 186 |
+
// 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
|
| 187 |
+
// 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
|
| 188 |
+
// For other operators (e.g. 'add'), 'TensorIterator' already checks
|
| 189 |
+
// these things separately.
|
| 190 |
+
TORCH_CHECK(options.dtype() == self.dtype(),
|
| 191 |
+
"Bad in-place call: ",
|
| 192 |
+
"input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
|
| 193 |
+
TORCH_CHECK(options.device() == self.device(),
|
| 194 |
+
"Bad in-place call: ",
|
| 195 |
+
"input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
|
| 196 |
+
TORCH_CHECK(sizes == self.sizes(),
|
| 197 |
+
"Bad in-place call: ",
|
| 198 |
+
"input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
|
| 199 |
+
}
|
| 200 |
+
"""
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
|
| 205 |
+
return [
|
| 206 |
+
'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
|
| 207 |
+
*gen_create_out_helper(backend_index),
|
| 208 |
+
*gen_resize_out_helper(backend_index),
|
| 209 |
+
*gen_check_inplace_helper(backend_index),
|
| 210 |
+
*gen_maybe_create_proxy_helper(backend_index),
|
| 211 |
+
"C10_DIAGNOSTIC_POP()",
|
| 212 |
+
]
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
|
| 216 |
+
#
|
| 217 |
+
# - The primary function of this file is to register all of the
|
| 218 |
+
# implementations for the given dispatch key to the dispatcher,
|
| 219 |
+
# so they are available for use in PyTorch. If dispatch is
|
| 220 |
+
# None, we generate schema (def) registrations and catchall
|
| 221 |
+
# registrations.
|
| 222 |
+
# - The secondary function of this file is to generate a wrapper
|
| 223 |
+
# around functions. In CPUType these wrappers do nothing
|
| 224 |
+
# (and should be removed), but in other cases they handle
|
| 225 |
+
# DeviceGuard. A small extra benefit of wrappers is they
|
| 226 |
+
# are not overloaded, so they can be used in the registration
|
| 227 |
+
# API without having to disambiguate which overload you want
|
| 228 |
+
# (as would be the case if you directly registered native::
|
| 229 |
+
# functions).
|
| 230 |
+
# - The tertiary function of this file is to generate *static*
|
| 231 |
+
# cpp API bindings which can be used to bypass dispatcher
|
| 232 |
+
# directly to kernels, but with user-friendly cpp-style API
|
| 233 |
+
@dataclass(frozen=True)
|
| 234 |
+
class RegisterDispatchKey:
|
| 235 |
+
backend_index: BackendIndex
|
| 236 |
+
|
| 237 |
+
target: Literal[
|
| 238 |
+
Target.ANONYMOUS_DEFINITION,
|
| 239 |
+
Target.NAMESPACED_DEFINITION,
|
| 240 |
+
Target.NAMESPACED_DECLARATION,
|
| 241 |
+
Target.REGISTRATION,
|
| 242 |
+
]
|
| 243 |
+
|
| 244 |
+
# Selector object to determine which operators to generate
|
| 245 |
+
# registration code for.
|
| 246 |
+
selector: SelectiveBuilder
|
| 247 |
+
|
| 248 |
+
# Whether or not we are actually code-genning for ROCm
|
| 249 |
+
rocm: bool
|
| 250 |
+
|
| 251 |
+
# Whether or not to generate symint registrations or not. External users
|
| 252 |
+
# of codegen who don't care about symints can set this to false to get
|
| 253 |
+
# non-SymInt codegen
|
| 254 |
+
symint: bool
|
| 255 |
+
|
| 256 |
+
# The class that all unstructured native functions live under. This is used to improve
|
| 257 |
+
# compiler error messages when a kernel writer adds a native function with the wrong signature.
|
| 258 |
+
# This is only used in unstructured kernels, since structured kernels already live in a class.
|
| 259 |
+
# Finally, this field is currently Optional because it is only used by external backends.
|
| 260 |
+
# It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
|
| 261 |
+
# all of the existing kernel signatures scattered across aten/src/ATen/native.
|
| 262 |
+
class_method_name: str | None
|
| 263 |
+
|
| 264 |
+
# Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
|
| 265 |
+
# operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
|
| 266 |
+
skip_dispatcher_op_registration: bool
|
| 267 |
+
|
| 268 |
+
@staticmethod
|
| 269 |
+
def gen_device_check(
|
| 270 |
+
type: DeviceCheckType, args: list[Argument], method_name: str
|
| 271 |
+
) -> str:
|
| 272 |
+
if type == DeviceCheckType.NoCheck:
|
| 273 |
+
return " // No device check\n"
|
| 274 |
+
|
| 275 |
+
device_check = "std::optional<Device> common_device = std::nullopt;\n"
|
| 276 |
+
device_check += "(void)common_device; // Suppress unused variable warning\n"
|
| 277 |
+
for arg in args:
|
| 278 |
+
# Only tensor like arguments are eligible
|
| 279 |
+
if arg.type.is_tensor_like():
|
| 280 |
+
device_check += f"""
|
| 281 |
+
c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
|
| 282 |
+
return device_check
|
| 283 |
+
|
| 284 |
+
@method_with_native_function
|
| 285 |
+
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
|
| 286 |
+
if isinstance(f, NativeFunctionsGroup):
|
| 287 |
+
g: NativeFunctionsGroup = f
|
| 288 |
+
# Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
|
| 289 |
+
# gen_structured() has special logic to handle auto-generated kernels.
|
| 290 |
+
if g.structured:
|
| 291 |
+
return self.gen_structured(g)
|
| 292 |
+
else:
|
| 293 |
+
return list(
|
| 294 |
+
mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
|
| 295 |
+
)
|
| 296 |
+
elif isinstance(f, NativeFunction):
|
| 297 |
+
r = self.gen_unstructured(f)
|
| 298 |
+
return [] if r is None else [r]
|
| 299 |
+
else:
|
| 300 |
+
assert_never(f)
|
| 301 |
+
|
| 302 |
+
def wrapper_kernel_sig(
|
| 303 |
+
self, f: NativeFunction
|
| 304 |
+
) -> NativeSignature | DispatcherSignature:
|
| 305 |
+
# The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
|
| 306 |
+
return DispatcherSignature.from_schema(
|
| 307 |
+
f.func,
|
| 308 |
+
prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
|
| 309 |
+
symint=self.symint,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
def gen_out_inplace_wrapper(
|
| 313 |
+
self, f: NativeFunction, g: NativeFunctionsGroup | None
|
| 314 |
+
) -> str | None:
|
| 315 |
+
if g is None:
|
| 316 |
+
return None
|
| 317 |
+
k = f.func.kind()
|
| 318 |
+
if k is SchemaKind.inplace:
|
| 319 |
+
copy_op = "at::_copy_from"
|
| 320 |
+
elif k is SchemaKind.out:
|
| 321 |
+
copy_op = "at::_copy_from_and_resize"
|
| 322 |
+
else:
|
| 323 |
+
raise AssertionError("gen_out_inplace_wrapper called on a functional op")
|
| 324 |
+
|
| 325 |
+
sig = self.wrapper_kernel_sig(f)
|
| 326 |
+
name = sig.name()
|
| 327 |
+
|
| 328 |
+
func_res = f"{name}_tmp"
|
| 329 |
+
return_names = cpp.return_names(f)
|
| 330 |
+
if len(return_names) > 1:
|
| 331 |
+
updates = "\n ".join(
|
| 332 |
+
f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
|
| 333 |
+
for i, ret_name in enumerate(return_names)
|
| 334 |
+
)
|
| 335 |
+
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
|
| 336 |
+
elif len(return_names) == 1:
|
| 337 |
+
ret_name = return_names[0]
|
| 338 |
+
updates = f"{copy_op}({func_res}, {ret_name});"
|
| 339 |
+
returns = ret_name
|
| 340 |
+
else:
|
| 341 |
+
assert len(f.func.arguments.out) == 1
|
| 342 |
+
returns = ""
|
| 343 |
+
out_arg = f.func.arguments.out[0]
|
| 344 |
+
if out_arg.type.is_list_like():
|
| 345 |
+
updates = f"""\
|
| 346 |
+
for (int64_t i = 0; i < {func_res}.size(); ++i) {{
|
| 347 |
+
{copy_op}({func_res}[i], {out_arg.name}[i]);
|
| 348 |
+
}}"""
|
| 349 |
+
else:
|
| 350 |
+
updates = f"{copy_op}({func_res}, {out_arg.name});"
|
| 351 |
+
|
| 352 |
+
functional_sig = self.wrapper_kernel_sig(g.functional)
|
| 353 |
+
wrapper_name = sig.name()
|
| 354 |
+
|
| 355 |
+
return f"""\
|
| 356 |
+
{sig.defn(name=wrapper_name)} {{
|
| 357 |
+
auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
|
| 358 |
+
{updates}
|
| 359 |
+
return {returns};
|
| 360 |
+
}}
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
|
| 364 |
+
metadata = self.backend_index.get_kernel(g)
|
| 365 |
+
if self.backend_index.dispatch_key == DispatchKey.Meta:
|
| 366 |
+
assert not self.backend_index.has_kernel(g.out), (
|
| 367 |
+
"Do not explicitly specify Meta dispatch key on structured "
|
| 368 |
+
"functions, they will be automatically generated for you"
|
| 369 |
+
)
|
| 370 |
+
elif (
|
| 371 |
+
self.backend_index.dispatch_key
|
| 372 |
+
== DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 373 |
+
):
|
| 374 |
+
assert not self.backend_index.has_kernel(g.out), (
|
| 375 |
+
"Do not explicitly specify CompositeExplicitAutograd dispatch key on structured "
|
| 376 |
+
"functions, they will be automatically generated for you"
|
| 377 |
+
)
|
| 378 |
+
elif metadata is None or not metadata.structured:
|
| 379 |
+
return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
|
| 380 |
+
structured_gen = StructuredRegisterDispatchKey(
|
| 381 |
+
self.backend_index,
|
| 382 |
+
self.target,
|
| 383 |
+
self.selector,
|
| 384 |
+
self.rocm,
|
| 385 |
+
self.symint,
|
| 386 |
+
self.class_method_name,
|
| 387 |
+
self.skip_dispatcher_op_registration,
|
| 388 |
+
g,
|
| 389 |
+
)
|
| 390 |
+
return list(mapMaybe(structured_gen.gen_one, g.functions()))
|
| 391 |
+
|
| 392 |
+
def gen_unstructured(
|
| 393 |
+
self, f: NativeFunction, g: NativeFunctionsGroup | None = None
|
| 394 |
+
) -> str | None:
|
| 395 |
+
with native_function_manager(f):
|
| 396 |
+
inplace_meta = False
|
| 397 |
+
gets_out_inplace_wrapper = False
|
| 398 |
+
if not self.backend_index.has_kernel(f):
|
| 399 |
+
if (
|
| 400 |
+
self.backend_index.dispatch_key == DispatchKey.Meta
|
| 401 |
+
and f.func.kind() is SchemaKind.inplace
|
| 402 |
+
and
|
| 403 |
+
# Defer to composites for meta implementation
|
| 404 |
+
not f.has_composite_kernel
|
| 405 |
+
and
|
| 406 |
+
# Inplace list operations are not supported
|
| 407 |
+
len(f.func.returns) == 1
|
| 408 |
+
):
|
| 409 |
+
inplace_meta = True
|
| 410 |
+
elif (
|
| 411 |
+
not self.backend_index.use_out_as_primary
|
| 412 |
+
and g is not None
|
| 413 |
+
and gets_generated_out_inplace_wrapper(f, g, self.backend_index)
|
| 414 |
+
):
|
| 415 |
+
# We want to generate inplace/out wrappers, that don't have a kernel for the backend.
|
| 416 |
+
gets_out_inplace_wrapper = True
|
| 417 |
+
else:
|
| 418 |
+
return None
|
| 419 |
+
if f.manual_kernel_registration:
|
| 420 |
+
return None
|
| 421 |
+
|
| 422 |
+
if (
|
| 423 |
+
self.target is Target.REGISTRATION
|
| 424 |
+
and not self.selector.is_native_function_selected(f)
|
| 425 |
+
):
|
| 426 |
+
return None
|
| 427 |
+
|
| 428 |
+
sig = self.wrapper_kernel_sig(f)
|
| 429 |
+
|
| 430 |
+
name = sig.name()
|
| 431 |
+
returns_type = sig.returns_type().cpp_type()
|
| 432 |
+
args = sig.arguments()
|
| 433 |
+
args_str = ", ".join(a.defn() for a in args)
|
| 434 |
+
|
| 435 |
+
# See Note [Direct dispatch bindings]
|
| 436 |
+
cpp_sig_group = CppSignatureGroup.from_native_function(
|
| 437 |
+
f, method=False, fallback_binding=False
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# TODO: dedupe this with the structured codegen
|
| 441 |
+
if self.target is Target.NAMESPACED_DECLARATION:
|
| 442 |
+
result = ""
|
| 443 |
+
for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
|
| 444 |
+
result += f"TORCH_API {cpp_sig.decl()};\n"
|
| 445 |
+
return result
|
| 446 |
+
elif self.target is Target.NAMESPACED_DEFINITION:
|
| 447 |
+
|
| 448 |
+
def generate_defn(cpp_sig: CppSignature) -> str:
|
| 449 |
+
return f"""
|
| 450 |
+
{cpp_sig.defn()} {{
|
| 451 |
+
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
| 452 |
+
}}
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
result = ""
|
| 456 |
+
for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
|
| 457 |
+
result += generate_defn(cpp_sig)
|
| 458 |
+
return result
|
| 459 |
+
|
| 460 |
+
elif self.target is Target.ANONYMOUS_DEFINITION:
|
| 461 |
+
# short circuit for inplace_meta
|
| 462 |
+
if inplace_meta:
|
| 463 |
+
assert f.func.arguments.self_arg is not None
|
| 464 |
+
self_arg_name = f.func.arguments.self_arg.argument.name
|
| 465 |
+
# TODO: handle in place on tensor list
|
| 466 |
+
return f"""
|
| 467 |
+
{returns_type} {name}({args_str}) {{
|
| 468 |
+
TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
|
| 469 |
+
"Cannot inplace into non-meta tensor with meta tensor argument");
|
| 470 |
+
return {self_arg_name};
|
| 471 |
+
}}
|
| 472 |
+
"""
|
| 473 |
+
|
| 474 |
+
# short circuit for generated inplace/out wrappers
|
| 475 |
+
if gets_out_inplace_wrapper:
|
| 476 |
+
return self.gen_out_inplace_wrapper(f, g)
|
| 477 |
+
|
| 478 |
+
metadata = self.backend_index.get_kernel(f)
|
| 479 |
+
if metadata is None:
|
| 480 |
+
return None
|
| 481 |
+
if self.class_method_name is None:
|
| 482 |
+
impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}"
|
| 483 |
+
else:
|
| 484 |
+
impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
|
| 485 |
+
|
| 486 |
+
kernel_sig = kernel_signature(f, self.backend_index)
|
| 487 |
+
|
| 488 |
+
args_exprs_str = ", ".join(
|
| 489 |
+
e.expr
|
| 490 |
+
for e in translate(
|
| 491 |
+
sig.arguments(), kernel_sig.arguments(), method=False
|
| 492 |
+
)
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
device_check = " // No device check\n"
|
| 496 |
+
# Backends that require device guards presumably also require device checks.
|
| 497 |
+
if self.backend_index.device_guard:
|
| 498 |
+
device_check_args = itertools.chain(
|
| 499 |
+
f.func.arguments.out, f.func.arguments.flat_positional
|
| 500 |
+
)
|
| 501 |
+
device_check = RegisterDispatchKey.gen_device_check(
|
| 502 |
+
f.device_check, list(device_check_args), name
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
device_guard = "// DeviceGuard omitted" # default
|
| 506 |
+
if f.device_guard and self.backend_index.device_guard:
|
| 507 |
+
has_tensor_options = any(
|
| 508 |
+
isinstance(a, TensorOptionsArguments)
|
| 509 |
+
for a in f.func.arguments.non_out
|
| 510 |
+
)
|
| 511 |
+
if has_tensor_options:
|
| 512 |
+
# kernel is creating a tensor
|
| 513 |
+
device_guard = """
|
| 514 |
+
const DeviceGuard device_guard(device_or_default(device));"""
|
| 515 |
+
|
| 516 |
+
# CUDA requires special handling
|
| 517 |
+
if is_cuda_dispatch_key(self.backend_index.dispatch_key):
|
| 518 |
+
device_guard = (
|
| 519 |
+
f"globalContext().lazyInitCUDA();\n{device_guard}"
|
| 520 |
+
)
|
| 521 |
+
else:
|
| 522 |
+
# kernel is operating on existing tensors
|
| 523 |
+
|
| 524 |
+
# There is precedence for which argument we use to do
|
| 525 |
+
# device guard. This describes the precedence order.
|
| 526 |
+
self_arg = (
|
| 527 |
+
[f.func.arguments.self_arg.argument]
|
| 528 |
+
if f.func.arguments.self_arg is not None
|
| 529 |
+
else []
|
| 530 |
+
)
|
| 531 |
+
candidate_args = itertools.chain(
|
| 532 |
+
self_arg,
|
| 533 |
+
f.func.arguments.out,
|
| 534 |
+
f.func.arguments.flat_positional,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
# Only tensor like arguments are eligible
|
| 538 |
+
device_of = next(
|
| 539 |
+
(
|
| 540 |
+
f"{a.name}"
|
| 541 |
+
for a in candidate_args
|
| 542 |
+
if a.type.is_tensor_like()
|
| 543 |
+
),
|
| 544 |
+
None,
|
| 545 |
+
)
|
| 546 |
+
if device_of is not None:
|
| 547 |
+
device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
|
| 548 |
+
|
| 549 |
+
return f"""\
|
| 550 |
+
namespace {{
|
| 551 |
+
|
| 552 |
+
{returns_type} {name}({args_str}) {{
|
| 553 |
+
{device_check}
|
| 554 |
+
|
| 555 |
+
{device_guard}
|
| 556 |
+
return {impl_name}({args_exprs_str});
|
| 557 |
+
}}
|
| 558 |
+
|
| 559 |
+
}} // anonymous namespace
|
| 560 |
+
"""
|
| 561 |
+
|
| 562 |
+
elif self.target is Target.REGISTRATION:
|
| 563 |
+
if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
|
| 564 |
+
return None
|
| 565 |
+
else:
|
| 566 |
+
payload = f"TORCH_FN({name})"
|
| 567 |
+
return f'm.impl("{f.func.name}",\n{payload});\n'
|
| 568 |
+
else:
|
| 569 |
+
assert_never(self.target)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 573 |
+
#
|
| 574 |
+
# STRUCTURED
|
| 575 |
+
#
|
| 576 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
@dataclass(frozen=True)
|
| 580 |
+
class StructuredRegisterDispatchKey(RegisterDispatchKey):
|
| 581 |
+
g: NativeFunctionsGroup
|
| 582 |
+
|
| 583 |
+
def gen_class_set_output_functions(
|
| 584 |
+
self, k: SchemaKind, parent_class: str, generate_super: bool
|
| 585 |
+
) -> str:
|
| 586 |
+
if generate_super:
|
| 587 |
+
set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);"
|
| 588 |
+
else:
|
| 589 |
+
set_output_super = ""
|
| 590 |
+
|
| 591 |
+
def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
|
| 592 |
+
return f"""
|
| 593 |
+
void set_output_{name}(
|
| 594 |
+
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
|
| 595 |
+
TensorOptions options, DimnameList names
|
| 596 |
+
) override {{
|
| 597 |
+
{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")}
|
| 598 |
+
if (!names.empty()) {{
|
| 599 |
+
namedinference::propagate_names(outputs_[output_idx], names);
|
| 600 |
+
}}
|
| 601 |
+
// super must happen after, so that downstream can use maybe_get_output
|
| 602 |
+
// to retrieve the output
|
| 603 |
+
{textwrap.indent(set_output_super, " ")}
|
| 604 |
+
}}
|
| 605 |
+
"""
|
| 606 |
+
|
| 607 |
+
return f"""
|
| 608 |
+
{gen_set_output_function("strided", maybe_create_proxy=True)}
|
| 609 |
+
{gen_set_output_function("raw_strided", maybe_create_proxy=False)}
|
| 610 |
+
"""
|
| 611 |
+
|
| 612 |
+
def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str:
|
| 613 |
+
if self.backend_index.dispatch_key in [
|
| 614 |
+
DispatchKey.CUDA,
|
| 615 |
+
DispatchKey.MPS,
|
| 616 |
+
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
| 617 |
+
]:
|
| 618 |
+
maybe_set_guard = """
|
| 619 |
+
auto current_device = guard_.current_device();
|
| 620 |
+
if (C10_UNLIKELY(current_device.has_value())) {
|
| 621 |
+
TORCH_INTERNAL_ASSERT(*current_device == options.device(),
|
| 622 |
+
"structured kernels don't support multi-device outputs");
|
| 623 |
+
} else {
|
| 624 |
+
guard_.reset_device(options.device());
|
| 625 |
+
}
|
| 626 |
+
"""
|
| 627 |
+
maybe_set_guard_line = maybe_set_guard + "\n"
|
| 628 |
+
else:
|
| 629 |
+
maybe_set_guard_line = maybe_set_guard = ""
|
| 630 |
+
|
| 631 |
+
if maybe_create_proxy:
|
| 632 |
+
create_proxy = """
|
| 633 |
+
auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
|
| 634 |
+
if (C10_UNLIKELY(maybe_proxy.has_value())) {
|
| 635 |
+
proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
|
| 636 |
+
}
|
| 637 |
+
"""
|
| 638 |
+
else:
|
| 639 |
+
create_proxy = ""
|
| 640 |
+
|
| 641 |
+
if k is SchemaKind.functional:
|
| 642 |
+
assert self.backend_index.dispatch_key in (
|
| 643 |
+
DispatchKey.Meta,
|
| 644 |
+
DispatchKey.CPU,
|
| 645 |
+
DispatchKey.CUDA,
|
| 646 |
+
DispatchKey.MPS,
|
| 647 |
+
DispatchKey.XPU,
|
| 648 |
+
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
| 649 |
+
)
|
| 650 |
+
return f"""{maybe_set_guard_line}
|
| 651 |
+
outputs_[output_idx] = create_out(sizes, strides, options);"""
|
| 652 |
+
elif k is SchemaKind.inplace:
|
| 653 |
+
return f"""{maybe_set_guard_line}
|
| 654 |
+
const auto& out = outputs_[output_idx].get();
|
| 655 |
+
check_inplace(out, sizes, options);
|
| 656 |
+
{create_proxy}"""
|
| 657 |
+
elif k is SchemaKind.out:
|
| 658 |
+
return f"""{maybe_set_guard_line}
|
| 659 |
+
const auto& out = outputs_[output_idx].get();
|
| 660 |
+
resize_out(out, sizes, strides, options);
|
| 661 |
+
{create_proxy}"""
|
| 662 |
+
elif k is SchemaKind.mutable or k is SchemaKind.scratch:
|
| 663 |
+
raise AssertionError(
|
| 664 |
+
f"{k} structured operators are currently not supported"
|
| 665 |
+
)
|
| 666 |
+
else:
|
| 667 |
+
assert_never(k)
|
| 668 |
+
|
| 669 |
+
# returns the definition of a ctor, as well as how to construct
|
| 670 |
+
# this class to a variable named op
|
| 671 |
+
def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
|
| 672 |
+
if k is SchemaKind.functional:
|
| 673 |
+
return ""
|
| 674 |
+
elif k is SchemaKind.inplace:
|
| 675 |
+
# TODO: Make sure out argument is guaranteed to be self
|
| 676 |
+
return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
|
| 677 |
+
elif k is SchemaKind.out:
|
| 678 |
+
out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
|
| 679 |
+
out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
|
| 680 |
+
return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
|
| 681 |
+
elif k is SchemaKind.mutable or k is SchemaKind.scratch:
|
| 682 |
+
raise AssertionError(
|
| 683 |
+
f"{k} structured operators are currently not supported"
|
| 684 |
+
)
|
| 685 |
+
else:
|
| 686 |
+
assert_never(k)
|
| 687 |
+
|
| 688 |
+
def gen_class(
|
| 689 |
+
self,
|
| 690 |
+
f: NativeFunction,
|
| 691 |
+
k: SchemaKind,
|
| 692 |
+
*,
|
| 693 |
+
class_name: str,
|
| 694 |
+
parent_class: str,
|
| 695 |
+
generate_super: bool,
|
| 696 |
+
) -> str:
|
| 697 |
+
if k is SchemaKind.functional:
|
| 698 |
+
output_type = "Tensor"
|
| 699 |
+
output_value = "outputs_[output_idx]"
|
| 700 |
+
proxy_field = ""
|
| 701 |
+
elif k is SchemaKind.inplace:
|
| 702 |
+
output_type = "std::reference_wrapper<Tensor>"
|
| 703 |
+
output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
|
| 704 |
+
proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
|
| 705 |
+
elif k is SchemaKind.out:
|
| 706 |
+
output_type = "std::reference_wrapper<Tensor>"
|
| 707 |
+
output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
|
| 708 |
+
proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
|
| 709 |
+
else:
|
| 710 |
+
raise RuntimeError(f"Unsupported SchemaKind {k}")
|
| 711 |
+
|
| 712 |
+
if self.backend_index.dispatch_key == DispatchKey.CUDA:
|
| 713 |
+
if self.rocm:
|
| 714 |
+
guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;"
|
| 715 |
+
else:
|
| 716 |
+
guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
|
| 717 |
+
elif (
|
| 718 |
+
self.backend_index.dispatch_key
|
| 719 |
+
== DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 720 |
+
):
|
| 721 |
+
guard_field = "c10::OptionalDeviceGuard guard_;"
|
| 722 |
+
elif self.backend_index.dispatch_key == DispatchKey.MPS:
|
| 723 |
+
# TODO: Move to OptionalMPSGuard.
|
| 724 |
+
guard_field = "c10::OptionalDeviceGuard guard_;"
|
| 725 |
+
else:
|
| 726 |
+
guard_field = ""
|
| 727 |
+
|
| 728 |
+
indent = " " * 4
|
| 729 |
+
class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns))
|
| 730 |
+
lines = (
|
| 731 |
+
f"struct {class_name} final : public {parent_class} {{",
|
| 732 |
+
f"{textwrap.indent(class_ctor_str, indent)}",
|
| 733 |
+
f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
|
| 734 |
+
" const Tensor& maybe_get_output(int64_t output_idx) override {",
|
| 735 |
+
f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit
|
| 736 |
+
" }",
|
| 737 |
+
# type: ignore[possibly-undefined] # TODO: audit
|
| 738 |
+
f" std::array<{output_type}, {len(f.func.returns)}> outputs_;",
|
| 739 |
+
f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit
|
| 740 |
+
f"{textwrap.indent(guard_field, indent)}",
|
| 741 |
+
"};",
|
| 742 |
+
)
|
| 743 |
+
return "\n".join(line for line in lines if line)
|
| 744 |
+
|
| 745 |
+
@method_with_native_function
|
| 746 |
+
def gen_one(self, f: NativeFunction) -> str | None:
|
| 747 |
+
assert not f.manual_kernel_registration
|
| 748 |
+
|
| 749 |
+
if (
|
| 750 |
+
self.target is Target.REGISTRATION
|
| 751 |
+
and not self.selector.is_native_function_selected(f)
|
| 752 |
+
):
|
| 753 |
+
return None
|
| 754 |
+
|
| 755 |
+
# TODO: Now, there is something interesting going on here. In the code below,
|
| 756 |
+
# we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace
|
| 757 |
+
# based on the out implementation. But in fact, out is definable by
|
| 758 |
+
# functional too (just not very efficiently), and this is honestly the
|
| 759 |
+
# MORE likely situation for a backend implementor. How do we pick?
|
| 760 |
+
# Well, taking a page from Haskell type classes and default methods,
|
| 761 |
+
# we could conceivably register a circular definition (out in terms
|
| 762 |
+
# of functional, and functional in terms of out) and just require
|
| 763 |
+
# someone to implement one or the other. We'd have to do a little bit
|
| 764 |
+
# of work to not register one of these "weak" definitions unless there
|
| 765 |
+
# is a strong definition somewhere in the DAG! So it's not implemented yet.
|
| 766 |
+
if (
|
| 767 |
+
self.backend_index.dispatch_key
|
| 768 |
+
== DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 769 |
+
and f.func.kind() is SchemaKind.out
|
| 770 |
+
):
|
| 771 |
+
# Never generate a default implementation for out, that's what you
|
| 772 |
+
# have to define as a backend implementor
|
| 773 |
+
return None
|
| 774 |
+
|
| 775 |
+
# Note [Direct dispatch bindings]
|
| 776 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 777 |
+
# Signature of the non-dispatched function we'll expose in a header
|
| 778 |
+
# (e.g., at::cpu::add). We don't generate methods (TODO: do this
|
| 779 |
+
# when CPUTensor class is a thing); nor do we generate fallback
|
| 780 |
+
# bindings for manual_cpp_binding functions.
|
| 781 |
+
cpp_sig_group = CppSignatureGroup.from_native_function(
|
| 782 |
+
f, method=False, fallback_binding=False
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
# Signature of the wrapper function we'll register to the dispatcher
|
| 786 |
+
kern = self.backend_index.get_kernel(f)
|
| 787 |
+
sig = NativeSignature(
|
| 788 |
+
f.func,
|
| 789 |
+
prefix=f"wrapper_{self.backend_index.dispatch_key}_",
|
| 790 |
+
symint=kern is not None and kern.supports_symint(),
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
if self.target is Target.NAMESPACED_DECLARATION:
|
| 794 |
+
result = ""
|
| 795 |
+
for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
|
| 796 |
+
result += f"TORCH_API {cpp_sig.decl()};\n"
|
| 797 |
+
return result
|
| 798 |
+
|
| 799 |
+
elif self.target is Target.NAMESPACED_DEFINITION:
|
| 800 |
+
|
| 801 |
+
def generate_defn(cpp_sig: CppSignature) -> str:
|
| 802 |
+
return f"""
|
| 803 |
+
{cpp_sig.defn()} {{
|
| 804 |
+
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
| 805 |
+
}}
|
| 806 |
+
"""
|
| 807 |
+
|
| 808 |
+
result = ""
|
| 809 |
+
for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
|
| 810 |
+
result += generate_defn(cpp_sig)
|
| 811 |
+
return result
|
| 812 |
+
|
| 813 |
+
elif self.target is Target.ANONYMOUS_DEFINITION:
|
| 814 |
+
k = f.func.kind()
|
| 815 |
+
|
| 816 |
+
# Construct the body of the wrapper function with signature sig
|
| 817 |
+
sig_body = []
|
| 818 |
+
# We'll use context to keep track of any variables we've brought
|
| 819 |
+
# into scope while generating code
|
| 820 |
+
context: list[Binding | Expr] = list(sig.arguments())
|
| 821 |
+
|
| 822 |
+
# Initialize the class corresponding to this structured
|
| 823 |
+
# operator; feeding it the output argument(s) if it is known
|
| 824 |
+
if self.backend_index.dispatch_key is DispatchKey.Meta:
|
| 825 |
+
class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
|
| 826 |
+
parent_class = f"at::meta::structured_{meta.name(self.g)}"
|
| 827 |
+
elif (
|
| 828 |
+
self.backend_index.dispatch_key
|
| 829 |
+
is DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 830 |
+
):
|
| 831 |
+
# TODO: dedup this branch
|
| 832 |
+
class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
|
| 833 |
+
parent_class = f"at::meta::structured_{meta.name(self.g)}"
|
| 834 |
+
else:
|
| 835 |
+
metadata = self.backend_index.get_kernel(self.g)
|
| 836 |
+
assert metadata is not None
|
| 837 |
+
class_name = f"structured_{metadata.kernel}_{k.name}"
|
| 838 |
+
parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}"
|
| 839 |
+
|
| 840 |
+
if self.backend_index.device_guard:
|
| 841 |
+
device_check_args = itertools.chain(
|
| 842 |
+
f.func.arguments.out, f.func.arguments.flat_positional
|
| 843 |
+
)
|
| 844 |
+
sig_body.append(
|
| 845 |
+
RegisterDispatchKey.gen_device_check(
|
| 846 |
+
f.device_check, list(device_check_args), sig.name()
|
| 847 |
+
)
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
if k is SchemaKind.functional:
|
| 851 |
+
sig_body.append(f"{class_name} op;")
|
| 852 |
+
elif k is SchemaKind.inplace:
|
| 853 |
+
sig_body.append(f"{class_name} op(self);")
|
| 854 |
+
elif k is SchemaKind.out:
|
| 855 |
+
out_args_str = ", ".join(a.name for a in f.func.arguments.out)
|
| 856 |
+
sig_body.append(f"{class_name} op({out_args_str});")
|
| 857 |
+
|
| 858 |
+
# Translate the input native arguments into structured
|
| 859 |
+
# arguments for the meta call
|
| 860 |
+
meta_exprs = ", ".join(
|
| 861 |
+
e.expr
|
| 862 |
+
for e in translate(
|
| 863 |
+
context, structured.meta_arguments(self.g), method=False
|
| 864 |
+
)
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
if self.g.out.precomputed:
|
| 868 |
+
# If this function group has precomputed elements, the meta function
|
| 869 |
+
# returns a struct containing them which must be saved so that it
|
| 870 |
+
# can be unpacked when generating code to call the impl.
|
| 871 |
+
sig_body.append(f"auto precompute = op.meta({meta_exprs});")
|
| 872 |
+
|
| 873 |
+
# Put all of the contents of the precompute struct into the context
|
| 874 |
+
# so that translate will be able to return the correct args for the
|
| 875 |
+
# call to the impl.
|
| 876 |
+
precomputed_values = [
|
| 877 |
+
*self.g.out.precomputed.replace.values(),
|
| 878 |
+
self.g.out.precomputed.add,
|
| 879 |
+
]
|
| 880 |
+
for precomputed_elems in precomputed_values:
|
| 881 |
+
for arg in precomputed_elems:
|
| 882 |
+
context.append(
|
| 883 |
+
Expr(
|
| 884 |
+
expr=f"precompute.{arg.name}",
|
| 885 |
+
type=structured.argument_type(arg, binds=arg.name),
|
| 886 |
+
)
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
# Add a use of the precompute struct so FB internal compilers don't
|
| 890 |
+
# complain that there is an unused variable.
|
| 891 |
+
sig_body.append("(void)precompute;")
|
| 892 |
+
else:
|
| 893 |
+
sig_body.append(f"op.meta({meta_exprs});")
|
| 894 |
+
|
| 895 |
+
# After running meta, op.outputs_ is guaranteed to be valid;
|
| 896 |
+
# add it to the context
|
| 897 |
+
out_args = structured.out_arguments(self.g)
|
| 898 |
+
for i, out_arg in enumerate(out_args):
|
| 899 |
+
assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
|
| 900 |
+
|
| 901 |
+
if k is SchemaKind.out:
|
| 902 |
+
expr = f"op.maybe_get_output({i})"
|
| 903 |
+
else:
|
| 904 |
+
expr = f"op.outputs_[{i}]"
|
| 905 |
+
|
| 906 |
+
context.append(
|
| 907 |
+
Expr(
|
| 908 |
+
expr=expr,
|
| 909 |
+
# TODO: Stop hardcoding that the output type is a Tensor. Note
|
| 910 |
+
# that for the codegen here this is fine because outputs_ is
|
| 911 |
+
# hardcoded to be tensor already
|
| 912 |
+
type=NamedCType(
|
| 913 |
+
out_arg.nctype.name, MutRefCType(BaseCType(tensorT))
|
| 914 |
+
),
|
| 915 |
+
)
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
# With the expanded context, do the impl call (if not a meta
|
| 919 |
+
# function)
|
| 920 |
+
if (
|
| 921 |
+
self.backend_index.dispatch_key
|
| 922 |
+
== DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 923 |
+
):
|
| 924 |
+
# TODO: https://github.com/pytorch/pytorch/issues/53023
|
| 925 |
+
out_sig_group = CppSignatureGroup.from_native_function(
|
| 926 |
+
self.g.out, method=False, fallback_binding=f.manual_cpp_binding
|
| 927 |
+
)
|
| 928 |
+
out_sig = out_sig_group.most_faithful_signature()
|
| 929 |
+
api_name = out_sig.name()
|
| 930 |
+
out_exprs = ", ".join(
|
| 931 |
+
e.expr
|
| 932 |
+
for e in translate(context, out_sig.arguments(), method=False)
|
| 933 |
+
)
|
| 934 |
+
# TODO: I think this means structured won't work with method
|
| 935 |
+
# only functions (but maybe you're saved by faithful? iunno.)
|
| 936 |
+
# NB: Originally I wrote this as an at::redispatch call, but
|
| 937 |
+
# I got in trouble because that meant I needed a DispatchKeySet
|
| 938 |
+
# in the wrapper function, which meant I needed a DispatchKeySet
|
| 939 |
+
# in the DispatchKeyFunctions declarations, but the defined API
|
| 940 |
+
# there does NOT permit a dispatch key set. I think you can
|
| 941 |
+
# probably unwind this by calling some function to do the TLS
|
| 942 |
+
# fetch and get the DispatchKeySet when you don't have it, but
|
| 943 |
+
# I didn't do it for this version
|
| 944 |
+
sig_body.append(f"at::{api_name}({out_exprs});")
|
| 945 |
+
elif self.backend_index.dispatch_key != DispatchKey.Meta:
|
| 946 |
+
impl_exprs = ", ".join(
|
| 947 |
+
e.expr
|
| 948 |
+
for e in translate(
|
| 949 |
+
context, structured.impl_arguments(self.g), method=False
|
| 950 |
+
)
|
| 951 |
+
)
|
| 952 |
+
sig_body.append(f"op.impl({impl_exprs});")
|
| 953 |
+
|
| 954 |
+
# Go over each output, and check if there is a proxy created for it.
|
| 955 |
+
# If so, copy it over to the original output.
|
| 956 |
+
if k is SchemaKind.out or k is SchemaKind.inplace:
|
| 957 |
+
for i in range(len(f.func.returns)):
|
| 958 |
+
sig_body.append(
|
| 959 |
+
f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
# Destructively return the final tensors
|
| 963 |
+
# TODO: Do this in translate instead
|
| 964 |
+
if k is SchemaKind.functional:
|
| 965 |
+
if len(f.func.returns) == 1:
|
| 966 |
+
ret_expr = "std::move(op.outputs_[0])" # small optimization
|
| 967 |
+
else:
|
| 968 |
+
moved = ", ".join(
|
| 969 |
+
f"std::move(op.outputs_[{i}])"
|
| 970 |
+
for i in range(len(f.func.returns))
|
| 971 |
+
)
|
| 972 |
+
ret_expr = f"std::make_tuple({moved})"
|
| 973 |
+
elif k is SchemaKind.inplace:
|
| 974 |
+
ret_expr = "self"
|
| 975 |
+
elif k is SchemaKind.out:
|
| 976 |
+
if len(f.func.returns) == 1:
|
| 977 |
+
ret_expr = f.func.arguments.out[0].name
|
| 978 |
+
else:
|
| 979 |
+
refs = ", ".join(a.name for a in f.func.arguments.out)
|
| 980 |
+
ret_expr = f"std::forward_as_tuple({refs})"
|
| 981 |
+
sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit
|
| 982 |
+
|
| 983 |
+
sig_body_str = "\n".join(sig_body)
|
| 984 |
+
|
| 985 |
+
# For an overview of what this template code looks like, see
|
| 986 |
+
# https://github.com/pytorch/rfcs/pull/9
|
| 987 |
+
return f"""\
|
| 988 |
+
{self.gen_class(
|
| 989 |
+
f, k,
|
| 990 |
+
class_name=class_name,
|
| 991 |
+
parent_class=parent_class,
|
| 992 |
+
generate_super=self.g.out.structured_inherits is not None
|
| 993 |
+
)}
|
| 994 |
+
|
| 995 |
+
{sig.defn()} {{
|
| 996 |
+
{sig_body_str}
|
| 997 |
+
}}
|
| 998 |
+
"""
|
| 999 |
+
|
| 1000 |
+
elif self.target is Target.REGISTRATION:
|
| 1001 |
+
return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
|
| 1002 |
+
else:
|
| 1003 |
+
assert_never(self.target)
|
| 1004 |
+
# Silence mypy's "Missing return statement" error
|
| 1005 |
+
return None
|
.venv/lib/python3.11/site-packages/torchgen/dest/ufunc.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Sequence, TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import torchgen.api.ufunc as ufunc
|
| 7 |
+
from torchgen.api.translate import translate
|
| 8 |
+
from torchgen.api.types import (
|
| 9 |
+
BaseCType,
|
| 10 |
+
Binding,
|
| 11 |
+
CType,
|
| 12 |
+
Expr,
|
| 13 |
+
NamedCType,
|
| 14 |
+
opmath_t,
|
| 15 |
+
scalar_t,
|
| 16 |
+
StructuredImplSignature,
|
| 17 |
+
VectorizedCType,
|
| 18 |
+
)
|
| 19 |
+
from torchgen.context import with_native_function
|
| 20 |
+
from torchgen.model import (
|
| 21 |
+
Argument,
|
| 22 |
+
BaseTy,
|
| 23 |
+
BaseType,
|
| 24 |
+
DispatchKey,
|
| 25 |
+
NativeFunctionsGroup,
|
| 26 |
+
ScalarType,
|
| 27 |
+
UfuncKey,
|
| 28 |
+
)
|
| 29 |
+
from torchgen.utils import OrderedSet
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from torchgen.api.ufunc import UfunctorBindings
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 37 |
+
#
|
| 38 |
+
# CUDA STUFF
|
| 39 |
+
#
|
| 40 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 41 |
+
|
| 42 |
+
# NB: not bothering to generate dispatch stub forward declaration in header,
|
| 43 |
+
# we can just paste it whereever necessary
|
| 44 |
+
|
| 45 |
+
# TODO: use BackendIndex
|
| 46 |
+
# dispatch_key: DispatchKey # only CPU/CUDA right now
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Represents functors for implementing CUDA ufuncs.
|
| 50 |
+
# Functors are templated by scalar_t because when USERS instantiate functors
|
| 51 |
+
# they are templated. A functor looks something like this:
|
| 52 |
+
#
|
| 53 |
+
# template <typename scalar_t>
|
| 54 |
+
# struct CUDAFunctorOnSelf_add {
|
| 55 |
+
# using opmath_t = at::opmath_type<scalar_t>;
|
| 56 |
+
# opmath_t other_;
|
| 57 |
+
# opmath_t alpha_;
|
| 58 |
+
# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
|
| 59 |
+
# : other_(other), alpha_(alpha) {}
|
| 60 |
+
# __device__ scalar_t operator()(scalar_t self) {
|
| 61 |
+
# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
|
| 62 |
+
# }
|
| 63 |
+
# };
|
| 64 |
+
#
|
| 65 |
+
@dataclass(frozen=True)
|
| 66 |
+
class UfunctorSignature:
|
| 67 |
+
g: NativeFunctionsGroup
|
| 68 |
+
scalar_tensor_idx: int | None
|
| 69 |
+
name: str
|
| 70 |
+
|
| 71 |
+
def arguments(self) -> UfunctorBindings:
|
| 72 |
+
return ufunc.ufunctor_arguments(
|
| 73 |
+
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def fields(self) -> list[Binding]:
|
| 77 |
+
# fields are renamed to have a trailing underscore, as is conventional
|
| 78 |
+
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
|
| 79 |
+
|
| 80 |
+
def returns_type(self) -> CType:
|
| 81 |
+
# TODO: don't hardcode; return type will be inferred based on tags on
|
| 82 |
+
# the native function
|
| 83 |
+
return BaseCType(scalar_t)
|
| 84 |
+
|
| 85 |
+
def decl_fields(self) -> str:
|
| 86 |
+
return "\n".join(f"{f.type} {f.name};" for f in self.fields())
|
| 87 |
+
|
| 88 |
+
def inline_defn_ctor(self) -> str:
|
| 89 |
+
args_str = ", ".join(a.decl() for a in self.arguments().ctor)
|
| 90 |
+
# NB: hypothetically could do this with translate but the
|
| 91 |
+
# transition here is very regular
|
| 92 |
+
init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
|
| 93 |
+
return f"{self.name}({args_str}) : {init_str} {{}}"
|
| 94 |
+
|
| 95 |
+
def decl_apply(self) -> str:
|
| 96 |
+
args_str = ", ".join(a.decl() for a in self.arguments().apply)
|
| 97 |
+
return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclass(frozen=True)
|
| 101 |
+
class UfuncSignature:
|
| 102 |
+
g: NativeFunctionsGroup
|
| 103 |
+
name: str
|
| 104 |
+
compute_t: CType
|
| 105 |
+
|
| 106 |
+
def arguments(self) -> list[Binding]:
|
| 107 |
+
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
|
| 108 |
+
|
| 109 |
+
def call(self, ctx: Sequence[Binding | Expr]) -> str:
|
| 110 |
+
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# steps:
|
| 114 |
+
# 1. take the functional signature
|
| 115 |
+
# 2. use api.ufunc to convert it to template signature. this establishes
|
| 116 |
+
# the type of the template function
|
| 117 |
+
# 3. use api.ufunc (II) to generate a split struct / operator() signature.
|
| 118 |
+
# this establish context in which we call the template signature
|
| 119 |
+
#
|
| 120 |
+
# StructuredImplSignature context
|
| 121 |
+
# ~> functor constructor sig
|
| 122 |
+
#
|
| 123 |
+
# Functor constructor context
|
| 124 |
+
# ~> functor fields sig
|
| 125 |
+
#
|
| 126 |
+
# Functor apply context (functor fields + functor apply sig)
|
| 127 |
+
# ~> template sig
|
| 128 |
+
#
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
|
| 132 |
+
num_tensors = sum(
|
| 133 |
+
1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
|
| 134 |
+
)
|
| 135 |
+
return num_tensors == 2
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def compute_ufunc_cuda_functors(
|
| 139 |
+
g: NativeFunctionsGroup,
|
| 140 |
+
) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
|
| 141 |
+
# First, build the functors.
|
| 142 |
+
ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
|
| 143 |
+
ufunctors: list[str] = []
|
| 144 |
+
loops = g.out.ufunc_inner_loop
|
| 145 |
+
scalar_tensor_idx_lookup = {
|
| 146 |
+
UfuncKey.CUDAFunctorOnSelf: 1,
|
| 147 |
+
UfuncKey.CUDAFunctorOnOther: 0,
|
| 148 |
+
UfuncKey.CUDAFunctor: None,
|
| 149 |
+
}
|
| 150 |
+
if eligible_for_binary_scalar_specialization(g):
|
| 151 |
+
keys = [
|
| 152 |
+
UfuncKey.CUDAFunctorOnSelf,
|
| 153 |
+
UfuncKey.CUDAFunctorOnOther,
|
| 154 |
+
UfuncKey.CUDAFunctor,
|
| 155 |
+
]
|
| 156 |
+
else:
|
| 157 |
+
keys = [UfuncKey.CUDAFunctor]
|
| 158 |
+
for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
|
| 159 |
+
assert k not in loops, f"cannot use {k} on non-binary function"
|
| 160 |
+
for k in keys:
|
| 161 |
+
# If the key was directly defined, skip functor codegen; we assume the
|
| 162 |
+
# user already done it for us
|
| 163 |
+
if k in loops:
|
| 164 |
+
ufunctor_sig = UfunctorSignature(
|
| 165 |
+
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
|
| 166 |
+
)
|
| 167 |
+
for dtype in loops[k].supported_dtypes:
|
| 168 |
+
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
# Note [ScalarOnly and Generic must match names for CUDA]
|
| 172 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 173 |
+
# Otherwise, look in ANY of the generic entries. For simplicity of
|
| 174 |
+
# codegen, both ScalarOnly and Generic are defined, the ufunc name
|
| 175 |
+
# must match (if they didn't match, we'd have to generate distinct
|
| 176 |
+
# functors per dtype, which is awful, so we're not going to do it unless
|
| 177 |
+
# someone really forces us to)
|
| 178 |
+
ufunc_name = None
|
| 179 |
+
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
|
| 180 |
+
for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
|
| 181 |
+
if lk not in loops:
|
| 182 |
+
continue
|
| 183 |
+
if ufunc_name is None:
|
| 184 |
+
ufunc_name = loops[lk].name
|
| 185 |
+
else:
|
| 186 |
+
# See Note [ScalarOnly and Generic must match names for CUDA]
|
| 187 |
+
assert (
|
| 188 |
+
ufunc_name == loops[lk].name
|
| 189 |
+
), "ScalarOnly and Generic must have same ufunc name"
|
| 190 |
+
supported_dtypes |= loops[lk].supported_dtypes
|
| 191 |
+
assert ufunc_name is not None
|
| 192 |
+
|
| 193 |
+
name = f"{k}_{ufunc_name}"
|
| 194 |
+
ufunctor_sig = UfunctorSignature(
|
| 195 |
+
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
|
| 196 |
+
)
|
| 197 |
+
for dtype in supported_dtypes:
|
| 198 |
+
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
|
| 199 |
+
|
| 200 |
+
ufunc_sig = UfuncSignature(
|
| 201 |
+
g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
|
| 202 |
+
)
|
| 203 |
+
apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
|
| 204 |
+
ufunctors.append(
|
| 205 |
+
f"""
|
| 206 |
+
template <typename scalar_t>
|
| 207 |
+
struct {ufunctor_sig.name} {{
|
| 208 |
+
using opmath_t = at::opmath_type<scalar_t>;
|
| 209 |
+
{ufunctor_sig.decl_fields()}
|
| 210 |
+
{ufunctor_sig.inline_defn_ctor()}
|
| 211 |
+
__device__ {ufunctor_sig.decl_apply()} {{
|
| 212 |
+
return {ufunc_sig.call(apply_ctx)};
|
| 213 |
+
}}
|
| 214 |
+
}};
|
| 215 |
+
"""
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
return ufunctor_sigs, "\n".join(ufunctors)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@dataclass(frozen=True)
|
| 222 |
+
class BinaryScalarSpecializationConfig:
|
| 223 |
+
scalar_idx: int
|
| 224 |
+
ctor_tensor: str
|
| 225 |
+
ufunc_key: UfuncKey
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
BinaryScalarSpecializationConfigs = [
|
| 229 |
+
BinaryScalarSpecializationConfig(
|
| 230 |
+
scalar_idx=0,
|
| 231 |
+
ctor_tensor="self",
|
| 232 |
+
ufunc_key=UfuncKey.CUDAFunctorOnOther,
|
| 233 |
+
),
|
| 234 |
+
BinaryScalarSpecializationConfig(
|
| 235 |
+
scalar_idx=1,
|
| 236 |
+
ctor_tensor="other",
|
| 237 |
+
ufunc_key=UfuncKey.CUDAFunctorOnSelf,
|
| 238 |
+
),
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def compute_ufunc_cuda_dtype_body(
|
| 243 |
+
g: NativeFunctionsGroup,
|
| 244 |
+
dtype: ScalarType,
|
| 245 |
+
inner_loops: dict[UfuncKey, UfunctorSignature],
|
| 246 |
+
parent_ctx: Sequence[Binding],
|
| 247 |
+
) -> str:
|
| 248 |
+
body = "using opmath_t = at::opmath_type<scalar_t>;"
|
| 249 |
+
body += "if (false) {}\n" # for ease of codegen
|
| 250 |
+
for config in BinaryScalarSpecializationConfigs:
|
| 251 |
+
if config.ufunc_key not in inner_loops:
|
| 252 |
+
continue
|
| 253 |
+
ufunctor_sig = inner_loops[config.ufunc_key]
|
| 254 |
+
scalar_idx = config.scalar_idx + 1
|
| 255 |
+
# Make a copy and at the same time widen the type (not permissible
|
| 256 |
+
# without copy; we don't want to mutate the input argument anyway)
|
| 257 |
+
ctx: list[Expr | Binding] = list(parent_ctx)
|
| 258 |
+
ctx.append(
|
| 259 |
+
Expr(
|
| 260 |
+
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
|
| 261 |
+
type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
|
| 262 |
+
)
|
| 263 |
+
)
|
| 264 |
+
ufunctor_ctor_exprs_str = ", ".join(
|
| 265 |
+
a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# NB: ufunctor must be allocated before iter.remove_operand is called,
|
| 269 |
+
# as it relies on iter
|
| 270 |
+
body += f"""\
|
| 271 |
+
else if (iter.is_cpu_scalar({scalar_idx})) {{
|
| 272 |
+
{ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
|
| 273 |
+
iter.remove_operand({scalar_idx});
|
| 274 |
+
gpu_kernel(iter, ufunctor);
|
| 275 |
+
}}"""
|
| 276 |
+
|
| 277 |
+
ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
|
| 278 |
+
ufunctor_ctor_exprs_str = ", ".join(
|
| 279 |
+
a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
|
| 280 |
+
)
|
| 281 |
+
body += f"""
|
| 282 |
+
else {{
|
| 283 |
+
gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
|
| 284 |
+
}}
|
| 285 |
+
"""
|
| 286 |
+
return body
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@with_native_function
|
| 290 |
+
def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
|
| 291 |
+
# First, build the functors, indexing them by dtype
|
| 292 |
+
ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
|
| 293 |
+
|
| 294 |
+
# Next, build the conditionals
|
| 295 |
+
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
|
| 296 |
+
dtype_cases = []
|
| 297 |
+
for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
|
| 298 |
+
dtype_cases.append(
|
| 299 |
+
f"""
|
| 300 |
+
AT_DISPATCH_CASE(at::ScalarType::{dtype},
|
| 301 |
+
[&]() {{
|
| 302 |
+
{compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
|
| 303 |
+
}}
|
| 304 |
+
)
|
| 305 |
+
"""
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
dtype_cases_str = "\n".join(dtype_cases)
|
| 309 |
+
|
| 310 |
+
stub_sig = StubSignature(g)
|
| 311 |
+
|
| 312 |
+
return f"""
|
| 313 |
+
{ufunctors}
|
| 314 |
+
|
| 315 |
+
{stub_sig.type_defn()};
|
| 316 |
+
{stub_sig.dispatch_decl()};
|
| 317 |
+
|
| 318 |
+
{stub_sig.kernel_defn()} {{
|
| 319 |
+
AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
|
| 320 |
+
{dtype_cases_str}
|
| 321 |
+
);
|
| 322 |
+
}}
|
| 323 |
+
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
|
| 324 |
+
|
| 325 |
+
{sig.defn()} {{
|
| 326 |
+
{stub_sig.direct_call(sig.arguments())};
|
| 327 |
+
}}
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 332 |
+
#
|
| 333 |
+
# CPU STUFF
|
| 334 |
+
#
|
| 335 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@dataclass(frozen=True)
|
| 339 |
+
class StubSignature:
|
| 340 |
+
g: NativeFunctionsGroup
|
| 341 |
+
|
| 342 |
+
@property
|
| 343 |
+
def name(self) -> str:
|
| 344 |
+
return f"{str(self.g.functional.func.name.name)}_stub"
|
| 345 |
+
|
| 346 |
+
@property
|
| 347 |
+
def kernel_name(self) -> str:
|
| 348 |
+
return f"{str(self.g.functional.func.name.name)}_kernel"
|
| 349 |
+
|
| 350 |
+
@property
|
| 351 |
+
def type_name(self) -> str:
|
| 352 |
+
return f"{str(self.g.functional.func.name.name)}_fn"
|
| 353 |
+
|
| 354 |
+
def arguments(self) -> list[Binding]:
|
| 355 |
+
return ufunc.stub_arguments(self.g)
|
| 356 |
+
|
| 357 |
+
def type(self) -> str:
|
| 358 |
+
cpp_args = self.arguments()
|
| 359 |
+
return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
|
| 360 |
+
|
| 361 |
+
def dispatch_decl(self) -> str:
|
| 362 |
+
return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
|
| 363 |
+
|
| 364 |
+
def dispatch_defn(self) -> str:
|
| 365 |
+
return f"DEFINE_DISPATCH({self.name})"
|
| 366 |
+
|
| 367 |
+
def kernel_defn(self) -> str:
|
| 368 |
+
return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
|
| 369 |
+
|
| 370 |
+
def type_defn(self) -> str:
|
| 371 |
+
return f"using {self.type_name} = {self.type()}"
|
| 372 |
+
|
| 373 |
+
# must be called from context where this is TensorIteratorBase*
|
| 374 |
+
def call(self, ctx: Sequence[Binding]) -> str:
|
| 375 |
+
return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
|
| 376 |
+
|
| 377 |
+
# used in CUDA to skip the unnecessary dynamic dispatch
|
| 378 |
+
def direct_call(self, ctx: Sequence[Binding]) -> str:
|
| 379 |
+
return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
@with_native_function
|
| 383 |
+
def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
|
| 384 |
+
stub_sig = StubSignature(g)
|
| 385 |
+
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
|
| 386 |
+
|
| 387 |
+
return f"""
|
| 388 |
+
{stub_sig.type_defn()};
|
| 389 |
+
{stub_sig.dispatch_decl()};
|
| 390 |
+
{stub_sig.dispatch_defn()};
|
| 391 |
+
|
| 392 |
+
{sig.defn()} {{
|
| 393 |
+
{stub_sig.call(sig.arguments())};
|
| 394 |
+
}}
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def compute_ufunc_cpu_dtype_body(
|
| 399 |
+
g: NativeFunctionsGroup,
|
| 400 |
+
dtype: ScalarType,
|
| 401 |
+
inner_loops: dict[UfuncKey, UfuncSignature],
|
| 402 |
+
parent_ctx: Sequence[Binding],
|
| 403 |
+
) -> str:
|
| 404 |
+
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
|
| 405 |
+
assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
|
| 406 |
+
scalar_loop = inner_loops[UfuncKey.CPUScalar]
|
| 407 |
+
vec_loop = None
|
| 408 |
+
if UfuncKey.CPUVector in inner_loops:
|
| 409 |
+
vec_loop = inner_loops[UfuncKey.CPUVector]
|
| 410 |
+
|
| 411 |
+
# NB: We DON'T use translate here, because translate is
|
| 412 |
+
# incapable of CSE'ing the scalar accesses in case it is also
|
| 413 |
+
# used by Vectorized; also, the unpacking here is very simple
|
| 414 |
+
# and only affects Scalar; everything else is implicitly captured
|
| 415 |
+
# by the lambda
|
| 416 |
+
|
| 417 |
+
# Setup scalar in scope
|
| 418 |
+
body = []
|
| 419 |
+
ctx = []
|
| 420 |
+
for b in parent_ctx:
|
| 421 |
+
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
|
| 422 |
+
BaseTy.Scalar
|
| 423 |
+
):
|
| 424 |
+
continue
|
| 425 |
+
body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
|
| 426 |
+
ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
|
| 427 |
+
if vec_loop is not None:
|
| 428 |
+
for b in parent_ctx:
|
| 429 |
+
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
|
| 430 |
+
BaseTy.Scalar
|
| 431 |
+
):
|
| 432 |
+
continue
|
| 433 |
+
body.append(
|
| 434 |
+
f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
|
| 435 |
+
)
|
| 436 |
+
ctx.append(
|
| 437 |
+
Expr(
|
| 438 |
+
f"_v_{b.name}",
|
| 439 |
+
NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
|
| 440 |
+
)
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Setup lambda signature
|
| 444 |
+
# NB: simplified version of ufunctor_arguments
|
| 445 |
+
scalar_bindings = []
|
| 446 |
+
vec_bindings = []
|
| 447 |
+
for a in g.functional.func.arguments.flat_non_out:
|
| 448 |
+
if not a.type.is_tensor_like():
|
| 449 |
+
continue
|
| 450 |
+
assert a.type == BaseType(BaseTy.Tensor)
|
| 451 |
+
scalar_bindings.append(
|
| 452 |
+
Binding(
|
| 453 |
+
name=a.name,
|
| 454 |
+
nctype=NamedCType(a.name, BaseCType(scalar_t)),
|
| 455 |
+
argument=a,
|
| 456 |
+
)
|
| 457 |
+
)
|
| 458 |
+
if vec_loop is not None:
|
| 459 |
+
vec_bindings.append(
|
| 460 |
+
Binding(
|
| 461 |
+
name=a.name,
|
| 462 |
+
nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
|
| 463 |
+
argument=a,
|
| 464 |
+
)
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
|
| 468 |
+
r: list[Expr | Binding] = []
|
| 469 |
+
r.extend(ctx)
|
| 470 |
+
r.extend(b)
|
| 471 |
+
return r
|
| 472 |
+
|
| 473 |
+
body_str = "\n".join(body)
|
| 474 |
+
if vec_loop is not None:
|
| 475 |
+
return f"""
|
| 476 |
+
{body_str}
|
| 477 |
+
cpu_kernel_vec(iter,
|
| 478 |
+
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
|
| 479 |
+
[=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
|
| 480 |
+
);
|
| 481 |
+
"""
|
| 482 |
+
else:
|
| 483 |
+
return f"""
|
| 484 |
+
{body_str}
|
| 485 |
+
cpu_kernel(iter,
|
| 486 |
+
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
|
| 487 |
+
);
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
@with_native_function
|
| 492 |
+
def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
|
| 493 |
+
stub_sig = StubSignature(g)
|
| 494 |
+
|
| 495 |
+
# Reindex the ufunc by dtypes; processing generic/scalaronly as well
|
| 496 |
+
loops = g.out.ufunc_inner_loop
|
| 497 |
+
ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
|
| 498 |
+
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
|
| 499 |
+
lks = []
|
| 500 |
+
# ORDER MATTERS: this specifies overriding precedence
|
| 501 |
+
if k in loops: # should happen rarely
|
| 502 |
+
lks.append(k)
|
| 503 |
+
if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
|
| 504 |
+
lks.append(UfuncKey.ScalarOnly)
|
| 505 |
+
if UfuncKey.Generic in loops:
|
| 506 |
+
lks.append(UfuncKey.Generic)
|
| 507 |
+
# TODO: don't hardcode ufunc:: namespace here, should be centralized smh
|
| 508 |
+
for lk in lks:
|
| 509 |
+
for dtype in loops[lk].supported_dtypes:
|
| 510 |
+
compute_t: CType
|
| 511 |
+
if k is UfuncKey.CPUScalar:
|
| 512 |
+
compute_t = BaseCType(scalar_t)
|
| 513 |
+
elif k is UfuncKey.CPUVector:
|
| 514 |
+
compute_t = VectorizedCType(BaseCType(scalar_t))
|
| 515 |
+
else:
|
| 516 |
+
raise AssertionError
|
| 517 |
+
inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
|
| 518 |
+
if k not in inner_ufunc_sigs:
|
| 519 |
+
inner_ufunc_sigs[k] = UfuncSignature(
|
| 520 |
+
g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# Build the conditionals
|
| 524 |
+
dtype_cases = []
|
| 525 |
+
for dtype, inner_ufunc_sigs in ufunc_sigs.items():
|
| 526 |
+
dtype_cases.append(
|
| 527 |
+
f"""
|
| 528 |
+
AT_DISPATCH_CASE(at::ScalarType::{dtype},
|
| 529 |
+
[&]() {{
|
| 530 |
+
{compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
|
| 531 |
+
}}
|
| 532 |
+
)
|
| 533 |
+
"""
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
dtype_cases_str = "\n".join(dtype_cases)
|
| 537 |
+
return f"""
|
| 538 |
+
namespace {{
|
| 539 |
+
|
| 540 |
+
{stub_sig.kernel_defn()} {{
|
| 541 |
+
AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
|
| 542 |
+
{dtype_cases_str}
|
| 543 |
+
);
|
| 544 |
+
}}
|
| 545 |
+
|
| 546 |
+
}} // anonymous namespace
|
| 547 |
+
|
| 548 |
+
{stub_sig.type_defn()};
|
| 549 |
+
{stub_sig.dispatch_decl()};
|
| 550 |
+
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
|
| 551 |
+
"""
|
.venv/lib/python3.11/site-packages/torchgen/executorch/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/executorch/__pycache__/parse.cpython-311.pyc
ADDED
|
Binary file (6.98 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/custom_ops.cpython-311.pyc
ADDED
|
Binary file (7.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-311.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/custom_ops.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Sequence, TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
from torchgen import dest
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# disable import sorting to avoid circular dependency.
|
| 11 |
+
from torchgen.api.types import DispatcherSignature # usort: skip
|
| 12 |
+
from torchgen.context import method_with_native_function
|
| 13 |
+
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
|
| 14 |
+
from torchgen.utils import concatMap, Target
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from torchgen.executorch.model import ETKernelIndex
|
| 19 |
+
from torchgen.selective_build.selector import SelectiveBuilder
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
|
| 23 |
+
# model authoring side.
|
| 24 |
+
@dataclass(frozen=True)
|
| 25 |
+
class ComputeNativeFunctionStub:
|
| 26 |
+
@method_with_native_function
|
| 27 |
+
def __call__(self, f: NativeFunction) -> str | None:
|
| 28 |
+
if Variant.function not in f.variants:
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
sig = DispatcherSignature.from_schema(
|
| 32 |
+
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
|
| 33 |
+
)
|
| 34 |
+
assert sig is not None
|
| 35 |
+
if len(f.func.returns) == 0:
|
| 36 |
+
ret_name = ""
|
| 37 |
+
elif len(f.func.returns) == 1:
|
| 38 |
+
if f.func.arguments.out:
|
| 39 |
+
ret_name = f.func.arguments.out[0].name
|
| 40 |
+
else:
|
| 41 |
+
ret_name = next(
|
| 42 |
+
(
|
| 43 |
+
a.name
|
| 44 |
+
for a in f.func.arguments.flat_non_out
|
| 45 |
+
if a.type == f.func.returns[0].type
|
| 46 |
+
),
|
| 47 |
+
"",
|
| 48 |
+
)
|
| 49 |
+
if not ret_name:
|
| 50 |
+
# if return type is tensor
|
| 51 |
+
if f.func.returns[0].type == BaseType(BaseTy.Tensor):
|
| 52 |
+
# Returns an empty tensor
|
| 53 |
+
ret_name = "at::Tensor()"
|
| 54 |
+
else:
|
| 55 |
+
raise Exception( # noqa: TRY002
|
| 56 |
+
f"Can't handle this return type {f.func}"
|
| 57 |
+
) # noqa: TRY002
|
| 58 |
+
elif len(f.func.arguments.out) == len(f.func.returns):
|
| 59 |
+
# Returns a tuple of out arguments
|
| 60 |
+
tensor_type = "at::Tensor &"
|
| 61 |
+
comma = ", "
|
| 62 |
+
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
|
| 63 |
+
{comma.join([r.name for r in f.func.arguments.out])}
|
| 64 |
+
)"""
|
| 65 |
+
else:
|
| 66 |
+
assert all(
|
| 67 |
+
a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
|
| 68 |
+
), f"Only support tensor returns but got {f.func.returns}"
|
| 69 |
+
# Returns a tuple of empty tensors
|
| 70 |
+
tensor_type = "at::Tensor"
|
| 71 |
+
comma = ", "
|
| 72 |
+
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
|
| 73 |
+
{comma.join(["at::Tensor()" for _ in f.func.returns])}
|
| 74 |
+
)"""
|
| 75 |
+
ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
|
| 76 |
+
return f"""
|
| 77 |
+
{sig.defn()} {{
|
| 78 |
+
{ret_str}
|
| 79 |
+
}}
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def gen_custom_ops_registration(
|
| 84 |
+
*,
|
| 85 |
+
native_functions: Sequence[NativeFunction],
|
| 86 |
+
selector: SelectiveBuilder,
|
| 87 |
+
kernel_index: ETKernelIndex,
|
| 88 |
+
rocm: bool,
|
| 89 |
+
) -> tuple[str, str]:
|
| 90 |
+
"""
|
| 91 |
+
Generate custom ops registration code for dest.RegisterDispatchKey.
|
| 92 |
+
|
| 93 |
+
:param native_functions: a sequence of `NativeFunction`
|
| 94 |
+
:param selector: for selective build.
|
| 95 |
+
:param kernel_index: kernels for all the ops.
|
| 96 |
+
:param rocm: bool for dest.RegisterDispatchKey.
|
| 97 |
+
:return: generated C++ code to register custom operators into PyTorch
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
# convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
|
| 101 |
+
# TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
|
| 102 |
+
|
| 103 |
+
dispatch_key = DispatchKey.CPU
|
| 104 |
+
backend_index = kernel_index._to_backend_index()
|
| 105 |
+
static_init_dispatch_registrations = ""
|
| 106 |
+
ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
|
| 107 |
+
for native_function in native_functions:
|
| 108 |
+
ns_grouped_native_functions[native_function.namespace].append(native_function)
|
| 109 |
+
|
| 110 |
+
for namespace, functions in ns_grouped_native_functions.items():
|
| 111 |
+
if len(functions) == 0:
|
| 112 |
+
continue
|
| 113 |
+
dispatch_registrations_body = "\n".join(
|
| 114 |
+
list(
|
| 115 |
+
concatMap(
|
| 116 |
+
dest.RegisterDispatchKey(
|
| 117 |
+
backend_index,
|
| 118 |
+
Target.REGISTRATION,
|
| 119 |
+
selector,
|
| 120 |
+
rocm=rocm,
|
| 121 |
+
symint=False,
|
| 122 |
+
class_method_name=None,
|
| 123 |
+
skip_dispatcher_op_registration=False,
|
| 124 |
+
),
|
| 125 |
+
functions,
|
| 126 |
+
)
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
static_init_dispatch_registrations += f"""
|
| 130 |
+
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
|
| 131 |
+
{dispatch_registrations_body}
|
| 132 |
+
}};"""
|
| 133 |
+
anonymous_definition = "\n".join(
|
| 134 |
+
list(
|
| 135 |
+
concatMap(
|
| 136 |
+
dest.RegisterDispatchKey(
|
| 137 |
+
backend_index,
|
| 138 |
+
Target.ANONYMOUS_DEFINITION,
|
| 139 |
+
selector,
|
| 140 |
+
rocm=rocm,
|
| 141 |
+
symint=False,
|
| 142 |
+
class_method_name=None,
|
| 143 |
+
skip_dispatcher_op_registration=False,
|
| 144 |
+
),
|
| 145 |
+
native_functions,
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
)
|
| 149 |
+
return anonymous_definition, static_init_dispatch_registrations
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/et_cpp.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Sequence
|
| 4 |
+
|
| 5 |
+
from torchgen import local
|
| 6 |
+
from torchgen.api.types import (
|
| 7 |
+
ArgName,
|
| 8 |
+
BaseCType,
|
| 9 |
+
Binding,
|
| 10 |
+
ConstRefCType,
|
| 11 |
+
CType,
|
| 12 |
+
MutRefCType,
|
| 13 |
+
NamedCType,
|
| 14 |
+
SpecialArgName,
|
| 15 |
+
TupleCType,
|
| 16 |
+
VectorCType,
|
| 17 |
+
voidT,
|
| 18 |
+
)
|
| 19 |
+
from torchgen.executorch.api.types import (
|
| 20 |
+
ArrayRefCType,
|
| 21 |
+
BaseTypeToCppMapping,
|
| 22 |
+
OptionalCType,
|
| 23 |
+
scalarT,
|
| 24 |
+
tensorListT,
|
| 25 |
+
tensorT,
|
| 26 |
+
)
|
| 27 |
+
from torchgen.model import (
|
| 28 |
+
Argument,
|
| 29 |
+
Arguments,
|
| 30 |
+
BaseTy,
|
| 31 |
+
BaseType,
|
| 32 |
+
ListType,
|
| 33 |
+
NativeFunction,
|
| 34 |
+
OptionalType,
|
| 35 |
+
Return,
|
| 36 |
+
SelfArgument,
|
| 37 |
+
TensorOptionsArguments,
|
| 38 |
+
Type,
|
| 39 |
+
)
|
| 40 |
+
from torchgen.utils import assert_never
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
|
| 45 |
+
functions like at::add. It also serves as a native function API, which is the signature of kernels,
|
| 46 |
+
since in Executorch CppSignature is the same as NativeSignature.
|
| 47 |
+
|
| 48 |
+
Difference between this file and torchgen.api.cpp.py:
|
| 49 |
+
|
| 50 |
+
- Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with
|
| 51 |
+
torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch).
|
| 52 |
+
|
| 53 |
+
- Executorch doesn't support Dimname.
|
| 54 |
+
|
| 55 |
+
- Executorch runtime doesn't support SymInt, will treat it as int.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Translation of "value types" in JIT schema to C++ API type. Value
|
| 60 |
+
# types look the same no matter if they are argument types or return
|
| 61 |
+
# types. Returns None if the type in question is not a value type.
|
| 62 |
+
def valuetype_type(
|
| 63 |
+
t: Type,
|
| 64 |
+
*,
|
| 65 |
+
binds: ArgName,
|
| 66 |
+
remove_non_owning_ref_types: bool = False,
|
| 67 |
+
) -> NamedCType | None:
|
| 68 |
+
if isinstance(t, BaseType):
|
| 69 |
+
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
|
| 70 |
+
return None
|
| 71 |
+
# For SymInt we simply treat it as int.
|
| 72 |
+
elif str(t) == "SymInt":
|
| 73 |
+
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int]))
|
| 74 |
+
if remove_non_owning_ref_types:
|
| 75 |
+
if t.name == BaseTy.str:
|
| 76 |
+
raise AssertionError(
|
| 77 |
+
"string ref->value conversion: not implemented yet"
|
| 78 |
+
)
|
| 79 |
+
# All other BaseType currently map directly to BaseCppTypes.
|
| 80 |
+
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
|
| 81 |
+
elif isinstance(t, OptionalType):
|
| 82 |
+
elem = valuetype_type(t.elem, binds=binds)
|
| 83 |
+
if elem is None:
|
| 84 |
+
return None
|
| 85 |
+
return NamedCType(binds, OptionalCType(elem.type))
|
| 86 |
+
elif isinstance(t, ListType):
|
| 87 |
+
if str(t.elem) == "bool":
|
| 88 |
+
assert t.size is not None
|
| 89 |
+
return NamedCType(
|
| 90 |
+
binds, ArrayRefCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]))
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
return None
|
| 94 |
+
else:
|
| 95 |
+
raise AssertionError(f"unrecognized type {repr(t)}")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Translation of types occurring in JIT arguments to a C++ argument type.
|
| 99 |
+
# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
|
| 100 |
+
# For example, we'll return std::vector<int> instead of IntArrayRef.
|
| 101 |
+
# See Note [translation from C++ reference to value types]
|
| 102 |
+
def argumenttype_type(
|
| 103 |
+
t: Type,
|
| 104 |
+
*,
|
| 105 |
+
mutable: bool,
|
| 106 |
+
binds: ArgName,
|
| 107 |
+
remove_non_owning_ref_types: bool = False,
|
| 108 |
+
) -> NamedCType:
|
| 109 |
+
# If it's a value type, do the value type translation
|
| 110 |
+
r = valuetype_type(
|
| 111 |
+
t,
|
| 112 |
+
binds=binds,
|
| 113 |
+
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
| 114 |
+
)
|
| 115 |
+
if r is not None:
|
| 116 |
+
return r
|
| 117 |
+
if isinstance(t, BaseType):
|
| 118 |
+
if t.name == BaseTy.Tensor:
|
| 119 |
+
if mutable and not local.use_const_ref_for_mutable_tensors():
|
| 120 |
+
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
|
| 121 |
+
else:
|
| 122 |
+
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
|
| 123 |
+
elif t.name == BaseTy.Scalar:
|
| 124 |
+
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
|
| 125 |
+
else:
|
| 126 |
+
raise AssertionError(f"base type should have been value type {t}")
|
| 127 |
+
elif isinstance(t, OptionalType):
|
| 128 |
+
if str(t.elem) == "Tensor":
|
| 129 |
+
if mutable and not local.use_const_ref_for_mutable_tensors():
|
| 130 |
+
return NamedCType(
|
| 131 |
+
binds, MutRefCType(BaseCType(tensorT))
|
| 132 |
+
) # TODO: fix this discrepancy
|
| 133 |
+
else:
|
| 134 |
+
return NamedCType(
|
| 135 |
+
binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
|
| 136 |
+
)
|
| 137 |
+
elif str(t.elem) == "Scalar":
|
| 138 |
+
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
|
| 139 |
+
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
| 140 |
+
return NamedCType(binds, OptionalCType(elem.type))
|
| 141 |
+
elif isinstance(t, ListType):
|
| 142 |
+
# TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels.
|
| 143 |
+
if str(t.elem) == "Tensor":
|
| 144 |
+
return NamedCType(binds, BaseCType(tensorListT))
|
| 145 |
+
elif str(t.elem) == "Dimname":
|
| 146 |
+
raise NotImplementedError("Executorch doesn't support Dimname")
|
| 147 |
+
elif str(t.elem) == "Tensor?":
|
| 148 |
+
return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT))))
|
| 149 |
+
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
| 150 |
+
return NamedCType(binds, ArrayRefCType(elem.type))
|
| 151 |
+
else:
|
| 152 |
+
raise AssertionError(f"unrecognized type {repr(t)}")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Translate a JIT argument into its C++ type
|
| 156 |
+
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
|
| 157 |
+
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# Translation of a (non-multi) return type from JIT to C++
|
| 161 |
+
# N.B: returntype_type returns a CType, not a NamedCType.
|
| 162 |
+
# This is mostly because of the mismatch between return types and return names.
|
| 163 |
+
# e.g. a function with a return type of 'void' has 0 return names,
|
| 164 |
+
# and a function with a return type of 'std::tuple' has >1 return name.
|
| 165 |
+
def returntype_type(t: Type, *, mutable: bool) -> CType:
|
| 166 |
+
# placeholder is ignored
|
| 167 |
+
r = valuetype_type(t, binds="__placeholder__")
|
| 168 |
+
if r is not None:
|
| 169 |
+
return r.type
|
| 170 |
+
|
| 171 |
+
if isinstance(t, BaseType):
|
| 172 |
+
if t.name == BaseTy.Tensor:
|
| 173 |
+
if mutable:
|
| 174 |
+
if local.use_const_ref_for_mutable_tensors():
|
| 175 |
+
return ConstRefCType(BaseCType(tensorT))
|
| 176 |
+
else:
|
| 177 |
+
return MutRefCType(BaseCType(tensorT))
|
| 178 |
+
else:
|
| 179 |
+
# Note [Tensor Copy Returns]
|
| 180 |
+
# Currently, we use "Argument.is_write" to determine
|
| 181 |
+
# whether or not Tensor return types should be copies or references.
|
| 182 |
+
# If that ever changes, take a look at other locations of this note!
|
| 183 |
+
return BaseCType(tensorT)
|
| 184 |
+
elif t.name == BaseTy.Scalar:
|
| 185 |
+
return BaseCType(scalarT)
|
| 186 |
+
elif isinstance(t, ListType):
|
| 187 |
+
assert (
|
| 188 |
+
not mutable
|
| 189 |
+
), "Native functions should never return a mutable tensor list. They should return void."
|
| 190 |
+
elem = returntype_type(t.elem, mutable=False)
|
| 191 |
+
assert t.size is None, f"fixed size list returns not supported: {t}"
|
| 192 |
+
return VectorCType(elem)
|
| 193 |
+
|
| 194 |
+
raise AssertionError(f"unrecognized return type {t}")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Translation of a single return to its C++ type
|
| 198 |
+
def return_type(r: Return) -> CType:
|
| 199 |
+
return returntype_type(r.type, mutable=r.is_write)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# Translation of a full (possibly multi) return from JIT to its C++ type
|
| 203 |
+
def returns_type(rs: Sequence[Return]) -> CType:
|
| 204 |
+
if len(rs) == 0:
|
| 205 |
+
return BaseCType(voidT)
|
| 206 |
+
elif len(rs) == 1:
|
| 207 |
+
return return_type(rs[0])
|
| 208 |
+
else:
|
| 209 |
+
return TupleCType([return_type(r) for r in rs])
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
|
| 213 |
+
returns: list[str] = []
|
| 214 |
+
for i, r in enumerate(f.func.returns):
|
| 215 |
+
# If we have an inplace function, the return argument is
|
| 216 |
+
# implicitly named self.
|
| 217 |
+
# TODO: Consider incorporating this into the data model
|
| 218 |
+
if f.func.name.name.inplace:
|
| 219 |
+
assert i == 0, "illegal inplace function with multiple returns"
|
| 220 |
+
name = "self"
|
| 221 |
+
# If we are out function, the name is the name of the
|
| 222 |
+
# corresponding output function (r.name will get recorded
|
| 223 |
+
# in field_name later.)
|
| 224 |
+
elif f.func.is_out_fn():
|
| 225 |
+
name = f.func.arguments.out[i].name
|
| 226 |
+
# If the return argument is explicitly named...
|
| 227 |
+
elif r.name:
|
| 228 |
+
name_conflict = any(
|
| 229 |
+
r.name == a.name for a in f.func.schema_order_arguments()
|
| 230 |
+
)
|
| 231 |
+
if name_conflict and not f.func.is_out_fn():
|
| 232 |
+
name = f"{r.name}_return"
|
| 233 |
+
else:
|
| 234 |
+
name = r.name
|
| 235 |
+
# If there is no explicit name and no fallback name was passed in, we just name the output result,
|
| 236 |
+
# unless it's a multi-return, in which case it's result0,
|
| 237 |
+
# result1, etc (zero-indexed)
|
| 238 |
+
else:
|
| 239 |
+
name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
|
| 240 |
+
returns.append(name)
|
| 241 |
+
return returns
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
JIT_TO_CPP_DEFAULT = {
|
| 245 |
+
"False": "false",
|
| 246 |
+
"True": "true",
|
| 247 |
+
"None": "torch::executorch::nullopt", # UGH this one is type directed
|
| 248 |
+
"[]": "{}",
|
| 249 |
+
"contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
|
| 250 |
+
"long": "torch::executorch::kLong",
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# Convert a JIT default into C++ expression representing the default
|
| 255 |
+
def default_expr(d: str, t: Type) -> str:
|
| 256 |
+
if d == "None" and str(t) == "Tensor?":
|
| 257 |
+
return "{}"
|
| 258 |
+
if isinstance(t, BaseType) and t.name is BaseTy.str:
|
| 259 |
+
# Schema allows single quotes but C++ needs double
|
| 260 |
+
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
|
| 261 |
+
s = ""
|
| 262 |
+
i = 1
|
| 263 |
+
while i + 1 < len(d):
|
| 264 |
+
if d[i] != "\\":
|
| 265 |
+
if d[i] == '"':
|
| 266 |
+
s += '\\"'
|
| 267 |
+
else:
|
| 268 |
+
s += d[i]
|
| 269 |
+
i += 1
|
| 270 |
+
else:
|
| 271 |
+
if d[i + 1] == "'":
|
| 272 |
+
s += "'"
|
| 273 |
+
else:
|
| 274 |
+
s += d[i : i + 2]
|
| 275 |
+
i += 2
|
| 276 |
+
|
| 277 |
+
return f'"{s}"'
|
| 278 |
+
|
| 279 |
+
if isinstance(t, OptionalType):
|
| 280 |
+
if d == "None":
|
| 281 |
+
return "torch::executor::nullopt"
|
| 282 |
+
|
| 283 |
+
return default_expr(d, t.elem)
|
| 284 |
+
|
| 285 |
+
if isinstance(t, ListType):
|
| 286 |
+
if d.startswith("[") and d.endswith("]"):
|
| 287 |
+
return "{" + d[1:-1] + "}"
|
| 288 |
+
elif t.size is None:
|
| 289 |
+
# NOTE: Sized lists can have scalar defaults
|
| 290 |
+
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
|
| 291 |
+
|
| 292 |
+
return JIT_TO_CPP_DEFAULT.get(d, d)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Convert an argument into its C++ API form
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def argument(
|
| 299 |
+
a: Argument | TensorOptionsArguments | SelfArgument,
|
| 300 |
+
*,
|
| 301 |
+
cpp_no_default_args: set[str],
|
| 302 |
+
method: bool,
|
| 303 |
+
faithful: bool,
|
| 304 |
+
has_tensor_options: bool,
|
| 305 |
+
) -> list[Binding]:
|
| 306 |
+
def sub_argument(
|
| 307 |
+
a: Argument | TensorOptionsArguments | SelfArgument,
|
| 308 |
+
) -> list[Binding]:
|
| 309 |
+
return argument(
|
| 310 |
+
a,
|
| 311 |
+
cpp_no_default_args=cpp_no_default_args,
|
| 312 |
+
method=method,
|
| 313 |
+
faithful=faithful,
|
| 314 |
+
has_tensor_options=has_tensor_options,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
if isinstance(a, Argument):
|
| 318 |
+
binds: ArgName
|
| 319 |
+
if a.name == "memory_format" and has_tensor_options:
|
| 320 |
+
binds = SpecialArgName.possibly_redundant_memory_format
|
| 321 |
+
else:
|
| 322 |
+
binds = a.name
|
| 323 |
+
default: str | None = None
|
| 324 |
+
if a.name not in cpp_no_default_args and a.default is not None:
|
| 325 |
+
default = default_expr(a.default, a.type)
|
| 326 |
+
return [
|
| 327 |
+
Binding(
|
| 328 |
+
nctype=argument_type(a, binds=binds),
|
| 329 |
+
name=a.name,
|
| 330 |
+
default=default,
|
| 331 |
+
argument=a,
|
| 332 |
+
)
|
| 333 |
+
]
|
| 334 |
+
elif isinstance(a, TensorOptionsArguments):
|
| 335 |
+
raise NotImplementedError("Need to implement type resolution for TensorOptions")
|
| 336 |
+
elif isinstance(a, SelfArgument):
|
| 337 |
+
if method:
|
| 338 |
+
# Caller is responsible for installing implicit this in context!
|
| 339 |
+
return []
|
| 340 |
+
else:
|
| 341 |
+
return sub_argument(a.argument)
|
| 342 |
+
else:
|
| 343 |
+
assert_never(a)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def arguments(
|
| 347 |
+
arguments: Arguments,
|
| 348 |
+
*,
|
| 349 |
+
faithful: bool,
|
| 350 |
+
method: bool,
|
| 351 |
+
cpp_no_default_args: set[str],
|
| 352 |
+
) -> list[Binding]:
|
| 353 |
+
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
| 354 |
+
if faithful:
|
| 355 |
+
args.extend(arguments.non_out)
|
| 356 |
+
args.extend(arguments.out)
|
| 357 |
+
else:
|
| 358 |
+
args.extend(arguments.out)
|
| 359 |
+
args.extend(arguments.non_out)
|
| 360 |
+
return [
|
| 361 |
+
r.no_default() if faithful else r
|
| 362 |
+
for a in args
|
| 363 |
+
for r in argument(
|
| 364 |
+
a,
|
| 365 |
+
faithful=faithful,
|
| 366 |
+
method=method,
|
| 367 |
+
has_tensor_options=arguments.tensor_options is not None,
|
| 368 |
+
cpp_no_default_args=cpp_no_default_args,
|
| 369 |
+
)
|
| 370 |
+
]
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchgen.executorch.api.types.types import *
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from torchgen.executorch.api.types.signatures import * # usort: skip
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (326 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/signatures.cpython-311.pyc
ADDED
|
Binary file (4.76 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-311.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/signatures.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import torchgen.api.cpp as aten_cpp
|
| 7 |
+
from torchgen.executorch.api.types.types import contextArg
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from torchgen.api.types import Binding, CType
|
| 12 |
+
from torchgen.model import FunctionSchema, NativeFunction
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class ExecutorchCppSignature:
|
| 17 |
+
"""
|
| 18 |
+
This signature is merely a CppSignature with Executorch types (optionally
|
| 19 |
+
contains KernelRuntimeContext as well). The inline definition of
|
| 20 |
+
CppSignature is generated in Functions.h and it's used by unboxing
|
| 21 |
+
functions.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
# The schema this signature is derived from
|
| 25 |
+
func: FunctionSchema
|
| 26 |
+
|
| 27 |
+
# The set of C++ arguments which should not have defaults applied to them
|
| 28 |
+
cpp_no_default_args: set[str]
|
| 29 |
+
|
| 30 |
+
# Allows you to prepend an arbitrary prefix to the signature name.
|
| 31 |
+
# This is useful for parts of the codegen that generate wrappers around kernels,
|
| 32 |
+
# and need to avoid naming collisions.
|
| 33 |
+
prefix: str = ""
|
| 34 |
+
|
| 35 |
+
def arguments(self, *, include_context: bool = True) -> list[Binding]:
|
| 36 |
+
return ([contextArg] if include_context else []) + et_cpp.arguments(
|
| 37 |
+
self.func.arguments,
|
| 38 |
+
faithful=True, # always faithful, out argument at the end
|
| 39 |
+
method=False, # method not supported
|
| 40 |
+
cpp_no_default_args=self.cpp_no_default_args,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def name(self) -> str:
|
| 44 |
+
return self.prefix + aten_cpp.name(
|
| 45 |
+
self.func,
|
| 46 |
+
faithful_name_for_out_overloads=True,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def decl(self, name: str | None = None, *, include_context: bool = True) -> str:
|
| 50 |
+
args_str = ", ".join(
|
| 51 |
+
a.decl() for a in self.arguments(include_context=include_context)
|
| 52 |
+
)
|
| 53 |
+
if name is None:
|
| 54 |
+
name = self.name()
|
| 55 |
+
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
| 56 |
+
|
| 57 |
+
def defn(self, name: str | None = None) -> str:
|
| 58 |
+
args = [a.defn() for a in self.arguments()]
|
| 59 |
+
args_str = ", ".join(args)
|
| 60 |
+
if name is None:
|
| 61 |
+
name = self.name()
|
| 62 |
+
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
| 63 |
+
|
| 64 |
+
def returns_type(self) -> CType:
|
| 65 |
+
return et_cpp.returns_type(self.func.returns)
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def from_native_function(
|
| 69 |
+
f: NativeFunction, *, prefix: str = ""
|
| 70 |
+
) -> ExecutorchCppSignature:
|
| 71 |
+
return ExecutorchCppSignature(
|
| 72 |
+
func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
from torchgen.executorch.api import et_cpp
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/types/types.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
from torchgen.api.types import (
|
| 6 |
+
BaseCppType,
|
| 7 |
+
BaseCType,
|
| 8 |
+
Binding,
|
| 9 |
+
boolT,
|
| 10 |
+
CType,
|
| 11 |
+
doubleT,
|
| 12 |
+
Expr,
|
| 13 |
+
longT,
|
| 14 |
+
MutRefCType,
|
| 15 |
+
NamedCType,
|
| 16 |
+
)
|
| 17 |
+
from torchgen.model import BaseTy
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
halfT = BaseCppType("torch::executor", "Half")
|
| 21 |
+
bfloat16T = BaseCppType("torch::executor", "BFloat16")
|
| 22 |
+
stringT = BaseCppType("torch::executor", "string_view")
|
| 23 |
+
scalarTypeT = BaseCppType("torch::executor", "ScalarType")
|
| 24 |
+
tensorT = BaseCppType("torch::executor", "Tensor")
|
| 25 |
+
tensorListT = BaseCppType("torch::executor", "TensorList")
|
| 26 |
+
scalarT = BaseCppType("torch::executor", "Scalar")
|
| 27 |
+
memoryFormatT = BaseCppType("torch::executor", "MemoryFormat")
|
| 28 |
+
intArrayRefT = BaseCppType("torch::executor", "IntArrayRef")
|
| 29 |
+
optionalT = BaseCppType("torch::executor", "optional")
|
| 30 |
+
contextT = BaseCppType("torch::executor", "KernelRuntimeContext")
|
| 31 |
+
|
| 32 |
+
contextExpr = Expr(
|
| 33 |
+
expr="context",
|
| 34 |
+
type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
contextArg = Binding(
|
| 38 |
+
name="context",
|
| 39 |
+
nctype=contextExpr.type,
|
| 40 |
+
argument=None, # type: ignore[arg-type]
|
| 41 |
+
default=None,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
| 45 |
+
BaseTy.int: longT,
|
| 46 |
+
BaseTy.float: doubleT,
|
| 47 |
+
BaseTy.bool: boolT,
|
| 48 |
+
BaseTy.str: stringT,
|
| 49 |
+
BaseTy.ScalarType: scalarTypeT,
|
| 50 |
+
BaseTy.Tensor: tensorT,
|
| 51 |
+
BaseTy.Scalar: scalarT,
|
| 52 |
+
BaseTy.MemoryFormat: memoryFormatT,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass(frozen=True)
|
| 57 |
+
class OptionalCType(CType):
|
| 58 |
+
elem: CType
|
| 59 |
+
|
| 60 |
+
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
| 61 |
+
# Do not pass `strip_ref` recursively.
|
| 62 |
+
return f"torch::executor::optional<{self.elem.cpp_type()}>"
|
| 63 |
+
|
| 64 |
+
def cpp_type_registration_declarations(self) -> str:
|
| 65 |
+
return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>"
|
| 66 |
+
|
| 67 |
+
def remove_const_ref(self) -> CType:
|
| 68 |
+
return OptionalCType(self.elem.remove_const_ref())
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass(frozen=True)
|
| 72 |
+
class ArrayRefCType(CType):
|
| 73 |
+
elem: CType
|
| 74 |
+
|
| 75 |
+
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
| 76 |
+
# Do not pass `strip_ref` recursively.
|
| 77 |
+
return f"torch::executor::ArrayRef<{self.elem.cpp_type()}>"
|
| 78 |
+
|
| 79 |
+
def cpp_type_registration_declarations(self) -> str:
|
| 80 |
+
return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
|
| 81 |
+
|
| 82 |
+
def remove_const_ref(self) -> CType:
|
| 83 |
+
return ArrayRefCType(self.elem.remove_const_ref())
|
.venv/lib/python3.11/site-packages/torchgen/executorch/api/unboxing.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Callable, Sequence, TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
from torchgen.model import (
|
| 7 |
+
Argument,
|
| 8 |
+
BaseTy,
|
| 9 |
+
BaseType,
|
| 10 |
+
ListType,
|
| 11 |
+
NativeFunction,
|
| 12 |
+
OptionalType,
|
| 13 |
+
Type,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from torchgen.api.types import Binding, CType, NamedCType
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
connector = "\n\t"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Return unboxing function name for a NativeFunction
|
| 25 |
+
def name(f: NativeFunction) -> str:
|
| 26 |
+
return f.func.name.unambiguous_name()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class Unboxing:
|
| 31 |
+
"""
|
| 32 |
+
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
|
| 33 |
+
A sample generated code:
|
| 34 |
+
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
void mul_out(EValue** stack) {
|
| 36 |
+
EValue& self = *stack[0];
|
| 37 |
+
EValue& other = *stack[1];
|
| 38 |
+
EValue& out = *stack[2];
|
| 39 |
+
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
|
| 40 |
+
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
|
| 41 |
+
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
|
| 42 |
+
|
| 43 |
+
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
|
| 44 |
+
torch::executor::mul_outf(self_base, other_base, out_base);
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
}
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
# this is a callable that converts a JIT argument, into its C++ type.
|
| 51 |
+
# Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
|
| 52 |
+
argument_type_gen: Callable[
|
| 53 |
+
...,
|
| 54 |
+
NamedCType,
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
# Convert all the arguments in a NativeFunction to C++ code
|
| 58 |
+
def convert_arguments(
|
| 59 |
+
self, args: Sequence[Binding]
|
| 60 |
+
) -> tuple[list[Binding], list[str]]:
|
| 61 |
+
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
|
| 62 |
+
binding_list = []
|
| 63 |
+
for arg in args:
|
| 64 |
+
# expecting only Argument
|
| 65 |
+
if not isinstance(arg.argument, Argument):
|
| 66 |
+
raise Exception( # noqa: TRY002
|
| 67 |
+
f"Unexpected argument type, expecting `Argument` but got {arg}"
|
| 68 |
+
)
|
| 69 |
+
argument: Argument = arg.argument
|
| 70 |
+
unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
|
| 71 |
+
argument.type, argument.name, mutable=argument.is_write
|
| 72 |
+
)
|
| 73 |
+
code_list.extend(decl)
|
| 74 |
+
code_list.extend(code)
|
| 75 |
+
binding_list.append(arg.with_name(unboxed_name))
|
| 76 |
+
return binding_list, code_list
|
| 77 |
+
|
| 78 |
+
def argumenttype_evalue_convert(
|
| 79 |
+
self, t: Type, arg_name: str, *, mutable: bool = False
|
| 80 |
+
) -> tuple[str, CType, list[str], list[str]]:
|
| 81 |
+
"""
|
| 82 |
+
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
|
| 83 |
+
(1) the C++ code necessary to unbox the argument
|
| 84 |
+
(2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
|
| 85 |
+
:param t: a `Type` of an argument
|
| 86 |
+
:param arg_name: argument name
|
| 87 |
+
:param mutable: boolean for whether this argument type is mutable
|
| 88 |
+
:return: unboxed result
|
| 89 |
+
"""
|
| 90 |
+
ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
|
| 91 |
+
|
| 92 |
+
if isinstance(t, BaseType):
|
| 93 |
+
out_name = f"{arg_name}_base"
|
| 94 |
+
code, decl = self._gen_code_base_type(
|
| 95 |
+
arg_name=arg_name, out_name=out_name, ctype=ctype
|
| 96 |
+
)
|
| 97 |
+
elif isinstance(t, OptionalType):
|
| 98 |
+
out_name = f"{arg_name}_opt_out"
|
| 99 |
+
code, decl = self._gen_code_optional_type(
|
| 100 |
+
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
|
| 101 |
+
)
|
| 102 |
+
elif isinstance(t, ListType):
|
| 103 |
+
out_name = f"{arg_name}_list_out"
|
| 104 |
+
code, decl = self._gen_code_list_type(
|
| 105 |
+
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
raise Exception( # noqa: TRY002
|
| 109 |
+
f"Cannot handle type {t}. arg_name: {arg_name}"
|
| 110 |
+
) # noqa: TRY002
|
| 111 |
+
return out_name, ctype, code, decl
|
| 112 |
+
|
| 113 |
+
def _gen_code_base_type(
|
| 114 |
+
self, arg_name: str, out_name: str, ctype: CType
|
| 115 |
+
) -> tuple[list[str], list[str]]:
|
| 116 |
+
return [
|
| 117 |
+
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
| 118 |
+
], []
|
| 119 |
+
|
| 120 |
+
def _gen_code_optional_type(
|
| 121 |
+
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
|
| 122 |
+
) -> tuple[list[str], list[str]]:
|
| 123 |
+
in_name = f"{arg_name}_opt_in"
|
| 124 |
+
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
|
| 125 |
+
t.elem, in_name
|
| 126 |
+
)
|
| 127 |
+
return (
|
| 128 |
+
f"""
|
| 129 |
+
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
|
| 130 |
+
""".split(
|
| 131 |
+
"\n"
|
| 132 |
+
),
|
| 133 |
+
decl,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def _gen_code_list_type(
|
| 137 |
+
self, arg_name: str, out_name: str, t: ListType, ctype: CType
|
| 138 |
+
) -> tuple[list[str], list[str]]:
|
| 139 |
+
in_name = f"{arg_name}_list_in"
|
| 140 |
+
elem_name = f"{arg_name}_elem"
|
| 141 |
+
code = []
|
| 142 |
+
res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
|
| 143 |
+
t.elem, elem_name
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
|
| 147 |
+
code.extend(
|
| 148 |
+
f"""
|
| 149 |
+
auto {out_name} = {arg_name}.toTensorList();
|
| 150 |
+
""".split(
|
| 151 |
+
"\n"
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
elif isinstance(t.elem, BaseType) and (
|
| 155 |
+
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
|
| 156 |
+
):
|
| 157 |
+
code.extend(
|
| 158 |
+
f"""
|
| 159 |
+
auto {out_name} = {arg_name}.toIntList();
|
| 160 |
+
""".split(
|
| 161 |
+
"\n"
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
|
| 165 |
+
code.extend(
|
| 166 |
+
f"""
|
| 167 |
+
auto {out_name} = {arg_name}.toDoubleList();
|
| 168 |
+
""".split(
|
| 169 |
+
"\n"
|
| 170 |
+
)
|
| 171 |
+
)
|
| 172 |
+
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
|
| 173 |
+
# handle list type with size, e.g., bool[4]
|
| 174 |
+
code.extend(
|
| 175 |
+
f"""
|
| 176 |
+
#ifdef USE_ATEN_LIB
|
| 177 |
+
std::array<bool, {t.size}> {out_name};
|
| 178 |
+
auto {in_name} = {arg_name}.toBoolList();
|
| 179 |
+
size_t _i = 0;
|
| 180 |
+
for (auto {elem_name}: {in_name}) {{
|
| 181 |
+
{out_name}[_i++] = {elem_name};
|
| 182 |
+
}}
|
| 183 |
+
#else
|
| 184 |
+
auto {out_name} = {arg_name}.toBoolList();
|
| 185 |
+
#endif
|
| 186 |
+
""".split(
|
| 187 |
+
"\n"
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
# pytorch codegen:
|
| 191 |
+
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
|
| 192 |
+
elif (
|
| 193 |
+
isinstance(t.elem, OptionalType)
|
| 194 |
+
and isinstance(t.elem.elem, BaseType)
|
| 195 |
+
and t.elem.elem.name == BaseTy.Tensor
|
| 196 |
+
):
|
| 197 |
+
code.extend(
|
| 198 |
+
f"""
|
| 199 |
+
#ifdef USE_ATEN_LIB
|
| 200 |
+
auto {in_name} = {arg_name}.toListOptionalTensor();
|
| 201 |
+
c10::List<::std::optional<at::Tensor>> {out_name};
|
| 202 |
+
for (auto {elem_name}: {in_name}) {{
|
| 203 |
+
{out_name}.push_back({elem_name});
|
| 204 |
+
}}
|
| 205 |
+
#else
|
| 206 |
+
auto {out_name} = {arg_name}.toListOptionalTensor();
|
| 207 |
+
#endif
|
| 208 |
+
""".split(
|
| 209 |
+
"\n"
|
| 210 |
+
)
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
# use ArrayRef as default.
|
| 214 |
+
vec_name = arg_name + "_vec"
|
| 215 |
+
# need to bring vector instantiation out of scope so that ArrayRef has valid data
|
| 216 |
+
decl.append(
|
| 217 |
+
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
|
| 218 |
+
)
|
| 219 |
+
code.extend(
|
| 220 |
+
f"""
|
| 221 |
+
for (EValue {elem_name}: {in_name}) {{
|
| 222 |
+
{connector.join(res_code)}
|
| 223 |
+
{vec_name}.push_back({res_name});
|
| 224 |
+
}}
|
| 225 |
+
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
|
| 226 |
+
""".split(
|
| 227 |
+
"\n"
|
| 228 |
+
)
|
| 229 |
+
)
|
| 230 |
+
return code, decl
|
.venv/lib/python3.11/site-packages/torchgen/executorch/model.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Represents all kernels used by an Executorch model.
|
| 2 |
+
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
from collections import defaultdict, namedtuple
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from enum import IntEnum
|
| 10 |
+
|
| 11 |
+
from torchgen.model import (
|
| 12 |
+
BackendIndex,
|
| 13 |
+
BackendMetadata,
|
| 14 |
+
DispatchKey,
|
| 15 |
+
NativeFunction,
|
| 16 |
+
NativeFunctionsGroup,
|
| 17 |
+
OperatorName,
|
| 18 |
+
)
|
| 19 |
+
from torchgen.utils import assert_never
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
KERNEL_KEY_VERSION = 1
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
|
| 26 |
+
class ScalarType(IntEnum):
|
| 27 |
+
Byte = 0
|
| 28 |
+
Char = 1
|
| 29 |
+
Short = 2
|
| 30 |
+
Int = 3
|
| 31 |
+
Long = 4
|
| 32 |
+
Float = 6
|
| 33 |
+
Double = 7
|
| 34 |
+
Bool = 11
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass(frozen=True)
|
| 41 |
+
class ETKernelKeyOpArgMeta:
|
| 42 |
+
arg_name: str
|
| 43 |
+
dtype: str
|
| 44 |
+
# The order of the dimensions if entry is a Tensor
|
| 45 |
+
dim_order: tuple[int, ...]
|
| 46 |
+
|
| 47 |
+
def to_native_string(self) -> str:
|
| 48 |
+
dtype_str = ScalarType[self.dtype].value
|
| 49 |
+
dim_str = str(self.dim_order)[1:-1].replace(" ", "")
|
| 50 |
+
return f"{dtype_str};{dim_str}"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass(frozen=True)
|
| 54 |
+
class ETKernelKey:
|
| 55 |
+
# Field undefined is default = True
|
| 56 |
+
arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
|
| 57 |
+
|
| 58 |
+
# Indicator for this kernel being used as a catch all
|
| 59 |
+
default: bool = False
|
| 60 |
+
|
| 61 |
+
version: int = KERNEL_KEY_VERSION
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def gen_from_yaml(
|
| 65 |
+
args: dict[str, tuple[str, str]],
|
| 66 |
+
type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val
|
| 67 |
+
dim_order_alias_map: dict[str, list[int]],
|
| 68 |
+
) -> list[ETKernelKey]:
|
| 69 |
+
"""Generate ETKernelKeys from arg kernel specs
|
| 70 |
+
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
|
| 71 |
+
type_alias_map (actualizing each potential type permutation as a KernelKey)
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
args: Mapping from argument name to kernel specs
|
| 75 |
+
Kernel specs are a tuple of (dtype, dim_order).
|
| 76 |
+
Currently tuple entries must be aliased via the alias map arguments
|
| 77 |
+
type_alias_map: Mapping from type alias to potential type enums
|
| 78 |
+
i.e { T0 : [Double, Int] } means T0 can be either Double or Int
|
| 79 |
+
Used for lookup by args
|
| 80 |
+
dim_order_alias_map: Mapping from alias to a list of dimension orders
|
| 81 |
+
Used for lookup by args
|
| 82 |
+
"""
|
| 83 |
+
# Cast to dim order to int
|
| 84 |
+
dim_order_alias_map = {
|
| 85 |
+
k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
|
| 86 |
+
}
|
| 87 |
+
kernel_keys = []
|
| 88 |
+
|
| 89 |
+
# Get all used Dtype Alias
|
| 90 |
+
dtype_alias_used = set()
|
| 91 |
+
for type_alias, dim_order in args.values():
|
| 92 |
+
# Enforce usage of alias initially
|
| 93 |
+
# TODO: Support inlined arguments
|
| 94 |
+
assert type_alias in type_alias_map, "Undefined type alias: " + str(
|
| 95 |
+
type_alias
|
| 96 |
+
)
|
| 97 |
+
assert (
|
| 98 |
+
dim_order in dim_order_alias_map
|
| 99 |
+
), "Undefined dim_order alias: " + str(dim_order)
|
| 100 |
+
dtype_alias_used.add(type_alias)
|
| 101 |
+
|
| 102 |
+
# Generate all permutations of dtype alias values
|
| 103 |
+
alias_dtypes = [
|
| 104 |
+
[(alias, dtype) for dtype in type_alias_map[alias]]
|
| 105 |
+
for alias in dtype_alias_used
|
| 106 |
+
]
|
| 107 |
+
alias_permutations = [
|
| 108 |
+
dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
# Using each alias value permutation, generate kernel keys
|
| 112 |
+
op_arg_cache = {}
|
| 113 |
+
for permutation in alias_permutations:
|
| 114 |
+
arg_list = []
|
| 115 |
+
for arg_name, arg_spec in args.items():
|
| 116 |
+
dtype = permutation[arg_spec[0]]
|
| 117 |
+
dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment]
|
| 118 |
+
if (
|
| 119 |
+
cache_key := (arg_name, dtype, tuple(dim_order))
|
| 120 |
+
) not in op_arg_cache:
|
| 121 |
+
op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type]
|
| 122 |
+
|
| 123 |
+
arg_list.append(op_arg_cache[cache_key])
|
| 124 |
+
kernel_keys.append(ETKernelKey(tuple(arg_list)))
|
| 125 |
+
|
| 126 |
+
return kernel_keys
|
| 127 |
+
|
| 128 |
+
def to_native_string(self) -> str:
|
| 129 |
+
if self.default:
|
| 130 |
+
return "default"
|
| 131 |
+
return (
|
| 132 |
+
"v"
|
| 133 |
+
+ str(KERNEL_KEY_VERSION)
|
| 134 |
+
+ "/"
|
| 135 |
+
+ "|".join([arg.to_native_string() for arg in self.arg_meta])
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@dataclass(frozen=True)
|
| 140 |
+
class ETKernelIndex:
|
| 141 |
+
index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
|
| 142 |
+
|
| 143 |
+
def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
|
| 144 |
+
m = self.get_kernels(g)
|
| 145 |
+
return m is not None
|
| 146 |
+
|
| 147 |
+
def get_kernels(
|
| 148 |
+
self, g: NativeFunction | NativeFunctionsGroup
|
| 149 |
+
) -> dict[ETKernelKey, BackendMetadata]:
|
| 150 |
+
if isinstance(g, NativeFunction):
|
| 151 |
+
f = g
|
| 152 |
+
elif isinstance(g, NativeFunctionsGroup):
|
| 153 |
+
f = g.functional
|
| 154 |
+
else:
|
| 155 |
+
assert_never(g)
|
| 156 |
+
if f.func.name not in self.index:
|
| 157 |
+
return {}
|
| 158 |
+
return self.index[f.func.name]
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
def grow_from_backend_indices(
|
| 162 |
+
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
|
| 163 |
+
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
| 164 |
+
) -> None:
|
| 165 |
+
for dk in backend_indices:
|
| 166 |
+
index = backend_indices[dk]
|
| 167 |
+
for op, backend_metadata in index.items():
|
| 168 |
+
if op in kernel_index:
|
| 169 |
+
kernel_index[op][ETKernelKey(default=True)] = backend_metadata
|
| 170 |
+
else:
|
| 171 |
+
kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
|
| 172 |
+
|
| 173 |
+
@staticmethod
|
| 174 |
+
def from_backend_indices(
|
| 175 |
+
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
|
| 176 |
+
) -> ETKernelIndex:
|
| 177 |
+
kernel_index: dict[
|
| 178 |
+
OperatorName, dict[ETKernelKey, BackendMetadata]
|
| 179 |
+
] = defaultdict(dict)
|
| 180 |
+
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
|
| 181 |
+
return ETKernelIndex(kernel_index)
|
| 182 |
+
|
| 183 |
+
def grow(
|
| 184 |
+
self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
|
| 185 |
+
) -> ETKernelIndex:
|
| 186 |
+
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
|
| 187 |
+
return self
|
| 188 |
+
|
| 189 |
+
def _to_backend_index(self) -> BackendIndex:
|
| 190 |
+
"""
|
| 191 |
+
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
|
| 192 |
+
"""
|
| 193 |
+
index: dict[OperatorName, BackendMetadata] = {}
|
| 194 |
+
for op in self.index:
|
| 195 |
+
kernel_dict = self.index[op]
|
| 196 |
+
assert (
|
| 197 |
+
len(kernel_dict.values()) == 1
|
| 198 |
+
), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
|
| 199 |
+
index[op] = kernel_dict.get(
|
| 200 |
+
ETKernelKey(default=True),
|
| 201 |
+
BackendMetadata(kernel="", structured=False, cpp_namespace=""),
|
| 202 |
+
)
|
| 203 |
+
return BackendIndex(
|
| 204 |
+
dispatch_key=DispatchKey.CPU,
|
| 205 |
+
use_out_as_primary=False,
|
| 206 |
+
device_guard=False,
|
| 207 |
+
external=False,
|
| 208 |
+
index=index,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
|
| 212 |
+
@staticmethod
|
| 213 |
+
def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
|
| 214 |
+
combined = defaultdict(dict, index_a.index.copy())
|
| 215 |
+
|
| 216 |
+
for op, entry in index_b.index.items():
|
| 217 |
+
for key, metadata in entry.items():
|
| 218 |
+
combined[op][key] = metadata
|
| 219 |
+
|
| 220 |
+
return ETKernelIndex(combined)
|
.venv/lib/python3.11/site-packages/torchgen/executorch/parse.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict, namedtuple
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import yaml
|
| 7 |
+
|
| 8 |
+
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
|
| 9 |
+
from torchgen.gen import LineLoader, parse_native_yaml
|
| 10 |
+
from torchgen.model import (
|
| 11 |
+
BackendMetadata,
|
| 12 |
+
DispatchKey,
|
| 13 |
+
FunctionSchema,
|
| 14 |
+
NativeFunction,
|
| 15 |
+
OperatorName,
|
| 16 |
+
)
|
| 17 |
+
from torchgen.utils import NamespaceHelper
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices.
|
| 21 |
+
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"])
|
| 22 |
+
|
| 23 |
+
# Fields in native_functions.yaml used to determine which kernels should be used
|
| 24 |
+
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
|
| 28 |
+
"""Given a loaded yaml representing kernel assignment information, extract the
|
| 29 |
+
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
ei: Dict keys {kernels, type_alias, dim_order_alias}
|
| 33 |
+
See ETKernelKey for description of arguments
|
| 34 |
+
"""
|
| 35 |
+
e = ei.copy()
|
| 36 |
+
if (kernels := e.pop("kernels", None)) is None:
|
| 37 |
+
return {}
|
| 38 |
+
|
| 39 |
+
type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
|
| 40 |
+
dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
|
| 41 |
+
dim_order_alias.pop("__line__", None)
|
| 42 |
+
|
| 43 |
+
kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}
|
| 44 |
+
|
| 45 |
+
for entry in kernels: # type: ignore[attr-defined]
|
| 46 |
+
arg_meta = entry.get("arg_meta")
|
| 47 |
+
if arg_meta is not None:
|
| 48 |
+
arg_meta.pop("__line__")
|
| 49 |
+
|
| 50 |
+
kernel_name = entry.get("kernel_name")
|
| 51 |
+
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
| 52 |
+
kernel_name, max_level=3
|
| 53 |
+
)
|
| 54 |
+
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
|
| 55 |
+
backend_metadata = BackendMetadata(
|
| 56 |
+
kernel=namespace_helper.entity_name,
|
| 57 |
+
structured=False,
|
| 58 |
+
cpp_namespace=(kernel_namespace + "::native"),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
kernel_keys = (
|
| 62 |
+
[ETKernelKey((), default=True)]
|
| 63 |
+
if arg_meta is None
|
| 64 |
+
else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
for kernel_key in kernel_keys:
|
| 68 |
+
assert kernel_key not in kernel_mapping, (
|
| 69 |
+
"Duplicate kernel key: " + str(kernel_key) + " " + str(e)
|
| 70 |
+
)
|
| 71 |
+
kernel_mapping[kernel_key] = backend_metadata
|
| 72 |
+
|
| 73 |
+
return kernel_mapping
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def parse_et_yaml_struct(es: object) -> ETKernelIndex:
|
| 77 |
+
"""Given a loaded yaml representing a list of operators, for each op extract the mapping
|
| 78 |
+
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
|
| 79 |
+
that should be used by the kernel key).
|
| 80 |
+
"""
|
| 81 |
+
indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
|
| 82 |
+
for ei in es: # type: ignore[attr-defined]
|
| 83 |
+
e = ei.copy()
|
| 84 |
+
|
| 85 |
+
funcs = e.pop("func")
|
| 86 |
+
assert isinstance(funcs, str), f"not a str: {funcs}"
|
| 87 |
+
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
| 88 |
+
namespaced_entity=funcs, max_level=1
|
| 89 |
+
)
|
| 90 |
+
opname = FunctionSchema.parse(namespace_helper.entity_name).name
|
| 91 |
+
|
| 92 |
+
assert opname not in indices, f"Duplicate func found in yaml: {opname} already"
|
| 93 |
+
|
| 94 |
+
if len(index := parse_from_yaml(e)) != 0:
|
| 95 |
+
indices[opname] = index
|
| 96 |
+
|
| 97 |
+
return ETKernelIndex(indices)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
|
| 101 |
+
"""Given a loaded yaml representing a list of operators, extract the
|
| 102 |
+
kernel key related fields indexed by the operator name.
|
| 103 |
+
"""
|
| 104 |
+
fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
|
| 105 |
+
for ei in es: # type: ignore[attr-defined]
|
| 106 |
+
funcs = ei.get("func")
|
| 107 |
+
assert isinstance(funcs, str), f"not a str: {funcs}"
|
| 108 |
+
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
| 109 |
+
namespaced_entity=funcs, max_level=1
|
| 110 |
+
)
|
| 111 |
+
opname = FunctionSchema.parse(namespace_helper.entity_name).name
|
| 112 |
+
|
| 113 |
+
for field in ET_FIELDS:
|
| 114 |
+
if (value := ei.get(field)) is not None:
|
| 115 |
+
fields[opname][field] = value
|
| 116 |
+
|
| 117 |
+
return fields
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def parse_et_yaml(
|
| 121 |
+
path: str,
|
| 122 |
+
tags_yaml_path: str,
|
| 123 |
+
ignore_keys: set[DispatchKey] | None = None,
|
| 124 |
+
skip_native_fns_gen: bool = False,
|
| 125 |
+
) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
|
| 126 |
+
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
|
| 127 |
+
of fields to persist from native_functions.yaml to functions.yaml
|
| 128 |
+
"""
|
| 129 |
+
with open(path) as f:
|
| 130 |
+
es = yaml.load(f, Loader=LineLoader)
|
| 131 |
+
|
| 132 |
+
et_kernel = extract_kernel_fields(es)
|
| 133 |
+
|
| 134 |
+
# Remove ET specific fields from entries for BC compatibility
|
| 135 |
+
strip_et_fields(es)
|
| 136 |
+
|
| 137 |
+
native_yaml = parse_native_yaml(
|
| 138 |
+
path,
|
| 139 |
+
tags_yaml_path,
|
| 140 |
+
ignore_keys,
|
| 141 |
+
skip_native_fns_gen=skip_native_fns_gen,
|
| 142 |
+
loaded_yaml=es,
|
| 143 |
+
)
|
| 144 |
+
return native_yaml.native_functions, et_kernel
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def strip_et_fields(es: object) -> None:
|
| 148 |
+
"""Given a loaded yaml representing a list of operators,
|
| 149 |
+
remove ET specific fields from every entries for BC compatibility
|
| 150 |
+
"""
|
| 151 |
+
for entry in es: # type: ignore[attr-defined]
|
| 152 |
+
for field in ET_FIELDS:
|
| 153 |
+
entry.pop(field, None)
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/native/native_functions.yaml
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/native/tags.yaml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This yaml file contains all the possible tags that can be defined in `tags` in `native_functions.yaml`
|
| 2 |
+
|
| 3 |
+
- tag: inplace_view
|
| 4 |
+
desc: |
|
| 5 |
+
This tag indicates if an operator *only* modifies the tensor metadata
|
| 6 |
+
- tag: pt2_compliant_tag
|
| 7 |
+
desc: |
|
| 8 |
+
This tag indicates if the operator is guaranteed to
|
| 9 |
+
work with the PT2 compilation APIs (torch.compile,
|
| 10 |
+
torch.export, etc). If you add this tag to an
|
| 11 |
+
operator, please use
|
| 12 |
+
`torch.testing._internal.optest.opcheck` to test that
|
| 13 |
+
the operator has been registered correctly and
|
| 14 |
+
works with torch.compile
|
| 15 |
+
- tag: view_copy
|
| 16 |
+
desc: |
|
| 17 |
+
This tag indicates operators that are *_copy* variants
|
| 18 |
+
of view/aliasing operators. If an operator has a view_copy tag,
|
| 19 |
+
then it should have the name {op}_copy, where {op} is a view operator.
|
| 20 |
+
- tag: dynamic_output_shape
|
| 21 |
+
desc: |
|
| 22 |
+
This tag indicates if an operator's output's shape depends on input Tensor
|
| 23 |
+
data.
|
| 24 |
+
- tag: data_dependent_output
|
| 25 |
+
desc: |
|
| 26 |
+
Operator has a non-Tensor output whose value is dependent on the data
|
| 27 |
+
of Tensor inputs. Among other things, this implies that this operator
|
| 28 |
+
cannot be run with meta tensor (since data is not available), nor
|
| 29 |
+
can it be symbolically traced.
|
| 30 |
+
- tag: generated
|
| 31 |
+
desc: |
|
| 32 |
+
This tag indicates that the operator doesn't have an explicit entry in
|
| 33 |
+
native_functions.yaml, and instead was generated automatically by the codegen.
|
| 34 |
+
- tag: nondeterministic_seeded
|
| 35 |
+
desc: |
|
| 36 |
+
This tag indicates if an operator is nondeterministically seeded
|
| 37 |
+
(i.e., is random) such that the operator intentionally produces
|
| 38 |
+
different results when run twice on the same inputs, but this randomness
|
| 39 |
+
is controlled by a Generator which, if reseeded would give you the
|
| 40 |
+
same result.
|
| 41 |
+
- tag: nondeterministic_bitwise
|
| 42 |
+
desc: |
|
| 43 |
+
This tag indicates if an operator doesn't guarantee bitwise equivalence
|
| 44 |
+
across different runs of an operator with identical inputs.
|
| 45 |
+
- tag: needs_fixed_stride_order
|
| 46 |
+
desc: |
|
| 47 |
+
This tag indicates that the operator should be passed Tensors following
|
| 48 |
+
the same stride permutation as observed in eager when compiled in inductor.
|
| 49 |
+
Only one of {needs_fixed_stride_order, flexible_layout} can apply; if
|
| 50 |
+
multiple are assigned then we assume the most restrictive one.
|
| 51 |
+
- tag: flexible_layout
|
| 52 |
+
desc: |
|
| 53 |
+
This tag indicates that the custom operator can accept inputs with varying
|
| 54 |
+
strides/storage_offset and that when compiled, Inductor is allowed to change
|
| 55 |
+
the strides/storage_offset of inputs to the custom operator.
|
| 56 |
+
Only one of {needs_fixed_stride_order, flexible_layout} can apply; if
|
| 57 |
+
multiple are assigned then we assume the most restrictive one.
|
| 58 |
+
|
| 59 |
+
# NOTE [Core ATen Ops]
|
| 60 |
+
- tag: core
|
| 61 |
+
desc: |
|
| 62 |
+
Core aten ops is a subset of aten ops that remains after aten-to-aten decomposition and
|
| 63 |
+
functionalization pass. Core aten ops are fully functional and adhere to single static
|
| 64 |
+
assignment (SSA): this implies there will be no `inplace` or `_out` variants in this opset.
|
| 65 |
+
This opset is designed to serve as the functional IR to interface with compiler backends.
|
| 66 |
+
In contrast to primTorch, core aten opset doesn't decompose ops into explicit
|
| 67 |
+
type promotion and broadcasting ops.
|
| 68 |
+
Core aten ops is also effectively the opset produced by torchdynamo.export(aten_graph=True),
|
| 69 |
+
and thus can be used as an opset for export purpose.
|
| 70 |
+
- tag: pointwise
|
| 71 |
+
desc: |
|
| 72 |
+
Pointwise operators are operators where each element of the output is computed only by accessing
|
| 73 |
+
the corresponding element of all the broadcasted inputs. The output shape will be the broadcasted
|
| 74 |
+
shape of the inputs.
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/ATenOpList.cpp
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/ATenOpList.h>
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include <cstring>
|
| 5 |
+
#include <utility>
|
| 6 |
+
#include <unordered_set>
|
| 7 |
+
#include <ATen/core/operator_name.h>
|
| 8 |
+
|
| 9 |
+
// ${generated_comment}
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
|
| 13 |
+
namespace {
|
| 14 |
+
struct OpNameEquals final {
|
| 15 |
+
bool operator()(const std::pair<const char*, const char*>& lhs, const std::pair<const char*, const char*>& rhs) const {
|
| 16 |
+
return 0 == strcmp(lhs.first, rhs.first) && 0 == strcmp(lhs.second, rhs.second);
|
| 17 |
+
}
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
struct OpNameHash final {
|
| 21 |
+
size_t operator()(const std::pair<const char*, const char*>& p) const {
|
| 22 |
+
// use std::hash<std::string> because std::hash<const char*> would hash pointers and not pointed-to strings
|
| 23 |
+
return std::hash<std::string>()(p.first) ^ (~ std::hash<std::string>()(p.second));
|
| 24 |
+
}
|
| 25 |
+
};
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
bool is_custom_op(const c10::OperatorName& opName) {
|
| 29 |
+
static std::unordered_set<std::pair<const char*, const char*>, OpNameHash, OpNameEquals> ops {
|
| 30 |
+
${aten_ops}
|
| 31 |
+
{"", ""}
|
| 32 |
+
};
|
| 33 |
+
return ops.count(std::make_pair(
|
| 34 |
+
opName.name.c_str(), opName.overload_name.c_str())) == 0;
|
| 35 |
+
}
|
| 36 |
+
}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
| 2 |
+
// ${generated_comment}
|
| 3 |
+
|
| 4 |
+
#include <ATen/InferSize.h>
|
| 5 |
+
#include <ATen/Tensor.h>
|
| 6 |
+
#include <ATen/native/Resize.h>
|
| 7 |
+
|
| 8 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 9 |
+
#include <ATen/Operators.h>
|
| 10 |
+
#else
|
| 11 |
+
#include <ATen/ops/clone.h>
|
| 12 |
+
$ops_headers
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
namespace at {
|
| 16 |
+
namespace native {
|
| 17 |
+
|
| 18 |
+
// This file contains a number of kernels for aten functions that are fully code-generated.
|
| 19 |
+
// TODO: rename this file to something more generic.
|
| 20 |
+
|
| 21 |
+
namespace {
|
| 22 |
+
at::Tensor clone_arg(const at::Tensor& t) {
|
| 23 |
+
return t.clone();
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
std::vector<at::Tensor> clone_arg(const at::TensorList& t_list) {
|
| 27 |
+
std::vector<at::Tensor> out(t_list.size());
|
| 28 |
+
for (const auto& i : c10::irange(t_list.size())) {
|
| 29 |
+
out[i] = t_list[i].clone();
|
| 30 |
+
}
|
| 31 |
+
return out;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// duped with gen_resize_out_helper from structured kernels
|
| 35 |
+
void copy_arg(const at::Tensor& dst, const at::Tensor& src) {
|
| 36 |
+
TORCH_CHECK(src.dtype() == dst.dtype(),
|
| 37 |
+
"Expected out tensor to have dtype ", src.dtype(), ", but got ", dst.dtype(), " instead");
|
| 38 |
+
TORCH_CHECK(src.device() == dst.device(),
|
| 39 |
+
"Expected out tensor to have device ", src.device(), ", but got ", dst.device(), " instead");
|
| 40 |
+
dst.copy_(src);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
void copy_arg(const at::TensorList& dst, const at::TensorList& src) {
|
| 44 |
+
TORCH_INTERNAL_ASSERT(dst.size() == src.size());
|
| 45 |
+
for (const auto& i : c10::irange(dst.size())) {
|
| 46 |
+
copy_arg(dst[i], src[i]);
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// TODO: this doesn't handle restriding empty tensors correctly; see
|
| 51 |
+
// gen_resize_out_helper for the correct algorithm
|
| 52 |
+
|
| 53 |
+
void resize_out_helper(const at::Tensor& dst, const at::Tensor& src) {
|
| 54 |
+
at::native::resize_output(dst, src.sizes());
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) {
|
| 58 |
+
TORCH_INTERNAL_ASSERT(dst.size() == src.size());
|
| 59 |
+
for (const auto& i : c10::irange(dst.size())) {
|
| 60 |
+
at::native::resize_output(dst[i], src[i].sizes());
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
${CompositeViewCopyKernel_Definitions}
|
| 67 |
+
|
| 68 |
+
${GeneratedCompositeFunctional_Definitions}
|
| 69 |
+
|
| 70 |
+
${GeneratedCompositeOut_Definitions}
|
| 71 |
+
|
| 72 |
+
} // namespace native
|
| 73 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunction.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// ${generated_comment}
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace ${dispatch_namespace} {
|
| 19 |
+
|
| 20 |
+
${dispatch_namespaced_declarations}
|
| 21 |
+
|
| 22 |
+
} // namespace ${dispatch_namespace}
|
| 23 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
${inline_headers}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// ${generated_comment}
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_${dispatch_namespace}_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
${DispatchKeyFunctions_inl_includes}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
${dispatch_namespaced_declarations}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// ${generated_comment}
|
| 2 |
+
${includes}
|
| 3 |
+
${native_functions_include}
|
| 4 |
+
|
| 5 |
+
namespace {
|
| 6 |
+
${helper_fns}
|
| 7 |
+
} // namespace
|
| 8 |
+
|
| 9 |
+
${namespace_prologue}
|
| 10 |
+
|
| 11 |
+
${native_function_definitions}
|
| 12 |
+
|
| 13 |
+
${namespace_epilogue}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// an external backend might generate file within its code tree
|
| 4 |
+
// and check all the source files within the tree with clang-format.
|
| 5 |
+
// so, disable it since the backend might have a different config.
|
| 6 |
+
// clang-format off
|
| 7 |
+
|
| 8 |
+
// ${generated_comment}
|
| 9 |
+
|
| 10 |
+
#include <ATen/Tensor.h>
|
| 11 |
+
|
| 12 |
+
${namespace_prologue}
|
| 13 |
+
|
| 14 |
+
struct ${class_name} {
|
| 15 |
+
|
| 16 |
+
${dispatch_declarations}
|
| 17 |
+
|
| 18 |
+
};
|
| 19 |
+
${namespace_epilogue}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Function.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// ${generated_comment}
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <optional>
|
| 17 |
+
|
| 18 |
+
${static_dispatch_ops_headers}
|
| 19 |
+
|
| 20 |
+
${operator_includes}
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
${function_definitions}
|
| 25 |
+
|
| 26 |
+
}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// ${generated_comment}
|
| 4 |
+
|
| 5 |
+
#include <ATen/Tensor.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
namespace functionalization {
|
| 9 |
+
|
| 10 |
+
enum class InverseReturnMode {
|
| 11 |
+
/// Specifies that functional inverses should always return a view.
|
| 12 |
+
AlwaysView,
|
| 13 |
+
/// Specifies that functional inverses should always return a non-view / copy.
|
| 14 |
+
NeverView,
|
| 15 |
+
/// Specifies that functional inverses should return a view unless a (copying) scatter
|
| 16 |
+
/// inverse exists, in which case that will be used instead.
|
| 17 |
+
/// This avoids as_strided() calls that can be difficult for subclasses to handle.
|
| 18 |
+
ViewOrScatterInverse,
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
struct FunctionalInverses {
|
| 22 |
+
|
| 23 |
+
${view_inverse_declarations}
|
| 24 |
+
|
| 25 |
+
// NB: These are not generated! They're manually implemented in the template.
|
| 26 |
+
// TODO: Change codegen to generate these. See the following link:
|
| 27 |
+
// https://github.com/pytorch/pytorch/blob/main/torchgen/model.py#L2583-L2585
|
| 28 |
+
static at::Tensor chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim);
|
| 29 |
+
static at::Tensor narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length);
|
| 30 |
+
|
| 31 |
+
};
|
| 32 |
+
}
|
| 33 |
+
}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.cpp
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <array>
|
| 2 |
+
|
| 3 |
+
#include <ATen/Functions.h>
|
| 4 |
+
#include <ATen/Utils.h>
|
| 5 |
+
#include <c10/core/Allocator.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
Tensor TensorMaker::make_tensor() {
|
| 10 |
+
AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
|
| 11 |
+
tracer::impl::NoTracerDispatchMode tracer_guard{};
|
| 12 |
+
|
| 13 |
+
check_size_nonnegative(sizes_);
|
| 14 |
+
|
| 15 |
+
TORCH_CHECK_VALUE(
|
| 16 |
+
!deleter_ || !ctx_,
|
| 17 |
+
"The deleter and context arguments are mutually exclusive.");
|
| 18 |
+
|
| 19 |
+
if (device_ == std::nullopt) {
|
| 20 |
+
device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
if (opts_.device().has_index()) {
|
| 24 |
+
// clang-format off
|
| 25 |
+
TORCH_CHECK_VALUE(
|
| 26 |
+
opts_.device() == *device_,
|
| 27 |
+
"Specified device ", opts_.device(), " does not match device of data ", *device_);
|
| 28 |
+
// clang-format on
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
std::size_t size_bytes = computeStorageSize();
|
| 32 |
+
|
| 33 |
+
DataPtr data_ptr{};
|
| 34 |
+
if (deleter_) {
|
| 35 |
+
data_ptr = makeDataPtrFromDeleter();
|
| 36 |
+
} else {
|
| 37 |
+
data_ptr = makeDataPtrFromContext();
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()");
|
| 41 |
+
Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizable=*/resizeable_};
|
| 42 |
+
|
| 43 |
+
Tensor tensor = detail::make_tensor<TensorImpl>(
|
| 44 |
+
std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
|
| 45 |
+
|
| 46 |
+
TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
|
| 47 |
+
if (strides_) {
|
| 48 |
+
tensor_impl->set_sizes_and_strides(sizes_, *strides_);
|
| 49 |
+
} else {
|
| 50 |
+
tensor_impl->set_sizes_contiguous(sizes_);
|
| 51 |
+
}
|
| 52 |
+
if (storage_offset_) {
|
| 53 |
+
tensor_impl->set_storage_offset(*storage_offset_);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
return tensor;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
std::size_t TensorMaker::computeStorageSize() const noexcept {
|
| 60 |
+
std::size_t itemsize = opts_.dtype().itemsize();
|
| 61 |
+
|
| 62 |
+
if (strides_) {
|
| 63 |
+
auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
|
| 64 |
+
if (storage_offset_) {
|
| 65 |
+
storage_size += storage_offset_.value();
|
| 66 |
+
}
|
| 67 |
+
return storage_size;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
std::size_t size = 1;
|
| 71 |
+
for (std::int64_t s : sizes_) {
|
| 72 |
+
size *= static_cast<std::size_t>(s);
|
| 73 |
+
}
|
| 74 |
+
auto storage_size = size * itemsize;
|
| 75 |
+
if (storage_offset_) {
|
| 76 |
+
storage_size += storage_offset_.value();
|
| 77 |
+
}
|
| 78 |
+
return storage_size;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
inline DataPtr TensorMaker::makeDataPtrFromDeleter() noexcept {
|
| 82 |
+
return InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept {
|
| 86 |
+
return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_};
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
IntArrayRef TensorMaker::makeTempSizes() const noexcept {
|
| 90 |
+
static std::int64_t zeros[5] = {0, 0, 0, 0, 0};
|
| 91 |
+
if (opts_.has_memory_format()) {
|
| 92 |
+
MemoryFormat format = *opts_.memory_format_opt();
|
| 93 |
+
if (format == MemoryFormat::ChannelsLast) {
|
| 94 |
+
return IntArrayRef(zeros, 4);
|
| 95 |
+
}
|
| 96 |
+
if (format == MemoryFormat::ChannelsLast3d) {
|
| 97 |
+
return IntArrayRef(zeros, 5);
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
return IntArrayRef(zeros, 1);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/Functions.h
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// ${generated_comment}
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 14 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 15 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 16 |
+
Consider including a specific operator from <ATen/ops/{my_operator}.h> and \
|
| 17 |
+
see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
// NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS]
|
| 21 |
+
//
|
| 22 |
+
// In ATen, certain generated headers files include the definitions of
|
| 23 |
+
// every single operator in PyTorch. Unfortunately this means every
|
| 24 |
+
// time an operator signature is updated or changed in
|
| 25 |
+
// native_functions.yaml, you (and every other PyTorch developer) need
|
| 26 |
+
// to recompile every source file that includes any of these headers.
|
| 27 |
+
//
|
| 28 |
+
// To break up these header dependencies, and improve incremental
|
| 29 |
+
// build times for all PyTorch developers. These headers are split
|
| 30 |
+
// into per-operator headers in the `ATen/ops` folder. This limits
|
| 31 |
+
// incremental builds to only changes to methods of `Tensor`, or files
|
| 32 |
+
// that use the specific operator being changed. With `at::sum` as an
|
| 33 |
+
// example, you should include
|
| 34 |
+
//
|
| 35 |
+
// <ATen/ops/sum.h> // instead of ATen/Functions.h
|
| 36 |
+
// <ATen/ops/sum_native.h> // instead of ATen/NativeFunctions.h
|
| 37 |
+
// <ATen/ops/sum_ops.h> // instead of ATen/Operators.h
|
| 38 |
+
// <ATen/ops/sum_cpu_dispatch.h> // instead of ATen/CPUFunctions.h
|
| 39 |
+
//
|
| 40 |
+
// However, even if you're careful to use this in your own code.
|
| 41 |
+
// `Functions.h` might be included indirectly through another header
|
| 42 |
+
// without you realising. To avoid this, you can add
|
| 43 |
+
//
|
| 44 |
+
// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
| 45 |
+
//
|
| 46 |
+
// to the top of your source file. This way any time the non-specific
|
| 47 |
+
// headers are included, the compiler will error out.
|
| 48 |
+
//
|
| 49 |
+
// Also, be aware that `ops` are not available in all build
|
| 50 |
+
// configurations (namely fb-internal) so you must guard these
|
| 51 |
+
// includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g.
|
| 52 |
+
//
|
| 53 |
+
// #ifndef AT_PER_OPERATOR_HEADERS
|
| 54 |
+
// #include <ATen/Functions.h>
|
| 55 |
+
// #else
|
| 56 |
+
// #include <ATen/ops/sum.h>
|
| 57 |
+
// #endif
|
| 58 |
+
|
| 59 |
+
#include <ATen/Context.h>
|
| 60 |
+
#include <ATen/DeviceGuard.h>
|
| 61 |
+
#include <ATen/TensorUtils.h>
|
| 62 |
+
#include <ATen/TracerMode.h>
|
| 63 |
+
#include <ATen/core/Generator.h>
|
| 64 |
+
#include <ATen/core/Reduction.h>
|
| 65 |
+
#include <c10/core/SymInt.h>
|
| 66 |
+
#include <ATen/core/Tensor.h>
|
| 67 |
+
#include <c10/core/Scalar.h>
|
| 68 |
+
#include <c10/core/Storage.h>
|
| 69 |
+
#include <c10/core/TensorOptions.h>
|
| 70 |
+
#include <c10/util/Deprecated.h>
|
| 71 |
+
#include <optional>
|
| 72 |
+
#include <c10/util/OptionalArrayRef.h>
|
| 73 |
+
|
| 74 |
+
#include <ATen/ops/from_blob.h>
|
| 75 |
+
#include <ATen/ops/tensor.h>
|
| 76 |
+
|
| 77 |
+
${Functions_includes}
|
| 78 |
+
|
| 79 |
+
namespace at {
|
| 80 |
+
|
| 81 |
+
${Functions_declarations}
|
| 82 |
+
|
| 83 |
+
// Special C++ only overloads for std()-like functions (See gh-40287)
|
| 84 |
+
// These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
|
| 85 |
+
// So, for example std(0) would select the std(unbiased=False) overload
|
| 86 |
+
TORCH_API inline Tensor var(const Tensor& self, int dim) {
|
| 87 |
+
return at::var(self, IntArrayRef{dim});
|
| 88 |
+
}
|
| 89 |
+
TORCH_API inline std::tuple<Tensor, Tensor> var_mean(const Tensor& self, int dim) {
|
| 90 |
+
return at::var_mean(self, IntArrayRef{dim});
|
| 91 |
+
}
|
| 92 |
+
TORCH_API inline Tensor std(const Tensor& self, int dim) {
|
| 93 |
+
return at::std(self, IntArrayRef{dim});
|
| 94 |
+
}
|
| 95 |
+
TORCH_API inline std::tuple<Tensor, Tensor> std_mean(const Tensor& self, int dim) {
|
| 96 |
+
return at::std_mean(self, IntArrayRef{dim});
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
inline int64_t numel(const Tensor& tensor) {
|
| 100 |
+
return tensor.numel();
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
inline int64_t size(const Tensor& tensor, int64_t dim) {
|
| 104 |
+
return tensor.size(dim);
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
inline int64_t stride(const Tensor& tensor, int64_t dim) {
|
| 108 |
+
return tensor.stride(dim);
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
inline bool is_complex(const Tensor& tensor) {
|
| 112 |
+
return tensor.is_complex();
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
inline bool is_floating_point(const Tensor& tensor) {
|
| 116 |
+
return tensor.is_floating_point();
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline bool is_signed(const Tensor& tensor) {
|
| 120 |
+
return tensor.is_signed();
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
inline bool is_inference(const Tensor& tensor) {
|
| 124 |
+
return tensor.is_inference();
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
inline bool _is_zerotensor(const Tensor& tensor) {
|
| 128 |
+
return tensor._is_zerotensor();
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
inline bool is_conj(const Tensor& tensor) {
|
| 132 |
+
return tensor.is_conj();
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
inline Tensor conj(const Tensor& tensor) {
|
| 136 |
+
return tensor.conj();
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
inline bool is_neg(const Tensor& tensor) {
|
| 140 |
+
return tensor.is_neg();
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyIr.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// This file contains autogenerated LazyTensor IR nodes
|
| 4 |
+
${lazy_ir_sysinc}
|
| 5 |
+
${lazy_ir_inc}
|
| 6 |
+
|
| 7 |
+
${namespace_prologue}
|
| 8 |
+
using at::operator<<;
|
| 9 |
+
|
| 10 |
+
// kNullValue is used to contribute a static hash value any time
|
| 11 |
+
// a node has an Optional<Value> input that is nullopt. It is important
|
| 12 |
+
// to differentiate between HASH(std::nullopt, something) and HASH(something, std::nullopt),
|
| 13 |
+
// and using kNullValue in the hash function in the order of arguments
|
| 14 |
+
// serves this purpose.
|
| 15 |
+
static const torch::lazy::Value kNullValue = torch::lazy::Value();
|
| 16 |
+
|
| 17 |
+
${ir_declarations}
|
| 18 |
+
|
| 19 |
+
${namespace_epilogue}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
${lazy_non_native_ir_inc}
|
| 4 |
+
|
| 5 |
+
// This file contains autogenerated LazyTensor Non Native IR nodes
|
| 6 |
+
|
| 7 |
+
${namespace_prologue}
|
| 8 |
+
|
| 9 |
+
${non_native_ir_nodes}
|
| 10 |
+
|
| 11 |
+
${namespace_epilogue}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/MethodOperators.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// ${generated_comment}
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 14 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 15 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 16 |
+
#include <ATen/core/ATen_fwd.h>
|
| 17 |
+
|
| 18 |
+
${MethodOperators_includes}
|
| 19 |
+
|
| 20 |
+
namespace at {
|
| 21 |
+
namespace _ops {
|
| 22 |
+
${MethodOperators_declarations}
|
| 23 |
+
} // namespace _ops
|
| 24 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// ${generated_comment}
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <optional>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
${extra_includes}
|
| 16 |
+
|
| 17 |
+
${native_function_declarations}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// ${generated_comment}
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 14 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 15 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 16 |
+
Consider including a specific operator from <ATen/ops/{my_operator}_native.h> \
|
| 17 |
+
and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
#include <c10/core/Scalar.h>
|
| 21 |
+
#include <c10/core/Storage.h>
|
| 22 |
+
#include <c10/core/TensorOptions.h>
|
| 23 |
+
#include <c10/util/Deprecated.h>
|
| 24 |
+
#include <optional>
|
| 25 |
+
#include <c10/core/QScheme.h>
|
| 26 |
+
#include <ATen/core/Reduction.h>
|
| 27 |
+
#include <ATen/core/Tensor.h>
|
| 28 |
+
#include <tuple>
|
| 29 |
+
#include <vector>
|
| 30 |
+
|
| 31 |
+
${NativeFunctions_includes}
|
| 32 |
+
|
| 33 |
+
${NativeFunctions_declarations}
|
.venv/lib/python3.11/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// ${generated_comment}
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <optional>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/TensorIterator.h>
|
| 13 |
+
#include <ATen/TensorMeta.h>
|
| 14 |
+
#include <tuple>
|
| 15 |
+
#include <vector>
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace meta {
|
| 19 |
+
|
| 20 |
+
${meta_function_declarations}
|
| 21 |
+
|
| 22 |
+
} // namespace native
|
| 23 |
+
} // namespace at
|