| import torch |
|
|
| class NestedTensor: |
| def __init__(self, tensors): |
| self.tensors = list(tensors) |
| self.is_nested = True |
|
|
| def _copy(self): |
| return NestedTensor(self.tensors) |
|
|
| def apply_operation(self, other, operation): |
| o = self._copy() |
| if isinstance(other, NestedTensor): |
| for i, t in enumerate(o.tensors): |
| o.tensors[i] = operation(t, other.tensors[i]) |
| else: |
| for i, t in enumerate(o.tensors): |
| o.tensors[i] = operation(t, other) |
| return o |
|
|
| def __add__(self, b): |
| return self.apply_operation(b, lambda x, y: x + y) |
|
|
| def __sub__(self, b): |
| return self.apply_operation(b, lambda x, y: x - y) |
|
|
| def __mul__(self, b): |
| return self.apply_operation(b, lambda x, y: x * y) |
|
|
| |
| |
|
|
| def __truediv__(self, b): |
| return self.apply_operation(b, lambda x, y: x / y) |
|
|
| def __getitem__(self, *args, **kwargs): |
| return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs)) |
|
|
| def unbind(self): |
| return self.tensors |
|
|
| def to(self, *args, **kwargs): |
| o = self._copy() |
| for i, t in enumerate(o.tensors): |
| o.tensors[i] = t.to(*args, **kwargs) |
| return o |
|
|
| def new_ones(self, *args, **kwargs): |
| return self.tensors[0].new_ones(*args, **kwargs) |
|
|
| def float(self): |
| return self.to(dtype=torch.float) |
|
|
| def chunk(self, *args, **kwargs): |
| return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs)) |
|
|
| def size(self): |
| return self.tensors[0].size() |
|
|
| @property |
| def shape(self): |
| return self.tensors[0].shape |
|
|
| @property |
| def ndim(self): |
| dims = 0 |
| for t in self.tensors: |
| dims = max(t.ndim, dims) |
| return dims |
|
|
| @property |
| def device(self): |
| return self.tensors[0].device |
|
|
| @property |
| def dtype(self): |
| return self.tensors[0].dtype |
|
|
| @property |
| def layout(self): |
| return self.tensors[0].layout |
|
|
|
|
| def cat_nested(tensors, *args, **kwargs): |
| cated_tensors = [] |
| for i in range(len(tensors[0].tensors)): |
| tens = [] |
| for j in range(len(tensors)): |
| tens.append(tensors[j].tensors[i]) |
| cated_tensors.append(torch.cat(tens, *args, **kwargs)) |
| return NestedTensor(cated_tensors) |
|
|