| | import torch |
| |
|
| | class NestedTensor: |
| | def __init__(self, tensors): |
| | self.tensors = list(tensors) |
| | self.is_nested = True |
| |
|
| | def _copy(self): |
| | return NestedTensor(self.tensors) |
| |
|
| | def apply_operation(self, other, operation): |
| | o = self._copy() |
| | if isinstance(other, NestedTensor): |
| | for i, t in enumerate(o.tensors): |
| | o.tensors[i] = operation(t, other.tensors[i]) |
| | else: |
| | for i, t in enumerate(o.tensors): |
| | o.tensors[i] = operation(t, other) |
| | return o |
| |
|
| | def __add__(self, b): |
| | return self.apply_operation(b, lambda x, y: x + y) |
| |
|
| | def __sub__(self, b): |
| | return self.apply_operation(b, lambda x, y: x - y) |
| |
|
| | def __mul__(self, b): |
| | return self.apply_operation(b, lambda x, y: x * y) |
| |
|
| | |
| | |
| |
|
| | def __truediv__(self, b): |
| | return self.apply_operation(b, lambda x, y: x / y) |
| |
|
| | def __getitem__(self, *args, **kwargs): |
| | return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs)) |
| |
|
| | def unbind(self): |
| | return self.tensors |
| |
|
| | def to(self, *args, **kwargs): |
| | o = self._copy() |
| | for i, t in enumerate(o.tensors): |
| | o.tensors[i] = t.to(*args, **kwargs) |
| | return o |
| |
|
| | def new_ones(self, *args, **kwargs): |
| | return self.tensors[0].new_ones(*args, **kwargs) |
| |
|
| | def float(self): |
| | return self.to(dtype=torch.float) |
| |
|
| | def chunk(self, *args, **kwargs): |
| | return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs)) |
| |
|
| | def size(self): |
| | return self.tensors[0].size() |
| |
|
| | @property |
| | def shape(self): |
| | return self.tensors[0].shape |
| |
|
| | @property |
| | def ndim(self): |
| | dims = 0 |
| | for t in self.tensors: |
| | dims = max(t.ndim, dims) |
| | return dims |
| |
|
| | @property |
| | def device(self): |
| | return self.tensors[0].device |
| |
|
| | @property |
| | def dtype(self): |
| | return self.tensors[0].dtype |
| |
|
| | @property |
| | def layout(self): |
| | return self.tensors[0].layout |
| |
|
| |
|
| | def cat_nested(tensors, *args, **kwargs): |
| | cated_tensors = [] |
| | for i in range(len(tensors[0].tensors)): |
| | tens = [] |
| | for j in range(len(tensors)): |
| | tens.append(tensors[j].tensors[i]) |
| | cated_tensors.append(torch.cat(tens, *args, **kwargs)) |
| | return NestedTensor(cated_tensors) |
| |
|