File size: 4,800 Bytes
d596074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence


def save_alignments(
    alignments: Dict[str, List[int]],
    subsampling_factor: int,
    filename: str,
) -> None:
    """Save alignments to a file.

    Args:
      alignments:
        A dict containing alignments. Keys of the dict are utterances and
        values are the corresponding framewise alignments after subsampling.
      subsampling_factor:
        The subsampling factor of the model.
      filename:
        Path to save the alignments.
    Returns:
      Return None.
    """
    ali_dict = {
        "subsampling_factor": subsampling_factor,
        "alignments": alignments,
    }
    torch.save(ali_dict, filename)


def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
    """Load alignments from a file.

    Args:
      filename:
        Path to the file containing alignment information.
        The file should be saved by :func:`save_alignments`.
    Returns:
      Return a tuple containing:
        - subsampling_factor: The subsampling_factor used to compute
          the alignments.
        - alignments: A dict containing utterances and their corresponding
          framewise alignment, after subsampling.
    """
    ali_dict = torch.load(filename, weights_only=False)
    subsampling_factor = ali_dict["subsampling_factor"]
    alignments = ali_dict["alignments"]
    return subsampling_factor, alignments


def convert_alignments_to_tensor(
    alignments: Dict[str, List[int]], device: torch.device
) -> Dict[str, torch.Tensor]:
    """Convert alignments from list of int to a 1-D torch.Tensor.

    Args:
      alignments:
        A dict containing alignments. Keys are utterance IDs and
        values are their corresponding frame-wise alignments.
      device:
        The device to move the alignments to.
    Returns:
      Return a dict using 1-D torch.Tensor to store the alignments.
      The dtype of the tensor are `torch.int64`. We choose `torch.int64`
      because `torch.nn.functional.one_hot` requires that.
    """
    ans = {}
    for utt_id, ali in alignments.items():
        ali = torch.tensor(ali, dtype=torch.int64, device=device)
        ans[utt_id] = ali
    return ans


def lookup_alignments(
    cut_ids: List[str],
    alignments: Dict[str, torch.Tensor],
    num_classes: int,
    log_score: float = -10,
) -> torch.Tensor:
    """Return a mask constructed from alignments by a list of cut IDs.

    The returned mask is a 3-D tensor of shape (N, T, C). For each frame,
    i.e., each row, of the returned mask, positions not corresponding to
    the alignments are filled with `log_score`, while the position
    specified by the alignment is filled with 0. For instance, if the alignments
    of two utterances are:

        [ [1, 3, 2], [1, 0, 4, 2] ]
    num_classes is 5 and log_score is -10,  then the returned mask is

        [
          [[-10, 0, -10, -10, -10],
           [-10, -10, -10, 0, -10],
           [-10, -10, 0, -10, -10],
           [0, -10, -10, -10, -10]],
          [[-10, 0, -10, -10, -10],
           [0, -10, -10, -10, -10],
           [-10, -10, -10, -10, 0],
           [-10, -10, 0, -10, -10]]
        ]
    Note: We pad the alignment of the first utterance with 0.

    Args:
      cut_ids:
        A list of utterance IDs.
      alignments:
        A dict containing alignments. The keys are utterance IDs and the values
        are framewise alignments.
      num_classes:
        The max token ID + 1 that appears in the alignments.
      log_score:
        Positions in the returned tensor not corresponding to the alignments
        are filled with this value.
    Returns:
      Return a 3-D torch.float32 tensor of shape (N, T, C).
    """
    # We assume all utterances have their alignments.
    ali = [alignments[cut_id] for cut_id in cut_ids]
    padded_ali = pad_sequence(ali, batch_first=True, padding_value=0)
    padded_one_hot = torch.nn.functional.one_hot(
        padded_ali,
        num_classes=num_classes,
    )
    mask = (1 - padded_one_hot) * float(log_score)
    return mask