File size: 9,690 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from omegaconf import DictConfig
from nemo.core.classes import NeuralModule
from nemo.core.neural_types import LengthsType, LogprobsType, NeuralType, PredictionsType
class ViterbiDecoderWithGraph(NeuralModule):
"""Viterbi Decoder with WFSA (Weighted Finite State Automaton) graphs.
Note:
Requires k2 v1.14 or later to be installed to use this module.
Decoder can be set up via the config, and optionally be passed keyword arguments as follows.
Examples:
.. code-block:: yaml
model: # Model config
...
graph_module_cfg: # Config for graph modules, e.g. ViterbiDecoderWithGraph
split_batch_size: 0
backend_cfg:
topo_type: "default" # other options: "compact", "shared_blank", "minimal"
topo_with_self_loops: true
token_lm: <token_lm_path> # must be provided for criterion_type: "map"
Args:
num_classes: Number of target classes for the decoder network to predict.
(Excluding the blank token).
backend: Which backend to use for decoding. Currently only `k2` is supported.
dec_type: Type of decoding graph to use. Choices: `topo` and `token_lm`,
with `topo` standing for the loss topology graph only
and `token_lm` for the topology composed with a token_lm graph.
return_type: Type of output. Choices: `1best` and `lattice`.
`1best` is represented as a list of 1D tensors.
`lattice` can be of type corresponding to the backend (e.g. k2.Fsa).
return_ilabels: For return_type=`1best`.
Whether to return input labels of a lattice (otherwise output labels).
output_aligned: For return_type=`1best`.
Whether the tensors length will correspond to log_probs_length
and the labels will be aligned to the frames of emission
(otherwise there will be only the necessary labels).
split_batch_size: Local batch size. Used for memory consumption reduction at the cost of speed performance.
Effective if complies 0 < split_batch_size < batch_size.
graph_module_cfg: Optional Dict of (str, value) pairs that are passed to the backend graph decoder.
"""
@property
def input_types(self):
"""Returns definitions of module input ports.
"""
return {
"log_probs": NeuralType(("B", "T", "D") if self._3d_input else ("B", "T", "T", "D"), LogprobsType()),
"input_lengths": NeuralType(tuple("B"), LengthsType()),
}
@property
def output_types(self):
"""Returns definitions of module output ports.
"""
return {"predictions": NeuralType(("B", "T"), PredictionsType())}
def __init__(
self,
num_classes,
backend: str = "k2",
dec_type: str = "topo",
return_type: str = "1best",
return_ilabels: bool = True,
output_aligned: bool = True,
split_batch_size: int = 0,
graph_module_cfg: Optional[DictConfig] = None,
):
self._blank = num_classes
self.return_ilabels = return_ilabels
self.output_aligned = output_aligned
self.split_batch_size = split_batch_size
self.dec_type = dec_type
if return_type == "1best":
self.return_lattices = False
elif return_type == "lattice":
self.return_lattices = True
elif return_type == "nbest":
raise NotImplementedError(f"return_type {return_type} is not supported at the moment")
else:
raise ValueError(f"Unsupported return_type: {return_type}")
# we assume that self._blank + 1 == num_classes
if backend == "k2":
if self.dec_type == "topo":
from nemo.collections.asr.parts.k2.graph_decoders import CtcDecoder as Decoder
elif self.dec_type == "topo_rnnt_ali":
from nemo.collections.asr.parts.k2.graph_decoders import RnntAligner as Decoder
elif self.dec_type == "token_lm":
from nemo.collections.asr.parts.k2.graph_decoders import TokenLMDecoder as Decoder
elif self.dec_type == "loose_ali":
raise NotImplementedError()
elif self.dec_type == "tlg":
raise NotImplementedError(f"dec_type {self.dec_type} is not supported at the moment")
else:
raise ValueError(f"Unsupported dec_type: {self.dec_type}")
self._decoder = Decoder(num_classes=self._blank + 1, blank=self._blank, cfg=graph_module_cfg)
elif backend == "gtn":
raise NotImplementedError("gtn-backed decoding is not implemented")
self._3d_input = self.dec_type != "topo_rnnt"
super().__init__()
def update_graph(self, graph):
"""Updates graph of the backend graph decoder.
"""
self._decoder.update_graph(graph)
def _forward_impl(self, log_probs, log_probs_length, targets=None, target_length=None):
if targets is None and target_length is not None or targets is not None and target_length is None:
raise RuntimeError(
f"Both targets and target_length have to be None or not None: {targets}, {target_length}"
)
# do not use self.return_lattices for now
if targets is None:
align = False
decode_func = lambda a, b: self._decoder.decode(
a, b, return_lattices=False, return_ilabels=self.return_ilabels, output_aligned=self.output_aligned
)
else:
align = True
decode_func = lambda a, b, c, d: self._decoder.align(
a, b, c, d, return_lattices=False, return_ilabels=False, output_aligned=True
)
batch_size = log_probs.shape[0]
if self.split_batch_size > 0 and self.split_batch_size <= batch_size:
predictions = []
probs = []
for batch_idx in range(0, batch_size, self.split_batch_size):
begin = batch_idx
end = min(begin + self.split_batch_size, batch_size)
log_probs_length_part = log_probs_length[begin:end]
log_probs_part = log_probs[begin:end, : log_probs_length_part.max()]
if align:
target_length_part = target_length[begin:end]
targets_part = targets[begin:end, : target_length_part.max()]
predictions_part, probs_part = decode_func(
log_probs_part, log_probs_length_part, targets_part, target_length_part
)
del targets_part, target_length_part
else:
predictions_part, probs_part = decode_func(log_probs_part, log_probs_length_part)
del log_probs_part, log_probs_length_part
predictions += predictions_part
probs += probs_part
else:
predictions, probs = (
decode_func(log_probs, log_probs_length, targets, target_length)
if align
else decode_func(log_probs, log_probs_length)
)
assert len(predictions) == len(probs)
return predictions, probs
@torch.no_grad()
def forward(self, log_probs, log_probs_length):
if self.dec_type == "looseali":
raise RuntimeError(f"Decoder with dec_type=`{self.dec_type}` is not intended for regular decoding.")
predictions, probs = self._forward_impl(log_probs, log_probs_length)
lengths = torch.tensor([len(pred) for pred in predictions], device=predictions[0].device)
predictions_tensor = torch.full((len(predictions), lengths.max()), self._blank).to(
device=predictions[0].device
)
probs_tensor = torch.full((len(probs), lengths.max()), 1.0).to(device=predictions[0].device)
for i, (pred, prob) in enumerate(zip(predictions, probs)):
predictions_tensor[i, : lengths[i]] = pred
probs_tensor[i, : lengths[i]] = prob
return predictions_tensor, lengths, probs_tensor
@torch.no_grad()
def align(self, log_probs, log_probs_length, targets, target_length):
len_enough = (log_probs_length >= target_length) & (target_length > 0)
if torch.all(len_enough) or self.dec_type == "looseali":
results = self._forward_impl(log_probs, log_probs_length, targets, target_length)
else:
results = self._forward_impl(
log_probs[len_enough], log_probs_length[len_enough], targets[len_enough], target_length[len_enough]
)
for i, computed in enumerate(len_enough):
if not computed:
results[0].insert(i, torch.empty(0, dtype=torch.int32))
results[1].insert(i, torch.empty(0, dtype=torch.float))
return results
|