File size: 7,426 Bytes
08157a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

from typing import Any, Optional

import numpy as np

from onnxscript import ir, onnx_opset
from onnxscript._internal import autocast


class Tensor:
    """An implementation of ONNX Tensors, based on a wrapper around numpy arrays.
    Serves to define overloaded ops with an ONNX/ONNXScript semantics.
    """

    def __init__(self, nparray: Optional[np.ndarray], opset=None):
        if nparray is not None and not isinstance(nparray, np.ndarray):
            raise TypeError(
                f"Unexpected type {type(nparray)}. It must be a numpy array or None."
            )

        self._nparray = nparray
        # FIXME(justinhuby): Create a better way to determine the opset version
        self._opset: Any = opset or onnx_opset.opset18

    @property
    def value(self) -> np.ndarray:
        if self._nparray is None:
            raise ValueError("Tensor does not have a value.")
        return self._nparray

    @property
    def rank(self) -> int:
        return len(self.value.shape)

    @property
    def is_scalar(self) -> bool:
        return self.rank == 0

    @property
    def shape(self) -> tuple[int, ...]:
        return self.value.shape

    @property
    def dtype(self) -> np.dtype:
        return self.value.dtype

    @property
    def onnx_dtype(self) -> int:
        return ir.DataType.from_numpy(self.dtype)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.value!r})"

    def __bool__(self) -> bool:
        return bool(self.value)

    def __int__(self) -> int:
        return int(self.value)

    def __float__(self) -> float:
        return float(self.value)

    def __len__(self) -> int:
        return self.shape[0]

    def __index__(self) -> int:
        return self.value.__index__()

    def __getitem__(self, index):
        op = self._opset
        if op.version < 13:
            raise RuntimeError("Indexing requires opset 13 or later.")
        if not isinstance(index, tuple):
            # Normalize representation to a tuple.
            # A single index-value is equivalent to a tuple with a single element.
            index = (index,)
        if len(index) > self.rank:
            raise ValueError(
                f"Number of indices {len(index)} is greater than rank {self.rank}"
            )

        # Promote integer indices to tensors of rank 0
        index = [autocast.cast_pyvalue_to_os_tensor(x) for x in index]
        # Process all elements in index
        shape = self.shape
        sliced_indices = []
        scalar_indices = []
        to_squeeze = []
        non_scalar_indices = []
        for axis_, s in enumerate(index):
            if isinstance(s, slice):
                if s.start is None and s.stop is None and s.step is None:
                    continue
                if s.step is None or s.step > 0:
                    sliced_indices.append(
                        [
                            s.start or 0,
                            s.stop if s.stop is not None else shape[axis_],
                            axis_,
                            s.step or 1,
                        ]
                    )
                else:
                    sliced_indices.append(
                        [
                            s.start if s.start is not None else (shape[axis_] - 1),
                            s.stop if s.stop is not None else -(shape[axis_] + 1),
                            axis_,
                            s.step,
                        ]
                    )
            elif isinstance(s, Tensor):
                if s.is_scalar:
                    scalar_indices.append([s, s + 1, axis_, 1])
                    to_squeeze.append(axis_)
                else:
                    non_scalar_indices.append((axis_, s))
            else:
                raise TypeError(f"Unexpected type {type(s)}: slice or int expected.")

        # Non-scalar-indexing requires the use of ONNX Gather operation.
        # Slicing can be implemented efficiently using ONNX's Slice operation.
        # Scalar-indexing can be implemented using either Gather or with the Slice operation.
        # We map scalar-indexing into the Slice operation, except in the special case
        # of a single scalar-index (with no other sliced_index), which we map directly
        # to a Gather.

        if not (sliced_indices or scalar_indices or non_scalar_indices):
            # Edge case: no index specified. Eg. A[:, :]
            return op.Identity(self)
        if not sliced_indices and len(scalar_indices) == 1:
            # Special case of indexing along a single axis: A[i], A[:, i], A[:, :, i] etc.
            # promote integer input to tensor
            axis = to_squeeze[0]
            index_value = index[axis]
            # use Gather to perform indexing
            result = op.Gather(self, index_value, axis=axis)
        elif sliced_indices or scalar_indices:
            sliced_indices = sliced_indices + scalar_indices
            indices = np.array(sliced_indices, dtype=np.int64).T
            starts = Tensor(indices[0])
            ends = Tensor(indices[1])
            axes = Tensor(indices[2])
            steps = Tensor(indices[3])
            result = op.Slice(self, starts, ends, axes, steps)
            if to_squeeze:
                result = Tensor(np.squeeze(result.value, axis=tuple(to_squeeze)))
        else:
            result = self
        for axis, value in non_scalar_indices:
            result = op.Gather(result, value, axis=axis)

        return result

    def __mod__(self, other):
        if self.onnx_dtype in {
            ir.DataType.FLOAT,
            ir.DataType.DOUBLE,
            ir.DataType.FLOAT16,
            ir.DataType.BFLOAT16,
        }:
            return self._opset.Mod(self, other, fmod=1)
        return self._opset.Mod(self, other)

    def __ne__(self, other):
        temp = self._opset.Equal(self, other)
        return self._opset.Not(temp)

    def __neg__(self):
        return self._opset.Neg(self)

    def __add__(self, other):
        return self._opset.Add(self, other)

    def __radd__(self, other):
        return self._opset.Add(other, self)

    def __and__(self, other):
        return self._opset.And(self, other)

    def __rand__(self, other):
        return self._opset.And(other, self)

    def __mul__(self, other):
        return self._opset.Mul(self, other)

    def __rmul__(self, other):
        return self._opset.Mul(other, self)

    def __matmul__(self, other):
        return self._opset.MatMul(self, other)

    def __or__(self, other):
        return self._opset.Or(self, other)

    def __pow__(self, other):
        return self._opset.Pow(self, other)

    def __sub__(self, other):
        return self._opset.Sub(self, other)

    def __rsub__(self, other):
        return self._opset.Sub(other, self)

    def __truediv__(self, other):
        return self._opset.Div(self, other)

    def __lt__(self, other):
        return self._opset.Less(self, other)

    def __le__(self, other):
        return self._opset.LessOrEqual(self, other)

    def __eq__(self, other):
        return self._opset.Equal(self, other)

    def __ge__(self, other):
        return self._opset.GreaterOrEqual(self, other)

    def __gt__(self, other):
        return self._opset.Greater(self, other)