| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional |
|
|
| import six |
| import torch |
| import numpy as np |
|
|
|
|
| def sequence_mask( |
| lengths, |
| maxlen: Optional[int] = None, |
| dtype: torch.dtype = torch.float32, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| if maxlen is None: |
| maxlen = lengths.max() |
| row_vector = torch.arange(0, maxlen, 1).to(lengths.device) |
| matrix = torch.unsqueeze(lengths, dim=-1) |
| mask = row_vector < matrix |
| mask = mask.detach() |
|
|
| return mask.type(dtype).to(device) if device is not None else mask.type(dtype) |
|
|
|
|
| def end_detect(ended_hyps, i, M=3, d_end=np.log(1 * np.exp(-10))): |
| """End detection. |
| |
| described in Eq. (50) of S. Watanabe et al |
| "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" |
| |
| :param ended_hyps: |
| :param i: |
| :param M: |
| :param d_end: |
| :return: |
| """ |
| if len(ended_hyps) == 0: |
| return False |
| count = 0 |
| best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0] |
| for m in six.moves.range(M): |
| |
| hyp_length = i - m |
| hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length] |
| if len(hyps_same_length) > 0: |
| best_hyp_same_length = sorted( |
| hyps_same_length, key=lambda x: x["score"], reverse=True |
| )[0] |
| if best_hyp_same_length["score"] - best_hyp["score"] < d_end: |
| count += 1 |
|
|
| if count == M: |
| return True |
| else: |
| return False |
|
|