|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from torch._prims_common import ( |
|
|
DimsSequenceType, |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND, |
|
|
getnvFuserDtype, |
|
|
make_contiguous_strides_for, |
|
|
ShapeType, |
|
|
TensorLikeType, |
|
|
) |
|
|
|
|
|
from torch._prims_common.wrappers import ( |
|
|
backwards_not_supported, |
|
|
elementwise_type_promotion_wrapper, |
|
|
) |
|
|
|
|
|
nvprim_namespace = "nvprims" |
|
|
nvprim = torch.library.Library(nvprim_namespace, "DEF") |
|
|
nvprim_impl = torch.library.Library( |
|
|
nvprim_namespace, "IMPL", "CompositeExplicitAutograd" |
|
|
) |
|
|
nvprim_implicit_impl = torch.library.Library( |
|
|
nvprim_namespace, "IMPL", "CompositeImplicitAutograd" |
|
|
) |
|
|
nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd") |
|
|
nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta") |
|
|
|
|
|
nvprim_names = [ |
|
|
"abs", |
|
|
"acos", |
|
|
"asin", |
|
|
"atan", |
|
|
"atanh", |
|
|
"cos", |
|
|
"cosh", |
|
|
"bitwise_not", |
|
|
"ceil", |
|
|
"erf", |
|
|
"erfc", |
|
|
"exp", |
|
|
"expm1", |
|
|
"floor", |
|
|
"imag", |
|
|
"isfinite", |
|
|
"lgamma", |
|
|
"log", |
|
|
"log1p", |
|
|
"log2", |
|
|
"log10", |
|
|
"real", |
|
|
"reciprocal", |
|
|
"neg", |
|
|
"round", |
|
|
"rsqrt", |
|
|
"sign", |
|
|
"sin", |
|
|
"sinh", |
|
|
"sqrt", |
|
|
"tan", |
|
|
"tanh", |
|
|
"transpose", |
|
|
"trunc", |
|
|
"add", |
|
|
"atan2", |
|
|
"bitwise_and", |
|
|
"bitwise_or", |
|
|
"bitwise_xor", |
|
|
"div", |
|
|
"eq", |
|
|
"fmod", |
|
|
"ge", |
|
|
"gt", |
|
|
"le", |
|
|
"lt", |
|
|
"mul", |
|
|
"ne", |
|
|
"pow", |
|
|
"remainder", |
|
|
"sub", |
|
|
"squeeze", |
|
|
"view_of", |
|
|
"broadcast_in_dim", |
|
|
"where", |
|
|
"convert_element_type", |
|
|
"sum", |
|
|
"var", |
|
|
"amax", |
|
|
"amin", |
|
|
] |
|
|
|
|
|
_nvfuser_impls: Dict[str, Any] = {} |
|
|
|
|
|
_nvfuser_unary_ops = { |
|
|
"abs", |
|
|
"acos", |
|
|
"asin", |
|
|
"atan", |
|
|
"atanh", |
|
|
"cos", |
|
|
"cosh", |
|
|
"bitwise_not", |
|
|
"ceil", |
|
|
"erf", |
|
|
"erfc", |
|
|
"exp", |
|
|
"expm1", |
|
|
"floor", |
|
|
"imag", |
|
|
"isfinite", |
|
|
"lgamma", |
|
|
"log", |
|
|
"log1p", |
|
|
"log2", |
|
|
"log10", |
|
|
"reciprocal", |
|
|
"neg", |
|
|
"real", |
|
|
"round", |
|
|
"rsqrt", |
|
|
"sign", |
|
|
"sin", |
|
|
"sinh", |
|
|
"sqrt", |
|
|
"tan", |
|
|
"tanh", |
|
|
"trunc", |
|
|
} |
|
|
|
|
|
|
|
|
def _assert_nvfuser_op_exists(fname: str): |
|
|
try: |
|
|
from torch._C._nvfuser import FusionDefinition as fd |
|
|
|
|
|
assert getattr(fd.Operators, fname) |
|
|
except ImportError: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
for fname in _nvfuser_unary_ops: |
|
|
exec( |
|
|
f""" |
|
|
# Ensure that the nvfuser implementation exists |
|
|
_assert_nvfuser_op_exists("{fname}") |
|
|
|
|
|
def _{fname}_nvfuser(fd, a): |
|
|
return fd.ops.{fname}(a) # type: ignore[attr-defined] |
|
|
|
|
|
_nvfuser_impls["{fname}"] = _{fname}_nvfuser |
|
|
""" |
|
|
) |
|
|
|
|
|
_nvfuser_binary_ops = { |
|
|
"add", |
|
|
"atan2", |
|
|
"bitwise_and", |
|
|
"bitwise_or", |
|
|
"bitwise_xor", |
|
|
"div", |
|
|
"eq", |
|
|
"fmod", |
|
|
"ge", |
|
|
"gt", |
|
|
"le", |
|
|
"lt", |
|
|
"mul", |
|
|
"ne", |
|
|
"pow", |
|
|
"remainder", |
|
|
"sub", |
|
|
} |
|
|
|
|
|
for fname in _nvfuser_binary_ops: |
|
|
exec( |
|
|
f""" |
|
|
# Ensure that the nvfuser implementation exists |
|
|
_assert_nvfuser_op_exists("{fname}") |
|
|
|
|
|
def _{fname}_nvfuser(fd, a, b): |
|
|
return fd.ops.{fname}(a, b) # type: ignore[attr-defined] |
|
|
|
|
|
_nvfuser_impls["{fname}"] = _{fname}_nvfuser |
|
|
""" |
|
|
) |
|
|
|
|
|
_nvfuser_ternary_ops = { |
|
|
"where", |
|
|
} |
|
|
|
|
|
for fname in _nvfuser_ternary_ops: |
|
|
exec( |
|
|
f""" |
|
|
# Ensure that the nvfuser implementation exists |
|
|
_assert_nvfuser_op_exists("{fname}") |
|
|
|
|
|
def _{fname}_nvfuser(fd, a, b, c): |
|
|
return fd.ops.{fname}(a, b, c) # type: ignore[attr-defined] |
|
|
|
|
|
_nvfuser_impls["{fname}"] = _{fname}_nvfuser |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
def _native_batch_norm_nvfuser( |
|
|
fd, input, weight, bias, running_mean, running_var, training, momentum, eps |
|
|
): |
|
|
if weight is None: |
|
|
weight = fd.define_null_tensor() |
|
|
if bias is None: |
|
|
bias = fd.define_null_tensor() |
|
|
if running_mean is None: |
|
|
running_mean = fd.define_null_tensor() |
|
|
if running_var is None: |
|
|
running_var = fd.define_null_tensor() |
|
|
return fd.ops.batch_norm( |
|
|
input, |
|
|
weight, |
|
|
bias, |
|
|
running_mean, |
|
|
running_var, |
|
|
training, |
|
|
momentum, |
|
|
eps, |
|
|
) |
|
|
|
|
|
|
|
|
def _broadcast_in_dim_nvfuser( |
|
|
fd: Any, |
|
|
a: TensorLikeType, |
|
|
shape: ShapeType, |
|
|
broadcast_dimensions: ShapeType, |
|
|
): |
|
|
return fd.ops.broadcast_in_dim(a, shape, broadcast_dimensions) |
|
|
|
|
|
|
|
|
def _convert_element_type_nvfuser(fd: Any, a: TensorLikeType, dtype: torch.dtype): |
|
|
nvfuser_dtype = getnvFuserDtype(dtype) |
|
|
return fd.ops.cast(a, nvfuser_dtype) |
|
|
|
|
|
|
|
|
def _transpose_nvfuser(fd, a, permutation): |
|
|
return fd.ops.permute(a, permutation) |
|
|
|
|
|
|
|
|
def _squeeze_nvfuser(fd, a, a_shape, dimensions): |
|
|
for idx in reversed(sorted(dimensions)): |
|
|
a = fd.ops.squeeze(a, a_shape, idx) |
|
|
a_shape = a_shape[:idx] + a_shape[idx + 1 :] |
|
|
return a |
|
|
|
|
|
|
|
|
def _view_of_nvfuser(fd, a): |
|
|
return fd.ops.set(a) |
|
|
|
|
|
|
|
|
def _sum_nvfuser( |
|
|
fd: Any, |
|
|
a: TensorLikeType, |
|
|
dims: DimsSequenceType, |
|
|
): |
|
|
keep_dims = False |
|
|
output_dtype = torch._C._nvfuser.DataType.Null |
|
|
return fd.ops.sum(a, dims, keep_dims, output_dtype) |
|
|
|
|
|
|
|
|
def _var_nvfuser( |
|
|
fd: Any, |
|
|
a: TensorLikeType, |
|
|
dims: DimsSequenceType, |
|
|
*, |
|
|
correction: int, |
|
|
): |
|
|
keep_dims = False |
|
|
return fd.ops.var(a, dims, correction, keep_dims) |
|
|
|
|
|
|
|
|
def _var_mean_nvfuser( |
|
|
fd: Any, |
|
|
a: TensorLikeType, |
|
|
dims: DimsSequenceType, |
|
|
unbiased: Optional[bool] = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
correction: int, |
|
|
): |
|
|
|
|
|
assert unbiased is None |
|
|
|
|
|
|
|
|
keepdim = False |
|
|
return fd.ops.var_mean(a, dims, correction, keepdim) |
|
|
|
|
|
|
|
|
def _rand_like_nvfuser(fd: Any, a: TensorLikeType): |
|
|
return fd.ops.rand_like(a) |
|
|
|
|
|
|
|
|
def _amax_nvfuser( |
|
|
fd: Any, |
|
|
a: TensorLikeType, |
|
|
dims: DimsSequenceType, |
|
|
): |
|
|
keep_dims = False |
|
|
return fd.ops.max(a, dims, keep_dims) |
|
|
|
|
|
|
|
|
def _amin_nvfuser( |
|
|
fd: Any, |
|
|
a: TensorLikeType, |
|
|
dims: DimsSequenceType, |
|
|
): |
|
|
keep_dims = False |
|
|
return fd.ops.min(a, dims, keep_dims) |
|
|
|
|
|
|
|
|
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser |
|
|
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser |
|
|
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser |
|
|
_nvfuser_impls["transpose"] = _transpose_nvfuser |
|
|
_nvfuser_impls["squeeze"] = _squeeze_nvfuser |
|
|
_nvfuser_impls["view_of"] = _view_of_nvfuser |
|
|
_nvfuser_impls["rand_like"] = _rand_like_nvfuser |
|
|
_nvfuser_impls["sum"] = _sum_nvfuser |
|
|
_nvfuser_impls["var"] = _var_nvfuser |
|
|
_nvfuser_impls["var_mean"] = _var_mean_nvfuser |
|
|
_nvfuser_impls["amax"] = _amax_nvfuser |
|
|
_nvfuser_impls["amin"] = _amin_nvfuser |
|
|
|
|
|
|
|
|
def register_native_batch_norm(): |
|
|
"""This function is used to register the native_batch_norm function in torch.ops.nvprims module.""" |
|
|
name = "native_batch_norm" |
|
|
|
|
|
nvprim.define( |
|
|
f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, " |
|
|
+ "bool training, float momentum, float eps)" |
|
|
+ " -> (Tensor, Tensor, Tensor)" |
|
|
) |
|
|
|
|
|
def _prim_impl( |
|
|
input, weight, bias, running_mean, running_var, training, momentum, eps |
|
|
): |
|
|
return torch.native_batch_norm( |
|
|
input, weight, bias, running_mean, running_var, training, momentum, eps |
|
|
) |
|
|
|
|
|
nvprim_impl.impl(name, _prim_impl) |
|
|
nvprim_autograd_impl.impl( |
|
|
name, backwards_not_supported(torch.ops.nvprims.native_batch_norm.default) |
|
|
) |
|
|
|
|
|
prim_packet = torch.ops.nvprims.native_batch_norm |
|
|
prim = prim_packet.default |
|
|
for p in (prim_packet, prim): |
|
|
p.__doc__ = "Computes batch normalization." |
|
|
p.impl_nvfuser = _nvfuser_impls["native_batch_norm"] |
|
|
p.return_type = torch._prims_common.RETURN_TYPE.NEW |
|
|
|
|
|
|
|
|
def register_rand_like(): |
|
|
name = "rand_like" |
|
|
|
|
|
nvprim.define( |
|
|
"rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, " |
|
|
+ "Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" |
|
|
) |
|
|
|
|
|
def _meta_rand_like( |
|
|
self, |
|
|
*, |
|
|
dtype=None, |
|
|
layout=None, |
|
|
device=None, |
|
|
pin_memory=None, |
|
|
memory_format=None, |
|
|
): |
|
|
strides = make_contiguous_strides_for(self.shape) |
|
|
return torch._prims.TensorMeta( |
|
|
self, |
|
|
shape=self.shape, |
|
|
strides=strides, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
def _prim_impl( |
|
|
self, |
|
|
*, |
|
|
dtype=None, |
|
|
layout=None, |
|
|
device=None, |
|
|
pin_memory=None, |
|
|
memory_format=None, |
|
|
): |
|
|
return torch.rand_like( |
|
|
self, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
memory_format=memory_format, |
|
|
) |
|
|
|
|
|
nvprim_impl.impl(name, _prim_impl) |
|
|
nvprim_meta_impl.impl(name, _meta_rand_like) |
|
|
|
|
|
prim_packet = getattr(torch.ops.nvprims, name) |
|
|
prim = prim_packet.default |
|
|
|
|
|
nvprim_autograd_impl.impl(name, backwards_not_supported(prim)) |
|
|
|
|
|
for p in (prim_packet, prim): |
|
|
p.__doc__ = "Computes rand_like" |
|
|
p.impl_nvfuser = _nvfuser_impls["rand_like"] |
|
|
p.return_type = torch._prims_common.RETURN_TYPE.NEW |
|
|
|
|
|
|
|
|
def register_var_mean(): |
|
|
"""This function is used to register the var_mean function in torch.ops.nvprims module.""" |
|
|
name = "var_mean.main" |
|
|
|
|
|
|
|
|
nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)") |
|
|
|
|
|
|
|
|
nvprim.define( |
|
|
f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, int? correction=None)" |
|
|
+ " -> (Tensor, Tensor)" |
|
|
) |
|
|
|
|
|
|
|
|
def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None): |
|
|
if torch._prims_common.is_complex_dtype(inp.dtype): |
|
|
output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype) |
|
|
else: |
|
|
output_dtype = inp.dtype |
|
|
var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype) |
|
|
mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype) |
|
|
if keepdim: |
|
|
output_shape = [ |
|
|
inp.shape[i] if i not in dim else 1 for i in range(inp.ndim) |
|
|
] |
|
|
broadcast_dims = [i for i in range(inp.ndim) if i not in dim] |
|
|
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims) |
|
|
mean = torch.ops.nvprims.broadcast_in_dim( |
|
|
mean, output_shape, broadcast_dims |
|
|
) |
|
|
return (var, mean) |
|
|
|
|
|
|
|
|
def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None): |
|
|
correction = torch._prims_common.set_correction(unbiased, correction) |
|
|
return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim) |
|
|
|
|
|
nvprim_impl.impl(name, _prim_impl) |
|
|
nvprim_meta_impl.impl(name, _meta_var_mean) |
|
|
|
|
|
prim_packet = torch.ops.nvprims.var_mean |
|
|
prim = prim_packet.main |
|
|
|
|
|
def _unbiased_overload_impl(inp, unbiased): |
|
|
return prim(inp, dim=None, unbiased=unbiased) |
|
|
|
|
|
nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl) |
|
|
|
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("a",), |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, |
|
|
) |
|
|
def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None): |
|
|
correction = torch._prims_common.set_correction(unbiased, correction) |
|
|
|
|
|
if dim == () or dim == []: |
|
|
dim = None |
|
|
dim = torch._prims_common.reduction_dims(a.shape, dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch._prims_common.is_complex_dtype(a.dtype): |
|
|
raise NotImplementedError("Complex tensors are not supported") |
|
|
|
|
|
var_mean = prim(a, dim, correction=correction) |
|
|
|
|
|
if keepdim: |
|
|
output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)] |
|
|
broadcast_dims = [i for i in range(a.ndim) if i not in dim] |
|
|
var, mean = var_mean |
|
|
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims) |
|
|
mean = torch.ops.nvprims.broadcast_in_dim( |
|
|
mean, output_shape, broadcast_dims |
|
|
) |
|
|
var_mean = (var, mean) |
|
|
return var_mean |
|
|
|
|
|
def _var_mean_autograd( |
|
|
a, dim=None, unbiased=None, keepdim=False, *, correction=None |
|
|
): |
|
|
|
|
|
|
|
|
from torch._prims.context import NvfuserPrimsMode |
|
|
|
|
|
with NvfuserPrimsMode(): |
|
|
return backwards_not_supported(_var_mean_ref)( |
|
|
a, dim, unbiased, keepdim, correction=correction |
|
|
) |
|
|
|
|
|
nvprim_autograd_impl.impl(name, _var_mean_autograd) |
|
|
|
|
|
for p in (prim_packet, prim): |
|
|
p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument" |
|
|
p.impl_nvfuser = _nvfuser_impls["var_mean"] |
|
|
p.return_type = torch._prims_common.RETURN_TYPE.NEW |
|
|
|
|
|
|
|
|
def register_nvprims(): |
|
|
"""Registers all nvFuser primitives in the torch.ops.nvprims module.""" |
|
|
register_var_mean() |
|
|
register_native_batch_norm() |
|
|
register_rand_like() |
|
|
|
|
|
for name in nvprim_names: |
|
|
main_prim = getattr(torch.ops.prims, name) |
|
|
|
|
|
nvprim.define(main_prim.schema) |
|
|
nvprim_impl.impl(name, main_prim.prim_impl) |
|
|
nvprim_meta_impl.impl(name, main_prim.prim_meta_impl) |
|
|
|
|
|
prim_packet = getattr(torch.ops.nvprims, name) |
|
|
prim = prim_packet.default |
|
|
|
|
|
nvprim_autograd_impl.impl(name, backwards_not_supported(prim)) |
|
|
|
|
|
for p in (prim_packet, prim): |
|
|
p.__doc__ = main_prim.__doc__ |
|
|
p.impl_nvfuser = _nvfuser_impls[name] |
|
|
p.return_type = main_prim.return_type |
|
|
|