xiaoanyu123's picture
Add files using upload-large-folder tool
08157a5 verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Any, Optional
import numpy as np
from onnxscript import ir, onnx_opset
from onnxscript._internal import autocast
class Tensor:
"""An implementation of ONNX Tensors, based on a wrapper around numpy arrays.
Serves to define overloaded ops with an ONNX/ONNXScript semantics.
"""
def __init__(self, nparray: Optional[np.ndarray], opset=None):
if nparray is not None and not isinstance(nparray, np.ndarray):
raise TypeError(
f"Unexpected type {type(nparray)}. It must be a numpy array or None."
)
self._nparray = nparray
# FIXME(justinhuby): Create a better way to determine the opset version
self._opset: Any = opset or onnx_opset.opset18
@property
def value(self) -> np.ndarray:
if self._nparray is None:
raise ValueError("Tensor does not have a value.")
return self._nparray
@property
def rank(self) -> int:
return len(self.value.shape)
@property
def is_scalar(self) -> bool:
return self.rank == 0
@property
def shape(self) -> tuple[int, ...]:
return self.value.shape
@property
def dtype(self) -> np.dtype:
return self.value.dtype
@property
def onnx_dtype(self) -> int:
return ir.DataType.from_numpy(self.dtype)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.value!r})"
def __bool__(self) -> bool:
return bool(self.value)
def __int__(self) -> int:
return int(self.value)
def __float__(self) -> float:
return float(self.value)
def __len__(self) -> int:
return self.shape[0]
def __index__(self) -> int:
return self.value.__index__()
def __getitem__(self, index):
op = self._opset
if op.version < 13:
raise RuntimeError("Indexing requires opset 13 or later.")
if not isinstance(index, tuple):
# Normalize representation to a tuple.
# A single index-value is equivalent to a tuple with a single element.
index = (index,)
if len(index) > self.rank:
raise ValueError(
f"Number of indices {len(index)} is greater than rank {self.rank}"
)
# Promote integer indices to tensors of rank 0
index = [autocast.cast_pyvalue_to_os_tensor(x) for x in index]
# Process all elements in index
shape = self.shape
sliced_indices = []
scalar_indices = []
to_squeeze = []
non_scalar_indices = []
for axis_, s in enumerate(index):
if isinstance(s, slice):
if s.start is None and s.stop is None and s.step is None:
continue
if s.step is None or s.step > 0:
sliced_indices.append(
[
s.start or 0,
s.stop if s.stop is not None else shape[axis_],
axis_,
s.step or 1,
]
)
else:
sliced_indices.append(
[
s.start if s.start is not None else (shape[axis_] - 1),
s.stop if s.stop is not None else -(shape[axis_] + 1),
axis_,
s.step,
]
)
elif isinstance(s, Tensor):
if s.is_scalar:
scalar_indices.append([s, s + 1, axis_, 1])
to_squeeze.append(axis_)
else:
non_scalar_indices.append((axis_, s))
else:
raise TypeError(f"Unexpected type {type(s)}: slice or int expected.")
# Non-scalar-indexing requires the use of ONNX Gather operation.
# Slicing can be implemented efficiently using ONNX's Slice operation.
# Scalar-indexing can be implemented using either Gather or with the Slice operation.
# We map scalar-indexing into the Slice operation, except in the special case
# of a single scalar-index (with no other sliced_index), which we map directly
# to a Gather.
if not (sliced_indices or scalar_indices or non_scalar_indices):
# Edge case: no index specified. Eg. A[:, :]
return op.Identity(self)
if not sliced_indices and len(scalar_indices) == 1:
# Special case of indexing along a single axis: A[i], A[:, i], A[:, :, i] etc.
# promote integer input to tensor
axis = to_squeeze[0]
index_value = index[axis]
# use Gather to perform indexing
result = op.Gather(self, index_value, axis=axis)
elif sliced_indices or scalar_indices:
sliced_indices = sliced_indices + scalar_indices
indices = np.array(sliced_indices, dtype=np.int64).T
starts = Tensor(indices[0])
ends = Tensor(indices[1])
axes = Tensor(indices[2])
steps = Tensor(indices[3])
result = op.Slice(self, starts, ends, axes, steps)
if to_squeeze:
result = Tensor(np.squeeze(result.value, axis=tuple(to_squeeze)))
else:
result = self
for axis, value in non_scalar_indices:
result = op.Gather(result, value, axis=axis)
return result
def __mod__(self, other):
if self.onnx_dtype in {
ir.DataType.FLOAT,
ir.DataType.DOUBLE,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
}:
return self._opset.Mod(self, other, fmod=1)
return self._opset.Mod(self, other)
def __ne__(self, other):
temp = self._opset.Equal(self, other)
return self._opset.Not(temp)
def __neg__(self):
return self._opset.Neg(self)
def __add__(self, other):
return self._opset.Add(self, other)
def __radd__(self, other):
return self._opset.Add(other, self)
def __and__(self, other):
return self._opset.And(self, other)
def __rand__(self, other):
return self._opset.And(other, self)
def __mul__(self, other):
return self._opset.Mul(self, other)
def __rmul__(self, other):
return self._opset.Mul(other, self)
def __matmul__(self, other):
return self._opset.MatMul(self, other)
def __or__(self, other):
return self._opset.Or(self, other)
def __pow__(self, other):
return self._opset.Pow(self, other)
def __sub__(self, other):
return self._opset.Sub(self, other)
def __rsub__(self, other):
return self._opset.Sub(other, self)
def __truediv__(self, other):
return self._opset.Div(self, other)
def __lt__(self, other):
return self._opset.Less(self, other)
def __le__(self, other):
return self._opset.LessOrEqual(self, other)
def __eq__(self, other):
return self._opset.Equal(self, other)
def __ge__(self, other):
return self._opset.GreaterOrEqual(self, other)
def __gt__(self, other):
return self._opset.Greater(self, other)