File size: 1,074 Bytes
b678162 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | import numpy as np
import torch
from typing import List
from torch import Tensor
class TensorList:
def __init__(self, tensor_list: List[Tensor] | Tensor, cumsum):
self._len = len(tensor_list)
if isinstance(tensor_list, List):
tensor_list = torch.cat(tensor_list, dim=0)
self._data = tensor_list
self._cumsum = cumsum
def __len__(self):
return self._len
def __getitem__(self, idx):
start_idx = self._cumsum[idx]
end_idx = self._cumsum[idx+1]
return self._data[start_idx:end_idx]
def cumsum(self):
return self._cumsum
def compute_cumsum(tensors: List[Tensor]):
seq_lens = torch.tensor([0] + [p.shape[0] for p in tensors], dtype=torch.int64)
return torch.cumsum(seq_lens, dim=0)
def make_tensorlist(tensor_list: List[Tensor]):
return TensorList(tensor_list, compute_cumsum(tensor_list))
def compute_cumsum_np(tensors: List[np.ndarray]):
seq_lens = np.array([0] + [p.shape[0] for p in tensors], dtype=np.int64)
return np.cumsum(seq_lens, axis=0)
|