File size: 4,144 Bytes
838f737
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import functorch._C
import torch
from functorch._C import dim as _C

from .tree_map import tree_flatten, tree_map
from .wrap_type import wrap_type


_C._patch_tensor_class()
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists


class DimensionMismatchError(Exception):
    pass


class DimensionBindError(Exception):
    pass


from . import op_properties


# use dict to avoid writing C++ bindings for set
pointwise = dict.fromkeys(op_properties.pointwise, True)


class _Tensor:
    # fast path around slow wrapping/unwrapping logic for simply queries used
    # by the implementation...

    @property
    def dims(self):
        return tuple(d for d in self._levels if isinstance(d, Dim))

    def dim(self):
        return self.ndim

    __torch_function__ = classmethod(_C.__torch_function__)
    expand = _C._instancemethod(_C.expand)

    index = _C._instancemethod(_C.index)

    def __repr__(self):
        tensor, levels, ndim = self._tensor, self._levels, self.ndim
        return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"


TensorLike = (_Tensor, torch.Tensor)


class Dim(_C.Dim, _Tensor):
    # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precedence.
    # Tensor defines format, but we want to print Dims with special formatting
    __format__ = object.__format__


class Tensor(_Tensor, _C.Tensor):
    from_positional = staticmethod(_C.Tensor_from_positional)
    sum = _C._instancemethod(_C.Tensor_sum)


def cat(tensors, dim, new_dim):
    n = dims()
    return stack(tensors, n, dim).index([n, dim], new_dim)


_wrap = _C._wrap


def _def(name, *args, **kwargs):
    orig = getattr(torch.Tensor, name)
    setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))


t__getitem__ = _C._instancemethod(_C.__getitem__)
stack = _C.stack
split = _C._instancemethod(_C.split)

# note: there is no python reference
t__setitem__ = _C._instancemethod(_C.__setitem__)
# this is patched in the C API because otherwise torch.Tensor will
# no longer be considered a sequence and things will break
# torch.Tensor.__getitem__ = t__getitem__

_Tensor.__getitem__ = t__getitem__
# torch.Tensor.__setitem__ = t__setitem__
_Tensor.__setitem__ = t__setitem__

torch.Tensor.split = split
_Tensor.split = split
torch.Tensor.expand = _C._instancemethod(_C.expand)
torch.Tensor.index = _C._instancemethod(_C.index)
wrap_type(_Tensor, torch.Tensor, _Tensor.__torch_function__)
del _Tensor.ndim

_Tensor.order = _C._instancemethod(_C.order)

_def("mean")
_def("sum")
_def("all")
_def("amax")
_def("amin")
_def("aminmax")
_def("any")
_def("count_nonzero")
_def("logsumexp")
_def("nanmean")
_def("nansum")
_def("prod")
_def("std", keepdim_offset=2)
_def("var", keepdim_offset=2)
_def("max", single_dim=True)
_def("min", single_dim=True)
_def("argmax", single_dim=True)
_def("argmin", single_dim=True)
_def("kthvalue", single_dim=True)
_def("median", single_dim=True)
_def("nanmedian", single_dim=True)
_def("mode", single_dim=True)
_def("sort", reduce=False)
_def("argsort", reduce=False)
_def("unbind", single_dim=True)
_def("chunk", dim_offset=1, reduce=False)
_def("cummax", single_dim=True, reduce=False)
_def("cummin", single_dim=True, reduce=False)
_def("cumprod", single_dim=True, reduce=False)
_def("cumprod_", single_dim=True, reduce=False)
_def("cumsum", single_dim=True, reduce=False)
_def("cumsum_", single_dim=True, reduce=False)
_def("logcumsumexp", single_dim=True, reduce=False)
_def("renorm", dim_offset=1, single_dim=True, reduce=False)
_def("softmax", single_dim=True, reduce=False)
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)

# stuff to handle in the future, because they require special
# binding logic for dims
# cross
# diag_embed
# diagonal
# diagonal_scatter
# diff
# nanquantile
# quantile
# roll
# rot90
# topk (new dimes on output)
# should these all be subsumed by inplace indexing?
# index_add_
# index_add
# index_copy
# index_copy_
# index_fill
# index_fill_
# index_select
# scatter
# scatter_
# scatter_add
# scatter_add_
# scatter_reduce