import numpy as np import torch from typing import List from torch import Tensor class TensorList: def __init__(self, tensor_list: List[Tensor] | Tensor, cumsum): self._len = len(tensor_list) if isinstance(tensor_list, List): tensor_list = torch.cat(tensor_list, dim=0) self._data = tensor_list self._cumsum = cumsum def __len__(self): return self._len def __getitem__(self, idx): start_idx = self._cumsum[idx] end_idx = self._cumsum[idx+1] return self._data[start_idx:end_idx] def cumsum(self): return self._cumsum def compute_cumsum(tensors: List[Tensor]): seq_lens = torch.tensor([0] + [p.shape[0] for p in tensors], dtype=torch.int64) return torch.cumsum(seq_lens, dim=0) def make_tensorlist(tensor_list: List[Tensor]): return TensorList(tensor_list, compute_cumsum(tensor_list)) def compute_cumsum_np(tensors: List[np.ndarray]): seq_lens = np.array([0] + [p.shape[0] for p in tensors], dtype=np.int64) return np.cumsum(seq_lens, axis=0)