File size: 5,331 Bytes
d596074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright    2022  Xiaomi Corp.        (authors: Wei Kang,
#                                                  Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import List, Optional, Tuple

import k2
import torch
from beam_search import Hypothesis, HypothesisList

from icefall.utils import AttributeDict


class DecodeStream(object):
    def __init__(
        self,
        params: AttributeDict,
        cut_id: str,
        initial_states: List[torch.Tensor],
        decoding_graph: Optional[k2.Fsa] = None,
        device: torch.device = torch.device("cpu"),
    ) -> None:
        """
        Args:
          initial_states:
            Initial decode states of the model, e.g. the return value of
            `get_init_state` in conformer.py
          decoding_graph:
            Decoding graph used for decoding, may be a TrivialGraph or a HLG.
            Used only when decoding_method is fast_beam_search.
          device:
            The device to run this stream.
        """
        if params.decoding_method == "fast_beam_search":
            assert decoding_graph is not None
            assert device == decoding_graph.device

        self.params = params
        self.cut_id = cut_id
        self.LOG_EPS = math.log(1e-10)

        self.states = initial_states

        # It contains a 2-D tensors representing the feature frames.
        self.features: torch.Tensor = None

        self.num_frames: int = 0
        # how many frames have been processed. (before subsampling).
        # we only modify this value in `func:get_feature_frames`.
        self.num_processed_frames: int = 0

        self._done: bool = False

        # The transcript of current utterance.
        self.ground_truth: str = ""

        # The decoding result (partial or final) of current utterance.
        self.hyp: List = []

        # how many frames have been processed, at encoder output
        self.done_frames: int = 0

        # The encoder_embed subsample features (T - 7) // 2
        # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
        self.pad_length = 7 + 2 * 3

        if params.decoding_method == "greedy_search":
            self.hyp = [-1] * (params.context_size - 1) + [params.blank_id]
        elif params.decoding_method == "modified_beam_search":
            self.hyps = HypothesisList()
            self.hyps.add(
                Hypothesis(
                    ys=[-1] * (params.context_size - 1) + [params.blank_id],
                    log_prob=torch.zeros(1, dtype=torch.float32, device=device),
                )
            )
        elif params.decoding_method == "fast_beam_search":
            # The rnnt_decoding_stream for fast_beam_search.
            self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
                decoding_graph
            )
        else:
            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")

    @property
    def done(self) -> bool:
        """Return True if all the features are processed."""
        return self._done

    @property
    def id(self) -> str:
        return self.cut_id

    def set_features(
        self,
        features: torch.Tensor,
        tail_pad_len: int = 0,
    ) -> None:
        """Set features tensor of current utterance."""
        assert features.dim() == 2, features.dim()
        self.features = torch.nn.functional.pad(
            features,
            (0, 0, 0, self.pad_length + tail_pad_len),
            mode="constant",
            value=self.LOG_EPS,
        )
        self.num_frames = self.features.size(0)

    def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
        """Consume chunk_size frames of features"""
        chunk_length = chunk_size + self.pad_length

        ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)

        ret_features = self.features[
            self.num_processed_frames : self.num_processed_frames + ret_length  # noqa
        ]

        self.num_processed_frames += chunk_size
        if self.num_processed_frames >= self.num_frames:
            self._done = True

        return ret_features, ret_length

    def decoding_result(self) -> List[int]:
        """Obtain current decoding result."""
        if self.params.decoding_method == "greedy_search":
            return self.hyp[self.params.context_size :]  # noqa
        elif self.params.decoding_method == "modified_beam_search":
            best_hyp = self.hyps.get_most_probable(length_norm=True)
            return best_hyp.ys[self.params.context_size :]  # noqa
        else:
            assert self.params.decoding_method == "fast_beam_search"
            return self.hyp