| |
| import torch |
| import torch._prims_common as utils |
|
|
| |
| from torch._decomp import register_decomposition |
| from torch._prims_common import TensorLikeType |
| from torch._prims_common.wrappers import out_wrapper |
| from torch._refs import _broadcast_shapes |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| __all__ = [ |
| |
| "bfloat16", |
| "bool", |
| "byte", |
| "cdouble", |
| "cfloat", |
| "chalf", |
| "char", |
| "double", |
| "float", |
| "half", |
| "int", |
| "long", |
| "short", |
| |
| "complex", |
| "polar", |
| ] |
|
|
|
|
| def _make_conversion_method(name: str, dtype: torch.dtype): |
| def fn( |
| self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format |
| ) -> TensorLikeType: |
| return self.to(dtype, memory_format=memory_format) |
|
|
| fn.__name__ = name |
| return fn |
|
|
|
|
| bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16) |
|
|
| bool = _make_conversion_method("bool", torch.bool) |
|
|
| byte = _make_conversion_method("byte", torch.uint8) |
|
|
| cdouble = _make_conversion_method("cdouble", torch.cdouble) |
|
|
| cfloat = _make_conversion_method("cfloat", torch.cfloat) |
|
|
| chalf = _make_conversion_method("chalf", torch.complex32) |
|
|
| char = _make_conversion_method("char", torch.int8) |
|
|
| double = _make_conversion_method("double", torch.double) |
|
|
| float = _make_conversion_method("float", torch.float) |
|
|
| half = _make_conversion_method("half", torch.half) |
|
|
| int = _make_conversion_method("int", torch.int) |
|
|
| long = _make_conversion_method("long", torch.long) |
|
|
| short = _make_conversion_method("short", torch.short) |
|
|
|
|
| @register_decomposition(torch._ops.ops.aten.complex) |
| |
| |
| @out_wrapper(exact_dtype=True) |
| def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: |
| allowed_dtypes = (torch.float32, torch.float64, torch.float16) |
| torch._check( |
| real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, |
| lambda: ( |
| f"Expected both inputs to be Half, Float or Double tensors but got " |
| f"{real.dtype} and {imag.dtype}" |
| ), |
| ) |
| torch._check( |
| real.dtype == imag.dtype, |
| lambda: ( |
| f"Expected object of scalar type {real.dtype} but got " |
| f"scalar type {imag.dtype} for second argument" |
| ), |
| ) |
| result_dtype = utils.corresponding_complex_dtype(real.dtype) |
| common_shape = _broadcast_shapes(real.shape, imag.shape) |
| result = real.new_empty( |
| common_shape, |
| dtype=result_dtype, |
| layout=real.layout, |
| device=real.device, |
| |
| ) |
| result.real = real |
| result.imag = imag |
| return result |
|
|
|
|
| @register_decomposition(torch._ops.ops.aten.polar) |
| |
| |
| @out_wrapper(exact_dtype=True) |
| def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType: |
| result = torch.complex(abs, angle) |
| result.real = abs * torch.cos(angle) |
| result.imag = abs * torch.sin(angle) |
| return result |
|
|