File size: 6,830 Bytes
a257816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import torch


def subsequent_mask(

        size: int,

        device: torch.device = torch.device("cpu"),

) -> torch.Tensor:
    """Create mask for subsequent steps (size, size).



    This mask is used only in decoder which works in an auto-regressive mode.

    This means the current step could only do attention with its left steps.



    In encoder, fully attention is used when streaming is not necessary and

    the sequence is not long. In this  case, no attention mask is needed.



    When streaming is need, chunk-based attention is used in encoder. See

    subsequent_chunk_mask for the chunk-based attention mask.



    Args:

        size (int): size of mask

        str device (str): "cpu" or "cuda" or torch.Tensor.device

        dtype (torch.device): result dtype



    Returns:

        torch.Tensor: mask



    Examples:

        >>> subsequent_mask(3)

        [[1, 0, 0],

         [1, 1, 0],

         [1, 1, 1]]

    """
    arange = torch.arange(size, device=device)
    mask = arange.expand(size, size)
    arange = arange.unsqueeze(-1)
    mask = mask <= arange
    return mask


def subsequent_chunk_mask(

        size: int,

        chunk_size: int,

        num_left_chunks: int = -1,

        device: torch.device = torch.device("cpu"),

) -> torch.Tensor:
    """Create mask for subsequent steps (size, size) with chunk size,

       this is for streaming encoder



    Args:

        size (int): size of mask

        chunk_size (int): size of chunk

        num_left_chunks (int): number of left chunks

            <0: use full chunk

            >=0: use num_left_chunks

        device (torch.device): "cpu" or "cuda" or torch.Tensor.device



    Returns:

        torch.Tensor: mask



    Examples:

        >>> subsequent_chunk_mask(4, 2)

        [[1, 1, 0, 0],

         [1, 1, 0, 0],

         [1, 1, 1, 1],

         [1, 1, 1, 1]]

    """
    ret = torch.zeros(size, size, device=device, dtype=torch.bool)
    for i in range(size):
        if num_left_chunks < 0:
            start = 0
        else:
            start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
        ending = min((i // chunk_size + 1) * chunk_size, size)
        ret[i, start:ending] = True
    return ret


def add_optional_chunk_mask(xs: torch.Tensor,

                            masks: torch.Tensor,

                            use_dynamic_chunk: bool,

                            use_dynamic_left_chunk: bool,

                            decoding_chunk_size: int,

                            static_chunk_size: int,

                            num_decoding_left_chunks: int,

                            enable_full_context: bool = True):
    """ Apply optional mask for encoder.



    Args:

        xs (torch.Tensor): padded input, (B, L, D), L for max length

        mask (torch.Tensor): mask for xs, (B, 1, L)

        use_dynamic_chunk (bool): whether to use dynamic chunk or not

        use_dynamic_left_chunk (bool): whether to use dynamic left chunk for

            training.

        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's

            0: default for training, use random dynamic chunk.

            <0: for decoding, use full chunk.

            >0: for decoding, use fixed chunk size as set.

        static_chunk_size (int): chunk size for static chunk training/decoding

            if it's greater than 0, if use_dynamic_chunk is true,

            this parameter will be ignored

        num_decoding_left_chunks: number of left chunks, this is for decoding,

            the chunk size is decoding_chunk_size.

            >=0: use num_decoding_left_chunks

            <0: use all left chunks

        enable_full_context (bool):

            True: chunk size is either [1, 25] or full context(max_len)

            False: chunk size ~ U[1, 25]



    Returns:

        torch.Tensor: chunk mask of the input xs.

    """
    # Whether to use chunk mask or not
    if use_dynamic_chunk:
        max_len = xs.size(1)
        if decoding_chunk_size < 0:
            chunk_size = max_len
            num_left_chunks = -1
        elif decoding_chunk_size > 0:
            chunk_size = decoding_chunk_size
            num_left_chunks = num_decoding_left_chunks
        else:
            # chunk size is either [1, 25] or full context(max_len).
            # Since we use 4 times subsampling and allow up to 1s(100 frames)
            # delay, the maximum frame is 100 / 4 = 25.
            chunk_size = torch.randint(1, max_len, (1, )).item()
            num_left_chunks = -1
            if chunk_size > max_len // 2 and enable_full_context:
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 25 + 1
                if use_dynamic_left_chunk:
                    max_left_chunks = (max_len - 1) // chunk_size
                    num_left_chunks = torch.randint(0, max_left_chunks,
                                                    (1, )).item()
        chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    elif static_chunk_size > 0:
        num_left_chunks = num_decoding_left_chunks
        chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    else:
        chunk_masks = masks
    return chunk_masks


def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    """Make mask tensor containing indices of padded part.



    See description of make_non_pad_mask.



    Args:

        lengths (torch.Tensor): Batch of lengths (B,).

    Returns:

        torch.Tensor: Mask tensor containing indices of padded part.



    Examples:

        >>> lengths = [5, 3, 2]

        >>> make_pad_mask(lengths)

        masks = [[0, 0, 0, 0 ,0],

                 [0, 0, 0, 1, 1],

                 [0, 0, 1, 1, 1]]

    """
    batch_size = lengths.size(0)
    max_len = max_len if max_len > 0 else lengths.max().item()
    seq_range = torch.arange(0,
                             max_len,
                             dtype=torch.int64,
                             device=lengths.device)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_length_expand = lengths.unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    return mask