File size: 7,426 Bytes
08157a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Any, Optional
import numpy as np
from onnxscript import ir, onnx_opset
from onnxscript._internal import autocast
class Tensor:
"""An implementation of ONNX Tensors, based on a wrapper around numpy arrays.
Serves to define overloaded ops with an ONNX/ONNXScript semantics.
"""
def __init__(self, nparray: Optional[np.ndarray], opset=None):
if nparray is not None and not isinstance(nparray, np.ndarray):
raise TypeError(
f"Unexpected type {type(nparray)}. It must be a numpy array or None."
)
self._nparray = nparray
# FIXME(justinhuby): Create a better way to determine the opset version
self._opset: Any = opset or onnx_opset.opset18
@property
def value(self) -> np.ndarray:
if self._nparray is None:
raise ValueError("Tensor does not have a value.")
return self._nparray
@property
def rank(self) -> int:
return len(self.value.shape)
@property
def is_scalar(self) -> bool:
return self.rank == 0
@property
def shape(self) -> tuple[int, ...]:
return self.value.shape
@property
def dtype(self) -> np.dtype:
return self.value.dtype
@property
def onnx_dtype(self) -> int:
return ir.DataType.from_numpy(self.dtype)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.value!r})"
def __bool__(self) -> bool:
return bool(self.value)
def __int__(self) -> int:
return int(self.value)
def __float__(self) -> float:
return float(self.value)
def __len__(self) -> int:
return self.shape[0]
def __index__(self) -> int:
return self.value.__index__()
def __getitem__(self, index):
op = self._opset
if op.version < 13:
raise RuntimeError("Indexing requires opset 13 or later.")
if not isinstance(index, tuple):
# Normalize representation to a tuple.
# A single index-value is equivalent to a tuple with a single element.
index = (index,)
if len(index) > self.rank:
raise ValueError(
f"Number of indices {len(index)} is greater than rank {self.rank}"
)
# Promote integer indices to tensors of rank 0
index = [autocast.cast_pyvalue_to_os_tensor(x) for x in index]
# Process all elements in index
shape = self.shape
sliced_indices = []
scalar_indices = []
to_squeeze = []
non_scalar_indices = []
for axis_, s in enumerate(index):
if isinstance(s, slice):
if s.start is None and s.stop is None and s.step is None:
continue
if s.step is None or s.step > 0:
sliced_indices.append(
[
s.start or 0,
s.stop if s.stop is not None else shape[axis_],
axis_,
s.step or 1,
]
)
else:
sliced_indices.append(
[
s.start if s.start is not None else (shape[axis_] - 1),
s.stop if s.stop is not None else -(shape[axis_] + 1),
axis_,
s.step,
]
)
elif isinstance(s, Tensor):
if s.is_scalar:
scalar_indices.append([s, s + 1, axis_, 1])
to_squeeze.append(axis_)
else:
non_scalar_indices.append((axis_, s))
else:
raise TypeError(f"Unexpected type {type(s)}: slice or int expected.")
# Non-scalar-indexing requires the use of ONNX Gather operation.
# Slicing can be implemented efficiently using ONNX's Slice operation.
# Scalar-indexing can be implemented using either Gather or with the Slice operation.
# We map scalar-indexing into the Slice operation, except in the special case
# of a single scalar-index (with no other sliced_index), which we map directly
# to a Gather.
if not (sliced_indices or scalar_indices or non_scalar_indices):
# Edge case: no index specified. Eg. A[:, :]
return op.Identity(self)
if not sliced_indices and len(scalar_indices) == 1:
# Special case of indexing along a single axis: A[i], A[:, i], A[:, :, i] etc.
# promote integer input to tensor
axis = to_squeeze[0]
index_value = index[axis]
# use Gather to perform indexing
result = op.Gather(self, index_value, axis=axis)
elif sliced_indices or scalar_indices:
sliced_indices = sliced_indices + scalar_indices
indices = np.array(sliced_indices, dtype=np.int64).T
starts = Tensor(indices[0])
ends = Tensor(indices[1])
axes = Tensor(indices[2])
steps = Tensor(indices[3])
result = op.Slice(self, starts, ends, axes, steps)
if to_squeeze:
result = Tensor(np.squeeze(result.value, axis=tuple(to_squeeze)))
else:
result = self
for axis, value in non_scalar_indices:
result = op.Gather(result, value, axis=axis)
return result
def __mod__(self, other):
if self.onnx_dtype in {
ir.DataType.FLOAT,
ir.DataType.DOUBLE,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
}:
return self._opset.Mod(self, other, fmod=1)
return self._opset.Mod(self, other)
def __ne__(self, other):
temp = self._opset.Equal(self, other)
return self._opset.Not(temp)
def __neg__(self):
return self._opset.Neg(self)
def __add__(self, other):
return self._opset.Add(self, other)
def __radd__(self, other):
return self._opset.Add(other, self)
def __and__(self, other):
return self._opset.And(self, other)
def __rand__(self, other):
return self._opset.And(other, self)
def __mul__(self, other):
return self._opset.Mul(self, other)
def __rmul__(self, other):
return self._opset.Mul(other, self)
def __matmul__(self, other):
return self._opset.MatMul(self, other)
def __or__(self, other):
return self._opset.Or(self, other)
def __pow__(self, other):
return self._opset.Pow(self, other)
def __sub__(self, other):
return self._opset.Sub(self, other)
def __rsub__(self, other):
return self._opset.Sub(other, self)
def __truediv__(self, other):
return self._opset.Div(self, other)
def __lt__(self, other):
return self._opset.Less(self, other)
def __le__(self, other):
return self._opset.LessOrEqual(self, other)
def __eq__(self, other):
return self._opset.Equal(self, other)
def __ge__(self, other):
return self._opset.GreaterOrEqual(self, other)
def __gt__(self, other):
return self._opset.Greater(self, other)
|