File size: 14,780 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import soundfile as sf
import torch
from utils.constants import BLANK_TOKEN, SPACE_TOKEN, V_NEGATIVE_NUM


def get_batch_starts_ends(manifest_filepath, batch_size):
    """
    Get the start and end ids of the lines we will use for each 'batch'.
    """

    with open(manifest_filepath, 'r') as f:
        num_lines_in_manifest = sum(1 for _ in f)

    starts = [x for x in range(0, num_lines_in_manifest, batch_size)]
    ends = [x - 1 for x in starts]
    ends.pop(0)
    ends.append(num_lines_in_manifest)

    return starts, ends


def is_entry_in_any_lines(manifest_filepath, entry):
    """
    Returns True if entry is a key in any of the JSON lines in manifest_filepath
    """

    entry_in_manifest = False

    with open(manifest_filepath, 'r') as f:
        for line in f:
            data = json.loads(line)

            if entry in data:
                entry_in_manifest = True

    return entry_in_manifest


def is_entry_in_all_lines(manifest_filepath, entry):
    """
    Returns True is entry is a key in all of the JSON lines in manifest_filepath.
    """
    with open(manifest_filepath, 'r') as f:
        for line in f:
            data = json.loads(line)

            if entry not in data:
                return False

    return True


def get_manifest_lines_batch(manifest_filepath, start, end):
    manifest_lines_batch = []
    with open(manifest_filepath, "r") as f:
        for line_i, line in enumerate(f):
            if line_i == start and line_i == end:
                manifest_lines_batch.append(json.loads(line))
                break

            if line_i == end:
                break
            if line_i >= start:
                manifest_lines_batch.append(json.loads(line))
    return manifest_lines_batch


def get_char_tokens(text, model):
    tokens = []
    for character in text:
        if character in model.decoder.vocabulary:
            tokens.append(model.decoder.vocabulary.index(character))
        else:
            tokens.append(len(model.decoder.vocabulary))  # return unk token (same as blank token)

    return tokens


