File size: 1,074 Bytes
b678162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import torch
from typing import List
from torch import Tensor


class TensorList:
    def __init__(self, tensor_list: List[Tensor] | Tensor, cumsum):
        self._len = len(tensor_list)
        if isinstance(tensor_list, List):
            tensor_list = torch.cat(tensor_list, dim=0)
        self._data = tensor_list
        self._cumsum = cumsum

    def __len__(self):
        return self._len

    def __getitem__(self, idx):
        start_idx = self._cumsum[idx]
        end_idx = self._cumsum[idx+1]
        return self._data[start_idx:end_idx]

    def cumsum(self):
        return self._cumsum


def compute_cumsum(tensors: List[Tensor]):
    seq_lens = torch.tensor([0] + [p.shape[0] for p in tensors], dtype=torch.int64)
    return torch.cumsum(seq_lens, dim=0)


def make_tensorlist(tensor_list: List[Tensor]):
    return TensorList(tensor_list, compute_cumsum(tensor_list))


def compute_cumsum_np(tensors: List[np.ndarray]):
    seq_lens = np.array([0] + [p.shape[0] for p in tensors], dtype=np.int64)
    return np.cumsum(seq_lens, axis=0)