|
|
|
|
|
|
|
|
from functools import partial |
|
|
import torch |
|
|
|
|
|
from .binary import ( |
|
|
_apply_native_binary, |
|
|
NATIVE_BINARY_FNS, |
|
|
NATIVE_INPLACE_BINARY_FNS, |
|
|
) |
|
|
from .core import is_masked_tensor, MaskedTensor, _get_data, _masks_match, _maybe_get_mask |
|
|
from .passthrough import ( |
|
|
_apply_pass_through_fn, |
|
|
PASSTHROUGH_FNS |
|
|
) |
|
|
from .reductions import ( |
|
|
_apply_reduction, |
|
|
NATIVE_REDUCE_FNS, |
|
|
TORCH_REDUCE_FNS, |
|
|
TENSOR_REDUCE_FNS, |
|
|
) |
|
|
from .unary import ( |
|
|
_apply_native_unary, |
|
|
NATIVE_UNARY_FNS, |
|
|
NATIVE_INPLACE_UNARY_FNS, |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = [] |
|
|
|
|
|
|
|
|
def _check_args_kwargs_length(args, kwargs, error_prefix, len_args=None, len_kwargs=None): |
|
|
if len_args is not None and len_args != len(args): |
|
|
raise ValueError(f"{error_prefix}: len(args) must be {len_args} but got {len(args)}") |
|
|
if len_kwargs is not None and len_kwargs != len(kwargs): |
|
|
raise ValueError(f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}") |
|
|
|
|
|
|
|
|
class _MaskedContiguous(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, input): |
|
|
if not is_masked_tensor(input): |
|
|
raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.") |
|
|
|
|
|
if input.is_contiguous(): |
|
|
return input |
|
|
|
|
|
data = input.get_data() |
|
|
mask = input.get_mask() |
|
|
|
|
|
return MaskedTensor(data.contiguous(), mask.contiguous()) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
return grad_output |
|
|
|
|
|
|
|
|
class _MaskedToDense(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, input): |
|
|
if not is_masked_tensor(input): |
|
|
raise ValueError("MaskedToDense forward: input must be a MaskedTensor.") |
|
|
|
|
|
if input.layout == torch.strided: |
|
|
return input |
|
|
|
|
|
ctx.layout = input.layout |
|
|
data = input.get_data() |
|
|
mask = input.get_mask() |
|
|
|
|
|
return MaskedTensor(data.to_dense(), mask.to_dense()) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
layout = ctx.layout |
|
|
|
|
|
if layout == torch.sparse_coo: |
|
|
return grad_output.to_sparse_coo() |
|
|
elif layout == torch.sparse_csr: |
|
|
return grad_output.to_sparse_csr() |
|
|
elif layout == torch.strided: |
|
|
return grad_output.to_dense() |
|
|
raise ValueError("to_dense: Unsupported input layout: ", layout) |
|
|
|
|
|
|
|
|
class _MaskedToSparse(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, input): |
|
|
if not is_masked_tensor(input): |
|
|
raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.") |
|
|
|
|
|
|
|
|
if input.layout == torch.sparse_coo: |
|
|
return input |
|
|
|
|
|
data = input.get_data() |
|
|
mask = input.get_mask() |
|
|
sparse_mask = mask.to_sparse_coo().coalesce() |
|
|
sparse_data = data.sparse_mask(sparse_mask) |
|
|
|
|
|
return MaskedTensor(sparse_data, sparse_mask) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
return grad_output.to_dense() |
|
|
|
|
|
|
|
|
class _MaskedToSparseCsr(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, input): |
|
|
if not is_masked_tensor(input): |
|
|
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.") |
|
|
|
|
|
if input._masked_data.ndim != 2: |
|
|
raise ValueError(f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}") |
|
|
|
|
|
if input.layout == torch.sparse_csr: |
|
|
return input |
|
|
|
|
|
data = input.get_data() |
|
|
mask = input.get_mask() |
|
|
sparse_mask = mask.to_sparse_csr() |
|
|
sparse_data = data.sparse_mask(sparse_mask) |
|
|
|
|
|
return MaskedTensor(sparse_data, sparse_mask) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
return grad_output.to_dense() |
|
|
|
|
|
|
|
|
class _MaskedWhere(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, cond, self, other): |
|
|
ctx.mark_non_differentiable(cond) |
|
|
ctx.save_for_backward(cond) |
|
|
return torch.ops.aten.where(cond, self, other) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
(cond,) = ctx.saved_tensors |
|
|
|
|
|
def masked_out_like(mt): |
|
|
return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool()) |
|
|
|
|
|
return ( |
|
|
None, |
|
|
torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)), |
|
|
torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output), |
|
|
) |
|
|
|
|
|
|
|
|
_MASKEDTENSOR_FUNCTION_TABLE = {} |
|
|
|
|
|
_function_fn_apply_map = { |
|
|
(tuple(NATIVE_REDUCE_FNS), tuple(TORCH_REDUCE_FNS), tuple(TENSOR_REDUCE_FNS)): _apply_reduction, |
|
|
} |
|
|
|
|
|
for fn_map_list, apply_fn in _function_fn_apply_map.items(): |
|
|
for fn_map in fn_map_list: |
|
|
for fn in fn_map: |
|
|
_MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn) |
|
|
|
|
|
|
|
|
def register_function_func(ops): |
|
|
""" |
|
|
Used for registering a new __torch_function__ function to MaskedTensor |
|
|
Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs) |
|
|
|
|
|
The code to register a new function looks like: |
|
|
|
|
|
@register_function_func(list_of_ops) |
|
|
def foo(func, *args, **kwargs): |
|
|
<implementation> |
|
|
""" |
|
|
def wrapper(func): |
|
|
for op in ops: |
|
|
_MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op) |
|
|
return wrapper |
|
|
|
|
|
|
|
|
@register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS) |
|
|
def _general_function_reductions(func, *args, **kwargs): |
|
|
return _apply_reduction(func, *args, **kwargs) |
|
|
|
|
|
|
|
|
@register_function_func([torch.Tensor.where, torch.where]) |
|
|
def _function_where(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0) |
|
|
return _MaskedWhere.apply(*args) |
|
|
|
|
|
|
|
|
@register_function_func([torch.Tensor.contiguous]) |
|
|
def _function_contiguous(func, *args, **kwargs): |
|
|
return _MaskedContiguous.apply(args[0]) |
|
|
|
|
|
|
|
|
@register_function_func([torch.Tensor.to_dense]) |
|
|
def _function_to_dense(func, *args, **kwargs): |
|
|
return _MaskedToDense.apply(args[0]) |
|
|
|
|
|
|
|
|
@register_function_func([torch.Tensor.to_sparse]) |
|
|
def _function_to_sparse(func, *args, **kwargs): |
|
|
return _MaskedToSparse.apply(args[0]) |
|
|
|
|
|
|
|
|
@register_function_func([torch.Tensor.to_sparse_csr]) |
|
|
def _function_to_sparse_csr(func, *args, **kwargs): |
|
|
return _MaskedToSparseCsr.apply(args[0]) |
|
|
|
|
|
|
|
|
_MASKEDTENSOR_DISPATCH_TABLE = {} |
|
|
|
|
|
def register_dispatch_func(aten_ops): |
|
|
""" |
|
|
Used for registering a new __torch_dispatch__ function to MaskedTensor |
|
|
Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs) |
|
|
|
|
|
The code to register a new function looks like: |
|
|
|
|
|
@register_dispatch_func(list_of_ops) |
|
|
def foo(func, *args, **kwargs): |
|
|
<implementation> |
|
|
""" |
|
|
def wrapper(func): |
|
|
for aten_op in aten_ops: |
|
|
_MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op) |
|
|
return wrapper |
|
|
|
|
|
|
|
|
@register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS) |
|
|
def _general_reduction(func, *args, **kwargs): |
|
|
return _apply_reduction(func, *args, **kwargs) |
|
|
|
|
|
|
|
|
@register_dispatch_func(PASSTHROUGH_FNS) |
|
|
def _general_passthrough(func, *args, **kwargs): |
|
|
return _apply_pass_through_fn(func, *args, **kwargs) |
|
|
|
|
|
|
|
|
@register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS) |
|
|
def _general_unary(func, *args, **kwargs): |
|
|
return _apply_native_unary(func, *args, **kwargs) |
|
|
|
|
|
|
|
|
@register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS) |
|
|
def _general_binary(func, *args, **kwargs): |
|
|
return _apply_native_binary(func, *args, **kwargs) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.stride]) |
|
|
def stride(func, *args, **kwargs): |
|
|
return None |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.sym_stride]) |
|
|
def sym_stride(func, *args, **kwargs): |
|
|
return None |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.prim.layout]) |
|
|
def layout(func, *args, **kwargs): |
|
|
return _get_data(args[0]).layout |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.is_contiguous]) |
|
|
def is_contiguous(func, *args, **kwargs): |
|
|
data = _get_data(args[0]) |
|
|
if data.is_sparse: |
|
|
raise ValueError( |
|
|
"MaskedTensors with sparse data do not have is_contiguous" |
|
|
) |
|
|
return func(data, *args[1:], **kwargs) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.is_strides_like_format]) |
|
|
def is_strides_like_format(func, *args, **kwargs): |
|
|
data = _get_data(args[0]) |
|
|
if data.is_sparse: |
|
|
raise ValueError( |
|
|
"MaskedTensors with sparse data do not have is_strides_like_format" |
|
|
) |
|
|
return func(data, *args[1:], **kwargs) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense]) |
|
|
def is_non_overlapping_and_dense(func, *args, **kwargs): |
|
|
data = _get_data(args[0]) |
|
|
if data.is_sparse: |
|
|
raise ValueError( |
|
|
"MaskedTensors with sparse data do not have is_non_overlapping_and_dense" |
|
|
) |
|
|
return func(data, *args[1:], **kwargs) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.contiguous]) |
|
|
def contiguous(func, *args, **kwargs): |
|
|
if _get_data(args[0]).is_sparse: |
|
|
raise ValueError( |
|
|
"MaskedTensors with sparse data do not have contiguous" |
|
|
) |
|
|
return _MaskedContiguous.apply(args[0]) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.new_empty_strided]) |
|
|
def new_empty_strided(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3) |
|
|
data = _get_data(args[0]) |
|
|
mask = _maybe_get_mask(args[0]) |
|
|
if tuple(args[1]) != tuple(data.size()): |
|
|
raise ValueError(f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()") |
|
|
if tuple(args[2]) != tuple(data.stride()): |
|
|
raise ValueError(f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()") |
|
|
return MaskedTensor(func(data, args[1], args[2], **kwargs), mask) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten._local_scalar_dense]) |
|
|
def _local_scalar_dense(func, *args, **kwargs): |
|
|
if not _maybe_get_mask(args[0]): |
|
|
raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor") |
|
|
return torch.ops.aten._local_scalar_dense(_get_data(args[0])) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone]) |
|
|
def _apply_fn_on_data(func, *args, **kwargs): |
|
|
return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0])) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten._to_copy]) |
|
|
def _to_copy(func, *args, **kwargs): |
|
|
new_data = func(_get_data(args[0]), *args[1:], **kwargs) |
|
|
return MaskedTensor(new_data, _maybe_get_mask(args[0])) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten._softmax]) |
|
|
def _softmax(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0) |
|
|
data = _get_data(args[0]) |
|
|
mask = _maybe_get_mask(args[0]) |
|
|
result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2) |
|
|
return MaskedTensor(result_data, mask) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.ones_like]) |
|
|
def ones_like(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1) |
|
|
result_data = func(_get_data(args[0]), **kwargs) |
|
|
return MaskedTensor(result_data, _maybe_get_mask(args[0])) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten._softmax_backward_data]) |
|
|
def _softmax_backward_data(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4) |
|
|
grad, output, dim, input_dtype = args |
|
|
if is_masked_tensor(grad) and is_masked_tensor(output): |
|
|
if not _masks_match(grad, output): |
|
|
raise ValueError("__torch_dispatch__, {func}: expected the masks of grad and output to match") |
|
|
grad_data = _get_data(grad) |
|
|
new_grad_data = torch.ops.aten._masked_softmax_backward( |
|
|
grad_data, |
|
|
_get_data(output), |
|
|
~_maybe_get_mask(grad), |
|
|
dim % grad_data.ndim, |
|
|
) |
|
|
res = MaskedTensor(new_grad_data, _maybe_get_mask(grad)) |
|
|
return res |
|
|
else: |
|
|
raise ValueError(f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors") |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.copy_]) |
|
|
def copy_(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2) |
|
|
if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])): |
|
|
raise ValueError("args[0] mask and args[1] mask must match but do not") |
|
|
func(_get_data(args[0]), _get_data(args[1])) |
|
|
return args[0] |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.where]) |
|
|
def where(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0) |
|
|
if not torch.is_tensor(args[0]): |
|
|
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") |
|
|
mx = args[1] |
|
|
my = args[2] |
|
|
if not is_masked_tensor(mx): |
|
|
mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool)) |
|
|
if not is_masked_tensor(my): |
|
|
my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool)) |
|
|
new_data = func(args[0], mx.get_data(), my.get_data()) |
|
|
new_mask = func(args[0], mx.get_mask(), my.get_mask()) |
|
|
return MaskedTensor(new_data, new_mask) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.to_sparse]) |
|
|
def to_sparse(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) |
|
|
if not torch.is_tensor(args[0]): |
|
|
raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor") |
|
|
mt = args[0] |
|
|
if not is_masked_tensor(mt): |
|
|
mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool)) |
|
|
if mt.is_sparse_coo(): |
|
|
return mt |
|
|
new_mask = func(_maybe_get_mask(args[0])).coalesce() |
|
|
new_data = _get_data(args[0]).sparse_mask(new_mask) |
|
|
return MaskedTensor(new_data, new_mask) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.to_sparse_csr]) |
|
|
def to_sparse_csr(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) |
|
|
if not torch.is_tensor(args[0]): |
|
|
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") |
|
|
mt = args[0] |
|
|
if not is_masked_tensor(mt): |
|
|
mt = MaskedTensor(mt, torch.ones_like(mt).bool()) |
|
|
if mt.is_sparse_csr(): |
|
|
return mt |
|
|
new_mask = func(_maybe_get_mask(args[0])) |
|
|
new_data = _get_data(args[0]).sparse_mask(new_mask) |
|
|
return MaskedTensor(new_data, new_mask) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten._to_dense]) |
|
|
def _to_dense(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) |
|
|
if not torch.is_tensor(args[0]): |
|
|
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") |
|
|
mt = args[0] |
|
|
if not is_masked_tensor(mt): |
|
|
mt = MaskedTensor(mt, torch.ones_like(mt).bool()) |
|
|
new_data = func(_get_data(args[0])) |
|
|
new_mask = func(_maybe_get_mask(args[0])) |
|
|
return MaskedTensor(new_data, new_mask) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten._indices]) |
|
|
def _indices(func, *args, **kwargs): |
|
|
|
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) |
|
|
data = _get_data(args[0]).indices() |
|
|
return MaskedTensor(data, torch.ones_like(data).bool()) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten._values]) |
|
|
def _values(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) |
|
|
data = _get_data(args[0]).values() |
|
|
return MaskedTensor(data, torch.ones_like(data).bool()) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors]) |
|
|
def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs): |
|
|
new_args = list(args) |
|
|
if is_masked_tensor(args[-1]): |
|
|
new_args[-1] = args[-1].get_data() |
|
|
if is_masked_tensor(args[-2]): |
|
|
new_args[-2] = args[-2].get_data() |
|
|
|
|
|
new_data = func(*new_args, **kwargs) |
|
|
new_args[-1] = torch.ones_like(new_args[-1]) |
|
|
new_mask = func(*new_args, **kwargs).bool() |
|
|
|
|
|
return MaskedTensor(new_data, new_mask) |
|
|
|
|
|
|
|
|
@register_dispatch_func([torch.ops.aten.is_same_size]) |
|
|
def is_same_size(func, *args, **kwargs): |
|
|
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2) |
|
|
return _get_data(args[0]).is_same_size(_get_data(args[1])) |
|
|
|