File size: 6,714 Bytes
d596074 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | from typing import List
import k2
import torch
from torch import nn
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
def _compute_mmi_loss_exact_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0,
beam_size: float = 8.0,
) -> torch.Tensor:
"""
The function name contains `exact`, which means it uses a version of
intersection without pruning.
`optimized` in the function name means this function is optimized
in that it calls k2.intersect_dense only once
Note:
It is faster at the cost of using more memory.
Args:
dense_fsa_vec:
It contains the neural network output.
texts:
The transcript. Each element consists of space(s) separated words.
graph_compiler:
Used to build num_graphs and den_graphs
den_scale:
The scale applied to the denominator tot_scores.
Returns:
Return a scalar loss. It is the sum over utterances in a batch,
without normalization.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
device = num_graphs.device
num_fsas = num_graphs.shape[0]
assert dense_fsa_vec.dim0() == num_fsas
assert den_graphs.shape[0] == 1
# The motivation to concatenate num_graphs and den_graphs
# is to reduce the number of calls to k2.intersect_dense.
num_den_graphs = k2.cat([num_graphs, den_graphs])
# NOTE: The a_to_b_map in k2.intersect_dense must be sorted
# so the following reorders num_den_graphs.
#
# The following code computes a_to_b_map
# [0, 1, 2, ... ]
num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)
# [num_fsas, num_fsas, num_fsas, ... ]
den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)
# [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
num_den_graphs_indexes = (
torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
)
num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
# [[0, 1, 2, ...]]
a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)
# [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)
num_den_lats = k2.intersect_dense(
num_den_reordered_graphs,
dense_fsa_vec,
output_beam=beam_size,
a_to_b_map=a_to_b_map,
)
num_den_tot_scores = num_den_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
num_tot_scores = num_den_tot_scores[::2]
den_tot_scores = num_den_tot_scores[1::2]
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
def _compute_mmi_loss_exact_non_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0,
beam_size: float = 8.0,
) -> torch.Tensor:
"""
See :func:`_compute_mmi_loss_exact_optimized` for the meaning
of the arguments.
It's more readable, though it invokes k2.intersect_dense twice.
Note:
It uses less memory at the cost of speed. It is slower.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
# TODO: pass output_beam as function argument
num_lats = k2.intersect_dense(
num_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
)
den_lats = k2.intersect_dense(
den_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
)
num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
tot_scores = num_tot_scores - den_scale * den_tot_scores
tot_scores = tot_scores.masked_fill(torch.isinf(tot_scores), 0.0)
loss = -1 * tot_scores.sum()
return loss
def _compute_mmi_loss_pruned(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0,
beam_size: float = 8.0,
) -> torch.Tensor:
"""
See :func:`_compute_mmi_loss_exact_optimized` for the meaning
of the arguments.
`pruned` means it uses k2.intersect_dense_pruned
Note:
It uses the least amount of memory, but the loss is not exact due
to pruning.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0)
# the values for search_beam/output_beam/min_active_states/max_active_states
# are not tuned. You may want to tune them.
den_lats = k2.intersect_dense_pruned(
den_graphs,
dense_fsa_vec,
search_beam=20.0,
output_beam=beam_size,
min_active_states=30,
max_active_states=10000,
)
num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
class LFMMILoss(nn.Module):
"""
Computes Lattice-Free Maximum Mutual Information (LFMMI) loss.
TODO: more detailed description
"""
def __init__(
self,
graph_compiler: MmiTrainingGraphCompiler,
use_pruned_intersect: bool = False,
den_scale: float = 1.0,
beam_size: float = 8.0,
):
super().__init__()
self.graph_compiler = graph_compiler
self.den_scale = den_scale
self.use_pruned_intersect = use_pruned_intersect
self.beam_size = beam_size
def forward(
self,
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
) -> torch.Tensor:
"""
Args:
dense_fsa_vec:
It contains the neural network output.
texts:
A list of strings. Each string contains space(s) separated words.
Returns:
Return a scalar loss. It is the sum over utterances in a batch,
without normalization.
"""
if self.use_pruned_intersect:
func = _compute_mmi_loss_pruned
else:
func = _compute_mmi_loss_exact_non_optimized
# func = _compute_mmi_loss_exact_optimized
return func(
dense_fsa_vec=dense_fsa_vec,
texts=texts,
graph_compiler=self.graph_compiler,
den_scale=self.den_scale,
beam_size=self.beam_size,
)
|