def get_y_and_boundary_info_for_utt(text, model, separator):
    """
    Get y_token_ids_with_blanks, token_info, word_info and segment_info for the text provided, tokenized 
    by the model provided.
    y_token_ids_with_blanks is a list of the indices of the text tokens with the blank token id in between every
    text token.
    token_info, word_info and segment_info are lists of dictionaries containing information about 
    where the tokens/words/segments start and end.
    For example, 'hi world | hey ' with separator = '|' and tokenized by a BPE tokenizer can have token_info like:
    token_info = [
        {'text': '<b>', 's_start': 0, 's_end': 0},
        {'text': '▁hi', 's_start': 1, 's_end': 1},
        {'text': '<b>', 's_start': 2, 's_end': 2},
        {'text': '▁world', 's_start': 3, 's_end': 3},
        {'text': '<b>', 's_start': 4, 's_end': 4},
        {'text': '▁he', 's_start': 5, 's_end': 5},
        {'text': '<b>', 's_start': 6, 's_end': 6},
        {'text': 'y', 's_start': 7, 's_end': 7},
        {'text': '<b>', 's_start': 8, 's_end': 8},    
    ]
    's_start' and 's_end' indicate where in the sequence of tokens does each token start and end.

    The word_info will be as follows:
    word_info = [
        {'text': 'hi', 's_start': 1, 's_end': 1},
        {'text': 'world', 's_start': 3, 's_end': 3},
        {'text': 'hey', 's_start': 5, 's_end': 7},
    ]
    's_start' and 's_end' indicate where in the sequence of tokens does each word start and end.

    segment_info will be as follows:
    segment_info = [
        {'text': 'hi world', 's_start': 1, 's_end': 3},
        {'text': 'hey', 's_start': 5, 's_end': 7},
    ]
    's_start' and 's_end' indicate where in the sequence of tokens does each segment start and end.
    """

    if not separator:  # if separator is not defined - treat the whole text as one segment
        segments = [text]
    else:
        segments = text.split(separator)

    # remove any spaces at start and end of segments
    segments = [seg.strip() for seg in segments]

    if hasattr(model, 'tokenizer'):

        BLANK_ID = len(model.decoder.vocabulary)  # TODO: check

        y_token_ids_with_blanks = [BLANK_ID]
        token_info = [{"text": BLANK_TOKEN, "s_start": 0, "s_end": 0,}]
        word_info = []
        segment_info = []

        segment_s_pointer = 1  # first segment will start at s=1 because s=0 is a blank
        word_s_pointer = 1  # first word will start at s=1 because s=0 is a blank

        for segment in segments:
            words = segment.split(" ")  # we define words to be space-separated sub-strings
            for word in words:

                word_tokens = model.tokenizer.text_to_tokens(word)
                word_ids = model.tokenizer.text_to_ids(word)
                for token, id_ in zip(word_tokens, word_ids):
                    # add the text token and the blank that follows it
                    # to our token-based variables
                    y_token_ids_with_blanks.extend([id_, BLANK_ID])
                    token_info.extend(
                        [
                            {
                                "text": token,
                                "s_start": len(y_token_ids_with_blanks) - 2,
                                "s_end": len(y_token_ids_with_blanks) - 2,
                            },
                            {
                                "text": BLANK_TOKEN,
                                "s_start": len(y_token_ids_with_blanks) - 1,
                                "s_end": len(y_token_ids_with_blanks) - 1,
                            },
                        ]
                    )

                # add the word to word_info and increment the word_s_pointer
                word_info.append(
                    {
                        "text": word,
                        "s_start": word_s_pointer,
                        "s_end": word_s_pointer + (len(word_tokens) - 1) * 2,  # TODO check this,
                    }
                )
                word_s_pointer += len(word_tokens) * 2  # TODO check this

            # add the segment to segment_info and increment the segment_s_pointer
            segment_tokens = model.tokenizer.text_to_tokens(segment)
            segment_info.append(
                {
                    "text": segment,
                    "s_start": segment_s_pointer,
                    "s_end": segment_s_pointer + (len(segment_tokens) - 1) * 2,
                }
            )
            segment_s_pointer += len(segment_tokens) * 2

        return y_token_ids_with_blanks, token_info, word_info, segment_info

    elif hasattr(model.decoder, "vocabulary"):  # i.e. tokenization is simply character-based

        BLANK_ID = len(model.decoder.vocabulary)  # TODO: check this is correct
        SPACE_ID = model.decoder.vocabulary.index(" ")

        y_token_ids_with_blanks = [BLANK_ID]
        token_info = [{"text": BLANK_TOKEN, "s_start": 0, "s_end": 0,}]
        word_info = []
        segment_info = []

        segment_s_pointer = 1  # first segment will start at s=1 because s=0 is a blank
        word_s_pointer = 1  # first word will start at s=1 because s=0 is a blank

        for i_segment, segment in enumerate(segments):
            words = segment.split(" ")  # we define words to be space-separated characters
            for i_word, word in enumerate(words):

                # convert string to list of characters
                word_tokens = list(word)
                # convert list of characters to list of their ids in the vocabulary
                word_ids = get_char_tokens(word, model)
                for token, id_ in zip(word_tokens, word_ids):
                    # add the text token and the blank that follows it
                    # to our token-based variables
                    y_token_ids_with_blanks.extend([id_, BLANK_ID])
                    token_info.extend(
                        [
                            {
                                "text": token,
                                "s_start": len(y_token_ids_with_blanks) - 2,
                                "s_end": len(y_token_ids_with_blanks) - 2,
                            },
                            {
                                "text": BLANK_TOKEN,
                                "s_start": len(y_token_ids_with_blanks) - 1,
                                "s_end": len(y_token_ids_with_blanks) - 1,
                            },
                        ]
                    )

                # add space token (and the blank after it) unless this is the final word in the final segment
                if not (i_segment == len(segments) - 1 and i_word == len(words) - 1):
                    y_token_ids_with_blanks.extend([SPACE_ID, BLANK_ID])
                    token_info.extend(
                        (
                            {
                                "text": SPACE_TOKEN,
                                "s_start": len(y_token_ids_with_blanks) - 2,
                                "s_end": len(y_token_ids_with_blanks) - 2,
                            },
                            {
                                "text": BLANK_TOKEN,
                                "s_start": len(y_token_ids_with_blanks) - 1,
                                "s_end": len(y_token_ids_with_blanks) - 1,
                            },
                        )
                    )
                # add the word to word_info and increment the word_s_pointer
                word_info.append(
                    {
                        "text": word,
                        "s_start": word_s_pointer,
                        "s_end": word_s_pointer + len(word_tokens) * 2 - 2,  # TODO check this,
                    }
                )
                word_s_pointer += len(word_tokens) * 2 + 2  # TODO check this

            # add the segment to segment_info and increment the segment_s_pointer
            segment_tokens = get_char_tokens(segment, model)
            segment_info.append(
                {
                    "text": segment,
                    "s_start": segment_s_pointer,
                    "s_end": segment_s_pointer + (len(segment_tokens) - 1) * 2,
                }
            )
            segment_s_pointer += len(segment_tokens) * 2 + 2

        return y_token_ids_with_blanks, token_info, word_info, segment_info

    else:
        raise RuntimeError("Cannot get tokens of this model.")


