| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | def insert_blank(label, blank_id=0): |
| | """Insert blank token between every two label token.""" |
| | label = np.expand_dims(label, 1) |
| | blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id |
| | label = np.concatenate([blanks, label], axis=1) |
| | label = label.reshape(-1) |
| | label = np.append(label, label[0]) |
| | return label |
| |
|
| |
|
| | def forced_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list: |
| | """ctc forced alignment. |
| | |
| | Args: |
| | torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) |
| | torch.Tensor y: id sequence tensor 1d tensor (L) |
| | int blank_id: blank symbol index |
| | Returns: |
| | torch.Tensor: alignment result |
| | """ |
| | y_insert_blank = insert_blank(y, blank_id) |
| |
|
| | log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank))) |
| | log_alpha = log_alpha - float("inf") |
| | state_path = ( |
| | torch.zeros((ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1 |
| | ) |
| |
|
| | |
| | log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] |
| | log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] |
| |
|
| | for t in range(1, ctc_probs.size(0)): |
| | for s in range(len(y_insert_blank)): |
| | if ( |
| | y_insert_blank[s] == blank_id |
| | or s < 2 |
| | or y_insert_blank[s] == y_insert_blank[s - 2] |
| | ): |
| | candidates = torch.tensor( |
| | [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]] |
| | ) |
| | prev_state = [s, s - 1] |
| | else: |
| | candidates = torch.tensor( |
| | [ |
| | log_alpha[t - 1, s], |
| | log_alpha[t - 1, s - 1], |
| | log_alpha[t - 1, s - 2], |
| | ] |
| | ) |
| | prev_state = [s, s - 1, s - 2] |
| | log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]] |
| | state_path[t, s] = prev_state[torch.argmax(candidates)] |
| |
|
| | state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16) |
| |
|
| | candidates = torch.tensor( |
| | [log_alpha[-1, len(y_insert_blank) - 1], log_alpha[-1, len(y_insert_blank) - 2]] |
| | ) |
| | final_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] |
| | state_seq[-1] = final_state[torch.argmax(candidates)] |
| | for t in range(ctc_probs.size(0) - 2, -1, -1): |
| | state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] |
| |
|
| | output_alignment = [] |
| | for t in range(0, ctc_probs.size(0)): |
| | output_alignment.append(y_insert_blank[state_seq[t, 0]]) |
| |
|
| | return output_alignment |
| |
|