|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Optional |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from onnxscript import ir, onnx_opset |
|
|
from onnxscript._internal import autocast |
|
|
|
|
|
|
|
|
class Tensor: |
|
|
"""An implementation of ONNX Tensors, based on a wrapper around numpy arrays. |
|
|
Serves to define overloaded ops with an ONNX/ONNXScript semantics. |
|
|
""" |
|
|
|
|
|
def __init__(self, nparray: Optional[np.ndarray], opset=None): |
|
|
if nparray is not None and not isinstance(nparray, np.ndarray): |
|
|
raise TypeError( |
|
|
f"Unexpected type {type(nparray)}. It must be a numpy array or None." |
|
|
) |
|
|
|
|
|
self._nparray = nparray |
|
|
|
|
|
self._opset: Any = opset or onnx_opset.opset18 |
|
|
|
|
|
@property |
|
|
def value(self) -> np.ndarray: |
|
|
if self._nparray is None: |
|
|
raise ValueError("Tensor does not have a value.") |
|
|
return self._nparray |
|
|
|
|
|
@property |
|
|
def rank(self) -> int: |
|
|
return len(self.value.shape) |
|
|
|
|
|
@property |
|
|
def is_scalar(self) -> bool: |
|
|
return self.rank == 0 |
|
|
|
|
|
@property |
|
|
def shape(self) -> tuple[int, ...]: |
|
|
return self.value.shape |
|
|
|
|
|
@property |
|
|
def dtype(self) -> np.dtype: |
|
|
return self.value.dtype |
|
|
|
|
|
@property |
|
|
def onnx_dtype(self) -> int: |
|
|
return ir.DataType.from_numpy(self.dtype) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"{self.__class__.__name__}({self.value!r})" |
|
|
|
|
|
def __bool__(self) -> bool: |
|
|
return bool(self.value) |
|
|
|
|
|
def __int__(self) -> int: |
|
|
return int(self.value) |
|
|
|
|
|
def __float__(self) -> float: |
|
|
return float(self.value) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return self.shape[0] |
|
|
|
|
|
def __index__(self) -> int: |
|
|
return self.value.__index__() |
|
|
|
|
|
def __getitem__(self, index): |
|
|
op = self._opset |
|
|
if op.version < 13: |
|
|
raise RuntimeError("Indexing requires opset 13 or later.") |
|
|
if not isinstance(index, tuple): |
|
|
|
|
|
|
|
|
index = (index,) |
|
|
if len(index) > self.rank: |
|
|
raise ValueError( |
|
|
f"Number of indices {len(index)} is greater than rank {self.rank}" |
|
|
) |
|
|
|
|
|
|
|
|
index = [autocast.cast_pyvalue_to_os_tensor(x) for x in index] |
|
|
|
|
|
shape = self.shape |
|
|
sliced_indices = [] |
|
|
scalar_indices = [] |
|
|
to_squeeze = [] |
|
|
non_scalar_indices = [] |
|
|
for axis_, s in enumerate(index): |
|
|
if isinstance(s, slice): |
|
|
if s.start is None and s.stop is None and s.step is None: |
|
|
continue |
|
|
if s.step is None or s.step > 0: |
|
|
sliced_indices.append( |
|
|
[ |
|
|
s.start or 0, |
|
|
s.stop if s.stop is not None else shape[axis_], |
|
|
axis_, |
|
|
s.step or 1, |
|
|
] |
|
|
) |
|
|
else: |
|
|
sliced_indices.append( |
|
|
[ |
|
|
s.start if s.start is not None else (shape[axis_] - 1), |
|
|
s.stop if s.stop is not None else -(shape[axis_] + 1), |
|
|
axis_, |
|
|
s.step, |
|
|
] |
|
|
) |
|
|
elif isinstance(s, Tensor): |
|
|
if s.is_scalar: |
|
|
scalar_indices.append([s, s + 1, axis_, 1]) |
|
|
to_squeeze.append(axis_) |
|
|
else: |
|
|
non_scalar_indices.append((axis_, s)) |
|
|
else: |
|
|
raise TypeError(f"Unexpected type {type(s)}: slice or int expected.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not (sliced_indices or scalar_indices or non_scalar_indices): |
|
|
|
|
|
return op.Identity(self) |
|
|
if not sliced_indices and len(scalar_indices) == 1: |
|
|
|
|
|
|
|
|
axis = to_squeeze[0] |
|
|
index_value = index[axis] |
|
|
|
|
|
result = op.Gather(self, index_value, axis=axis) |
|
|
elif sliced_indices or scalar_indices: |
|
|
sliced_indices = sliced_indices + scalar_indices |
|
|
indices = np.array(sliced_indices, dtype=np.int64).T |
|
|
starts = Tensor(indices[0]) |
|
|
ends = Tensor(indices[1]) |
|
|
axes = Tensor(indices[2]) |
|
|
steps = Tensor(indices[3]) |
|
|
result = op.Slice(self, starts, ends, axes, steps) |
|
|
if to_squeeze: |
|
|
result = Tensor(np.squeeze(result.value, axis=tuple(to_squeeze))) |
|
|
else: |
|
|
result = self |
|
|
for axis, value in non_scalar_indices: |
|
|
result = op.Gather(result, value, axis=axis) |
|
|
|
|
|
return result |
|
|
|
|
|
def __mod__(self, other): |
|
|
if self.onnx_dtype in { |
|
|
ir.DataType.FLOAT, |
|
|
ir.DataType.DOUBLE, |
|
|
ir.DataType.FLOAT16, |
|
|
ir.DataType.BFLOAT16, |
|
|
}: |
|
|
return self._opset.Mod(self, other, fmod=1) |
|
|
return self._opset.Mod(self, other) |
|
|
|
|
|
def __ne__(self, other): |
|
|
temp = self._opset.Equal(self, other) |
|
|
return self._opset.Not(temp) |
|
|
|
|
|
def __neg__(self): |
|
|
return self._opset.Neg(self) |
|
|
|
|
|
def __add__(self, other): |
|
|
return self._opset.Add(self, other) |
|
|
|
|
|
def __radd__(self, other): |
|
|
return self._opset.Add(other, self) |
|
|
|
|
|
def __and__(self, other): |
|
|
return self._opset.And(self, other) |
|
|
|
|
|
def __rand__(self, other): |
|
|
return self._opset.And(other, self) |
|
|
|
|
|
def __mul__(self, other): |
|
|
return self._opset.Mul(self, other) |
|
|
|
|
|
def __rmul__(self, other): |
|
|
return self._opset.Mul(other, self) |
|
|
|
|
|
def __matmul__(self, other): |
|
|
return self._opset.MatMul(self, other) |
|
|
|
|
|
def __or__(self, other): |
|
|
return self._opset.Or(self, other) |
|
|
|
|
|
def __pow__(self, other): |
|
|
return self._opset.Pow(self, other) |
|
|
|
|
|
def __sub__(self, other): |
|
|
return self._opset.Sub(self, other) |
|
|
|
|
|
def __rsub__(self, other): |
|
|
return self._opset.Sub(other, self) |
|
|
|
|
|
def __truediv__(self, other): |
|
|
return self._opset.Div(self, other) |
|
|
|
|
|
def __lt__(self, other): |
|
|
return self._opset.Less(self, other) |
|
|
|
|
|
def __le__(self, other): |
|
|
return self._opset.LessOrEqual(self, other) |
|
|
|
|
|
def __eq__(self, other): |
|
|
return self._opset.Equal(self, other) |
|
|
|
|
|
def __ge__(self, other): |
|
|
return self._opset.GreaterOrEqual(self, other) |
|
|
|
|
|
def __gt__(self, other): |
|
|
return self._opset.Greater(self, other) |
|
|
|