def get_batch_tensors_and_boundary_info(manifest_lines_batch, model, separator, align_using_pred_text):
    """
    Returns:
        log_probs, y, T, U (y and U are s.t. every other token is a blank) - these are the tensors we will need
            during Viterbi decoding.
        token_info_list, word_info_list, segment_info_list - these are lists of dictionaries which we will need
            for writing the CTM files with the human-readable alignments.
        pred_text_list - this is a list of the transcriptions from our model which we will save to our output JSON
            file if align_using_pred_text is True.
    """

    # get hypotheses by calling 'transcribe'
    # we will use the output log_probs, the duration of the log_probs,
    # and (optionally) the predicted ASR text from the hypotheses
    audio_filepaths_batch = [line["audio_filepath"] for line in manifest_lines_batch]
    B = len(audio_filepaths_batch)
    with torch.no_grad():
        hypotheses = model.transcribe(audio_filepaths_batch, return_hypotheses=True, batch_size=B)

    log_probs_list_batch = []
    T_list_batch = []
    pred_text_batch = []
    for hypothesis in hypotheses:
        log_probs_list_batch.append(hypothesis.y_sequence)
        T_list_batch.append(hypothesis.y_sequence.shape[0])
        pred_text_batch.append(hypothesis.text)

    # we loop over every line in the manifest that is in our current batch,
    # and record the y (list of tokens, including blanks), U (list of lengths of y) and
    # token_info_batch, word_info_batch, segment_info_batch
    y_list_batch = []
    U_list_batch = []
    token_info_batch = []
    word_info_batch = []
    segment_info_batch = []

    for i_line, line in enumerate(manifest_lines_batch):
        if align_using_pred_text:
            gt_text_for_alignment = pred_text_batch[i_line]
        else:
            gt_text_for_alignment = line["text"]
        y_utt, token_info_utt, word_info_utt, segment_info_utt = get_y_and_boundary_info_for_utt(
            gt_text_for_alignment, model, separator
        )

        y_list_batch.append(y_utt)
        U_list_batch.append(len(y_utt))
        token_info_batch.append(token_info_utt)
        word_info_batch.append(word_info_utt)
        segment_info_batch.append(segment_info_utt)

    # turn log_probs, y, T, U into dense tensors for fast computation during Viterbi decoding
    T_max = max(T_list_batch)
    U_max = max(U_list_batch)
    #  V = the number of tokens in the vocabulary + 1 for the blank token.
    V = len(model.decoder.vocabulary) + 1
    T_batch = torch.tensor(T_list_batch)
    U_batch = torch.tensor(U_list_batch)

    # make log_probs_batch tensor of shape (B x T_max x V)
    log_probs_batch = V_NEGATIVE_NUM * torch.ones((B, T_max, V))
    for b, log_probs_utt in enumerate(log_probs_list_batch):
        t = log_probs_utt.shape[0]
        log_probs_batch[b, :t, :] = log_probs_utt

    # make y tensor of shape (B x U_max)
    # populate it initially with all 'V' numbers so that the 'V's will remain in the areas that
    # are 'padding'. This will be useful for when we make 'log_probs_reorderd' during Viterbi decoding
    # in a different function.
    y_batch = V * torch.ones((B, U_max), dtype=torch.int64)
    for b, y_utt in enumerate(y_list_batch):
        U_utt = U_batch[b]
        y_batch[b, :U_utt] = torch.tensor(y_utt)

    return (
        log_probs_batch,
        y_batch,
        T_batch,
        U_batch,
        token_info_batch,
        word_info_batch,
        segment_info_batch,
        pred_text_batch,
    )