|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import os |
|
|
|
|
|
import soundfile as sf |
|
|
import torch |
|
|
from utils.constants import BLANK_TOKEN, SPACE_TOKEN, V_NEGATIVE_NUM |
|
|
|
|
|
|
|
|
def get_batch_starts_ends(manifest_filepath, batch_size): |
|
|
""" |
|
|
Get the start and end ids of the lines we will use for each 'batch'. |
|
|
""" |
|
|
|
|
|
with open(manifest_filepath, 'r') as f: |
|
|
num_lines_in_manifest = sum(1 for _ in f) |
|
|
|
|
|
starts = [x for x in range(0, num_lines_in_manifest, batch_size)] |
|
|
ends = [x - 1 for x in starts] |
|
|
ends.pop(0) |
|
|
ends.append(num_lines_in_manifest) |
|
|
|
|
|
return starts, ends |
|
|
|
|
|
|
|
|
def is_entry_in_any_lines(manifest_filepath, entry): |
|
|
""" |
|
|
Returns True if entry is a key in any of the JSON lines in manifest_filepath |
|
|
""" |
|
|
|
|
|
entry_in_manifest = False |
|
|
|
|
|
with open(manifest_filepath, 'r') as f: |
|
|
for line in f: |
|
|
data = json.loads(line) |
|
|
|
|
|
if entry in data: |
|
|
entry_in_manifest = True |
|
|
|
|
|
return entry_in_manifest |
|
|
|
|
|
|
|
|
def is_entry_in_all_lines(manifest_filepath, entry): |
|
|
""" |
|
|
Returns True is entry is a key in all of the JSON lines in manifest_filepath. |
|
|
""" |
|
|
with open(manifest_filepath, 'r') as f: |
|
|
for line in f: |
|
|
data = json.loads(line) |
|
|
|
|
|
if entry not in data: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def get_manifest_lines_batch(manifest_filepath, start, end): |
|
|
manifest_lines_batch = [] |
|
|
with open(manifest_filepath, "r") as f: |
|
|
for line_i, line in enumerate(f): |
|
|
if line_i == start and line_i == end: |
|
|
manifest_lines_batch.append(json.loads(line)) |
|
|
break |
|
|
|
|
|
if line_i == end: |
|
|
break |
|
|
if line_i >= start: |
|
|
manifest_lines_batch.append(json.loads(line)) |
|
|
return manifest_lines_batch |
|
|
|
|
|
|
|
|
def get_char_tokens(text, model): |
|
|
tokens = [] |
|
|
for character in text: |
|
|
if character in model.decoder.vocabulary: |
|
|
tokens.append(model.decoder.vocabulary.index(character)) |
|
|
else: |
|
|
tokens.append(len(model.decoder.vocabulary)) |
|
|
|
|
|
return tokens |
|
|
|
|
|
|
|
|
def get_y_and_boundary_info_for_utt(text, model, separator): |
|
|
""" |
|
|
Get y_token_ids_with_blanks, token_info, word_info and segment_info for the text provided, tokenized |
|
|
by the model provided. |
|
|
y_token_ids_with_blanks is a list of the indices of the text tokens with the blank token id in between every |
|
|
text token. |
|
|
token_info, word_info and segment_info are lists of dictionaries containing information about |
|
|
where the tokens/words/segments start and end. |
|
|
For example, 'hi world | hey ' with separator = '|' and tokenized by a BPE tokenizer can have token_info like: |
|
|
token_info = [ |
|
|
{'text': '<b>', 's_start': 0, 's_end': 0}, |
|
|
{'text': '▁hi', 's_start': 1, 's_end': 1}, |
|
|
{'text': '<b>', 's_start': 2, 's_end': 2}, |
|
|
{'text': '▁world', 's_start': 3, 's_end': 3}, |
|
|
{'text': '<b>', 's_start': 4, 's_end': 4}, |
|
|
{'text': '▁he', 's_start': 5, 's_end': 5}, |
|
|
{'text': '<b>', 's_start': 6, 's_end': 6}, |
|
|
{'text': 'y', 's_start': 7, 's_end': 7}, |
|
|
{'text': '<b>', 's_start': 8, 's_end': 8}, |
|
|
] |
|
|
's_start' and 's_end' indicate where in the sequence of tokens does each token start and end. |
|
|
|
|
|
The word_info will be as follows: |
|
|
word_info = [ |
|
|
{'text': 'hi', 's_start': 1, 's_end': 1}, |
|
|
{'text': 'world', 's_start': 3, 's_end': 3}, |
|
|
{'text': 'hey', 's_start': 5, 's_end': 7}, |
|
|
] |
|
|
's_start' and 's_end' indicate where in the sequence of tokens does each word start and end. |
|
|
|
|
|
segment_info will be as follows: |
|
|
segment_info = [ |
|
|
{'text': 'hi world', 's_start': 1, 's_end': 3}, |
|
|
{'text': 'hey', 's_start': 5, 's_end': 7}, |
|
|
] |
|
|
's_start' and 's_end' indicate where in the sequence of tokens does each segment start and end. |
|
|
""" |
|
|
|
|
|
if not separator: |
|
|
segments = [text] |
|
|
else: |
|
|
segments = text.split(separator) |
|
|
|
|
|
|
|
|
segments = [seg.strip() for seg in segments] |
|
|
|
|
|
if hasattr(model, 'tokenizer'): |
|
|
|
|
|
BLANK_ID = len(model.decoder.vocabulary) |
|
|
|
|
|
y_token_ids_with_blanks = [BLANK_ID] |
|
|
token_info = [{"text": BLANK_TOKEN, "s_start": 0, "s_end": 0,}] |
|
|
word_info = [] |
|
|
segment_info = [] |
|
|
|
|
|
segment_s_pointer = 1 |
|
|
word_s_pointer = 1 |
|
|
|
|
|
for segment in segments: |
|
|
words = segment.split(" ") |
|
|
for word in words: |
|
|
|
|
|
word_tokens = model.tokenizer.text_to_tokens(word) |
|
|
word_ids = model.tokenizer.text_to_ids(word) |
|
|
for token, id_ in zip(word_tokens, word_ids): |
|
|
|
|
|
|
|
|
y_token_ids_with_blanks.extend([id_, BLANK_ID]) |
|
|
token_info.extend( |
|
|
[ |
|
|
{ |
|
|
"text": token, |
|
|
"s_start": len(y_token_ids_with_blanks) - 2, |
|
|
"s_end": len(y_token_ids_with_blanks) - 2, |
|
|
}, |
|
|
{ |
|
|
"text": BLANK_TOKEN, |
|
|
"s_start": len(y_token_ids_with_blanks) - 1, |
|
|
"s_end": len(y_token_ids_with_blanks) - 1, |
|
|
}, |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
word_info.append( |
|
|
{ |
|
|
"text": word, |
|
|
"s_start": word_s_pointer, |
|
|
"s_end": word_s_pointer + (len(word_tokens) - 1) * 2, |
|
|
} |
|
|
) |
|
|
word_s_pointer += len(word_tokens) * 2 |
|
|
|
|
|
|
|
|
segment_tokens = model.tokenizer.text_to_tokens(segment) |
|
|
segment_info.append( |
|
|
{ |
|
|
"text": segment, |
|
|
"s_start": segment_s_pointer, |
|
|
"s_end": segment_s_pointer + (len(segment_tokens) - 1) * 2, |
|
|
} |
|
|
) |
|
|
segment_s_pointer += len(segment_tokens) * 2 |
|
|
|
|
|
return y_token_ids_with_blanks, token_info, word_info, segment_info |
|
|
|
|
|
elif hasattr(model.decoder, "vocabulary"): |
|
|
|
|
|
BLANK_ID = len(model.decoder.vocabulary) |
|
|
SPACE_ID = model.decoder.vocabulary.index(" ") |
|
|
|
|
|
y_token_ids_with_blanks = [BLANK_ID] |
|
|
token_info = [{"text": BLANK_TOKEN, "s_start": 0, "s_end": 0,}] |
|
|
word_info = [] |
|
|
segment_info = [] |
|
|
|
|
|
segment_s_pointer = 1 |
|
|
word_s_pointer = 1 |
|
|
|
|
|
for i_segment, segment in enumerate(segments): |
|
|
words = segment.split(" ") |
|
|
for i_word, word in enumerate(words): |
|
|
|
|
|
|
|
|
word_tokens = list(word) |
|
|
|
|
|
word_ids = get_char_tokens(word, model) |
|
|
for token, id_ in zip(word_tokens, word_ids): |
|
|
|
|
|
|
|
|
y_token_ids_with_blanks.extend([id_, BLANK_ID]) |
|
|
token_info.extend( |
|
|
[ |
|
|
{ |
|
|
"text": token, |
|
|
"s_start": len(y_token_ids_with_blanks) - 2, |
|
|
"s_end": len(y_token_ids_with_blanks) - 2, |
|
|
}, |
|
|
{ |
|
|
"text": BLANK_TOKEN, |
|
|
"s_start": len(y_token_ids_with_blanks) - 1, |
|
|
"s_end": len(y_token_ids_with_blanks) - 1, |
|
|
}, |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
if not (i_segment == len(segments) - 1 and i_word == len(words) - 1): |
|
|
y_token_ids_with_blanks.extend([SPACE_ID, BLANK_ID]) |
|
|
token_info.extend( |
|
|
( |
|
|
{ |
|
|
"text": SPACE_TOKEN, |
|
|
"s_start": len(y_token_ids_with_blanks) - 2, |
|
|
"s_end": len(y_token_ids_with_blanks) - 2, |
|
|
}, |
|
|
{ |
|
|
"text": BLANK_TOKEN, |
|
|
"s_start": len(y_token_ids_with_blanks) - 1, |
|
|
"s_end": len(y_token_ids_with_blanks) - 1, |
|
|
}, |
|
|
) |
|
|
) |
|
|
|
|
|
word_info.append( |
|
|
{ |
|
|
"text": word, |
|
|
"s_start": word_s_pointer, |
|
|
"s_end": word_s_pointer + len(word_tokens) * 2 - 2, |
|
|
} |
|
|
) |
|
|
word_s_pointer += len(word_tokens) * 2 + 2 |
|
|
|
|
|
|
|
|
segment_tokens = get_char_tokens(segment, model) |
|
|
segment_info.append( |
|
|
{ |
|
|
"text": segment, |
|
|
"s_start": segment_s_pointer, |
|
|
"s_end": segment_s_pointer + (len(segment_tokens) - 1) * 2, |
|
|
} |
|
|
) |
|
|
segment_s_pointer += len(segment_tokens) * 2 + 2 |
|
|
|
|
|
return y_token_ids_with_blanks, token_info, word_info, segment_info |
|
|
|
|
|
else: |
|
|
raise RuntimeError("Cannot get tokens of this model.") |
|
|
|
|
|
|
|
|
def get_batch_tensors_and_boundary_info(manifest_lines_batch, model, separator, align_using_pred_text): |
|
|
""" |
|
|
Returns: |
|
|
log_probs, y, T, U (y and U are s.t. every other token is a blank) - these are the tensors we will need |
|
|
during Viterbi decoding. |
|
|
token_info_list, word_info_list, segment_info_list - these are lists of dictionaries which we will need |
|
|
for writing the CTM files with the human-readable alignments. |
|
|
pred_text_list - this is a list of the transcriptions from our model which we will save to our output JSON |
|
|
file if align_using_pred_text is True. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audio_filepaths_batch = [line["audio_filepath"] for line in manifest_lines_batch] |
|
|
B = len(audio_filepaths_batch) |
|
|
with torch.no_grad(): |
|
|
hypotheses = model.transcribe(audio_filepaths_batch, return_hypotheses=True, batch_size=B) |
|
|
|
|
|
log_probs_list_batch = [] |
|
|
T_list_batch = [] |
|
|
pred_text_batch = [] |
|
|
for hypothesis in hypotheses: |
|
|
log_probs_list_batch.append(hypothesis.y_sequence) |
|
|
T_list_batch.append(hypothesis.y_sequence.shape[0]) |
|
|
pred_text_batch.append(hypothesis.text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_list_batch = [] |
|
|
U_list_batch = [] |
|
|
token_info_batch = [] |
|
|
word_info_batch = [] |
|
|
segment_info_batch = [] |
|
|
|
|
|
for i_line, line in enumerate(manifest_lines_batch): |
|
|
if align_using_pred_text: |
|
|
gt_text_for_alignment = pred_text_batch[i_line] |
|
|
else: |
|
|
gt_text_for_alignment = line["text"] |
|
|
y_utt, token_info_utt, word_info_utt, segment_info_utt = get_y_and_boundary_info_for_utt( |
|
|
gt_text_for_alignment, model, separator |
|
|
) |
|
|
|
|
|
y_list_batch.append(y_utt) |
|
|
U_list_batch.append(len(y_utt)) |
|
|
token_info_batch.append(token_info_utt) |
|
|
word_info_batch.append(word_info_utt) |
|
|
segment_info_batch.append(segment_info_utt) |
|
|
|
|
|
|
|
|
T_max = max(T_list_batch) |
|
|
U_max = max(U_list_batch) |
|
|
|
|
|
V = len(model.decoder.vocabulary) + 1 |
|
|
T_batch = torch.tensor(T_list_batch) |
|
|
U_batch = torch.tensor(U_list_batch) |
|
|
|
|
|
|
|
|
log_probs_batch = V_NEGATIVE_NUM * torch.ones((B, T_max, V)) |
|
|
for b, log_probs_utt in enumerate(log_probs_list_batch): |
|
|
t = log_probs_utt.shape[0] |
|
|
log_probs_batch[b, :t, :] = log_probs_utt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_batch = V * torch.ones((B, U_max), dtype=torch.int64) |
|
|
for b, y_utt in enumerate(y_list_batch): |
|
|
U_utt = U_batch[b] |
|
|
y_batch[b, :U_utt] = torch.tensor(y_utt) |
|
|
|
|
|
return ( |
|
|
log_probs_batch, |
|
|
y_batch, |
|
|
T_batch, |
|
|
U_batch, |
|
|
token_info_batch, |
|
|
word_info_batch, |
|
|
segment_info_batch, |
|
|
pred_text_batch, |
|
|
) |
|
|
|