Aditeya Kamlesh Prajapati commited on
Commit
8096486
·
1 Parent(s): b0e7e58

Add app and modules

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. average_checkpoints.py +36 -0
  2. cosine.py +25 -0
  3. espnet/nets/batch_beam_search.py +349 -0
  4. espnet/nets/beam_search.py +510 -0
  5. espnet/nets/ctc_prefix_score.py +357 -0
  6. espnet/nets/e2e_asr_common.py +199 -0
  7. espnet/nets/pytorch_backend/__pycache__/ctc.cpython-310.pyc +0 -0
  8. espnet/nets/pytorch_backend/__pycache__/ctc.cpython-311.pyc +0 -0
  9. espnet/nets/pytorch_backend/__pycache__/e2e_asr_conformer.cpython-310.pyc +0 -0
  10. espnet/nets/pytorch_backend/__pycache__/e2e_asr_conformer.cpython-311.pyc +0 -0
  11. espnet/nets/pytorch_backend/__pycache__/nets_utils.cpython-310.pyc +0 -0
  12. espnet/nets/pytorch_backend/__pycache__/nets_utils.cpython-311.pyc +0 -0
  13. espnet/nets/pytorch_backend/ctc.py +242 -0
  14. espnet/nets/pytorch_backend/decoder/__pycache__/transformer_decoder.cpython-310.pyc +0 -0
  15. espnet/nets/pytorch_backend/decoder/__pycache__/transformer_decoder.cpython-311.pyc +0 -0
  16. espnet/nets/pytorch_backend/decoder/transformer_decoder.py +334 -0
  17. espnet/nets/pytorch_backend/e2e_asr_conformer.py +87 -0
  18. espnet/nets/pytorch_backend/encoder/__pycache__/conformer_encoder.cpython-310.pyc +0 -0
  19. espnet/nets/pytorch_backend/encoder/__pycache__/conformer_encoder.cpython-311.pyc +0 -0
  20. espnet/nets/pytorch_backend/encoder/conformer_encoder.py +303 -0
  21. espnet/nets/pytorch_backend/frontend/__pycache__/resnet.cpython-310.pyc +0 -0
  22. espnet/nets/pytorch_backend/frontend/__pycache__/resnet.cpython-311.pyc +0 -0
  23. espnet/nets/pytorch_backend/frontend/__pycache__/resnet1d.cpython-310.pyc +0 -0
  24. espnet/nets/pytorch_backend/frontend/__pycache__/resnet1d.cpython-311.pyc +0 -0
  25. espnet/nets/pytorch_backend/frontend/resnet.py +237 -0
  26. espnet/nets/pytorch_backend/frontend/resnet1d.py +238 -0
  27. espnet/nets/pytorch_backend/nets_utils.py +306 -0
  28. espnet/nets/pytorch_backend/transformer/__init__.py +1 -0
  29. espnet/nets/pytorch_backend/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  30. espnet/nets/pytorch_backend/transformer/__pycache__/__init__.cpython-311.pyc +0 -0
  31. espnet/nets/pytorch_backend/transformer/__pycache__/add_sos_eos.cpython-310.pyc +0 -0
  32. espnet/nets/pytorch_backend/transformer/__pycache__/add_sos_eos.cpython-311.pyc +0 -0
  33. espnet/nets/pytorch_backend/transformer/__pycache__/attention.cpython-310.pyc +0 -0
  34. espnet/nets/pytorch_backend/transformer/__pycache__/attention.cpython-311.pyc +0 -0
  35. espnet/nets/pytorch_backend/transformer/__pycache__/embedding.cpython-310.pyc +0 -0
  36. espnet/nets/pytorch_backend/transformer/__pycache__/embedding.cpython-311.pyc +0 -0
  37. espnet/nets/pytorch_backend/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc +0 -0
  38. espnet/nets/pytorch_backend/transformer/__pycache__/label_smoothing_loss.cpython-311.pyc +0 -0
  39. espnet/nets/pytorch_backend/transformer/__pycache__/layer_norm.cpython-310.pyc +0 -0
  40. espnet/nets/pytorch_backend/transformer/__pycache__/layer_norm.cpython-311.pyc +0 -0
  41. espnet/nets/pytorch_backend/transformer/__pycache__/mask.cpython-310.pyc +0 -0
  42. espnet/nets/pytorch_backend/transformer/__pycache__/mask.cpython-311.pyc +0 -0
  43. espnet/nets/pytorch_backend/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc +0 -0
  44. espnet/nets/pytorch_backend/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc +0 -0
  45. espnet/nets/pytorch_backend/transformer/__pycache__/repeat.cpython-310.pyc +0 -0
  46. espnet/nets/pytorch_backend/transformer/__pycache__/repeat.cpython-311.pyc +0 -0
  47. espnet/nets/pytorch_backend/transformer/add_sos_eos.py +31 -0
  48. espnet/nets/pytorch_backend/transformer/attention.py +193 -0
  49. espnet/nets/pytorch_backend/transformer/embedding.py +184 -0
  50. espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py +63 -0
average_checkpoints.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+
6
+ def average_checkpoints(last):
7
+ avg = None
8
+ for path in last:
9
+ states = torch.load(path, map_location=lambda storage, loc: storage)["state_dict"]
10
+ states = {k[6:]: v for k, v in states.items() if k.startswith("model.")}
11
+ if avg is None:
12
+ avg = states
13
+ else:
14
+ for k in avg.keys():
15
+ avg[k] += states[k]
16
+ # average
17
+ for k in avg.keys():
18
+ if avg[k] is not None:
19
+ if avg[k].is_floating_point():
20
+ avg[k] /= len(last)
21
+ else:
22
+ avg[k] //= len(last)
23
+ return avg
24
+
25
+
26
+ def ensemble(args):
27
+ last = [
28
+ os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt")
29
+ for n in range(
30
+ args.max_epochs - 10,
31
+ args.max_epochs,
32
+ )
33
+ ]
34
+ model_path = os.path.join(args.exp_dir, args.exp_name, f"model_avg_10.pth")
35
+ torch.save(average_checkpoints(last), model_path)
36
+ return model_path
cosine.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
7
+ def __init__(
8
+ self,
9
+ optimizer: torch.optim.Optimizer,
10
+ warmup_epochs: int,
11
+ total_epochs: int,
12
+ steps_per_epoch: int,
13
+ last_epoch=-1,
14
+ verbose=False,
15
+ ):
16
+ self.warmup_steps = warmup_epochs * steps_per_epoch
17
+ self.total_steps = total_epochs * steps_per_epoch
18
+ super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
19
+
20
+ def get_lr(self):
21
+ if self._step_count < self.warmup_steps:
22
+ return [self._step_count / self.warmup_steps * base_lr for base_lr in self.base_lrs]
23
+ decay_steps = self.total_steps - self.warmup_steps
24
+ cos_val = math.cos(math.pi * (self._step_count - self.warmup_steps) / decay_steps)
25
+ return [0.5 * base_lr * (1 + cos_val) for base_lr in self.base_lrs]
espnet/nets/batch_beam_search.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parallel beam search module."""
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, NamedTuple, Tuple
5
+
6
+ import torch
7
+
8
+ from espnet.nets.beam_search import BeamSearch, Hypothesis
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+
12
+ class BatchHypothesis(NamedTuple):
13
+ """Batchfied/Vectorized hypothesis data type."""
14
+
15
+ yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
16
+ score: torch.Tensor = torch.tensor([]) # (batch,)
17
+ length: torch.Tensor = torch.tensor([]) # (batch,)
18
+ scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
19
+ states: Dict[str, Dict] = dict()
20
+
21
+ def __len__(self) -> int:
22
+ """Return a batch size."""
23
+ return len(self.length)
24
+
25
+
26
+ class BatchBeamSearch(BeamSearch):
27
+ """Batch beam search implementation."""
28
+
29
+ def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
30
+ """Convert list to batch."""
31
+ if len(hyps) == 0:
32
+ return BatchHypothesis()
33
+ yseq = pad_sequence(
34
+ [h.yseq for h in hyps], batch_first=True, padding_value=self.eos
35
+ )
36
+ return BatchHypothesis(
37
+ yseq=yseq,
38
+ length=torch.tensor(
39
+ [len(h.yseq) for h in hyps], dtype=torch.int64, device=yseq.device
40
+ ),
41
+ score=torch.tensor([h.score for h in hyps]).to(yseq.device),
42
+ scores={
43
+ k: torch.tensor([h.scores[k] for h in hyps], device=yseq.device)
44
+ for k in self.scorers
45
+ },
46
+ states={k: [h.states[k] for h in hyps] for k in self.scorers},
47
+ )
48
+
49
+ def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
50
+ return BatchHypothesis(
51
+ yseq=hyps.yseq[ids],
52
+ score=hyps.score[ids],
53
+ length=hyps.length[ids],
54
+ scores={k: v[ids] for k, v in hyps.scores.items()},
55
+ states={
56
+ k: [self.scorers[k].select_state(v, i) for i in ids]
57
+ for k, v in hyps.states.items()
58
+ },
59
+ )
60
+
61
+ def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
62
+ return Hypothesis(
63
+ yseq=hyps.yseq[i, : hyps.length[i]],
64
+ score=hyps.score[i],
65
+ scores={k: v[i] for k, v in hyps.scores.items()},
66
+ states={
67
+ k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
68
+ },
69
+ )
70
+
71
+ def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
72
+ """Revert batch to list."""
73
+ return [
74
+ Hypothesis(
75
+ yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
76
+ score=batch_hyps.score[i],
77
+ scores={k: batch_hyps.scores[k][i] for k in self.scorers},
78
+ states={
79
+ k: v.select_state(batch_hyps.states[k], i)
80
+ for k, v in self.scorers.items()
81
+ },
82
+ )
83
+ for i in range(len(batch_hyps.length))
84
+ ]
85
+
86
+ def batch_beam(
87
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
88
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
89
+ """Batch-compute topk full token ids and partial token ids.
90
+
91
+ Args:
92
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
93
+ Its shape is `(n_beam, self.vocab_size)`.
94
+ ids (torch.Tensor): The partial token ids to compute topk.
95
+ Its shape is `(n_beam, self.pre_beam_size)`.
96
+
97
+ Returns:
98
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
99
+ The topk full (prev_hyp, new_token) ids
100
+ and partial (prev_hyp, new_token) ids.
101
+ Their shapes are all `(self.beam_size,)`
102
+
103
+ """
104
+ top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
105
+ # Because of the flatten above, `top_ids` is organized as:
106
+ # [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
107
+ # where V is `self.n_vocab` and K is `self.beam_size`
108
+ prev_hyp_ids = torch.div(top_ids, self.n_vocab, rounding_mode="trunc")
109
+ new_token_ids = top_ids % self.n_vocab
110
+ return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
111
+
112
+ def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
113
+ """Get an initial hypothesis data.
114
+
115
+ Args:
116
+ x (torch.Tensor): The encoder output feature
117
+
118
+ Returns:
119
+ Hypothesis: The initial hypothesis.
120
+
121
+ """
122
+ init_states = dict()
123
+ init_scores = dict()
124
+ for k, d in self.scorers.items():
125
+ init_states[k] = d.batch_init_state(x)
126
+ init_scores[k] = 0.0
127
+ return self.batchfy(
128
+ [
129
+ Hypothesis(
130
+ score=0.0,
131
+ scores=init_scores,
132
+ states=init_states,
133
+ yseq=torch.tensor([self.sos], device=x.device),
134
+ )
135
+ ]
136
+ )
137
+
138
+ def score_full(
139
+ self, hyp: BatchHypothesis, x: torch.Tensor
140
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
141
+ """Score new hypothesis by `self.full_scorers`.
142
+
143
+ Args:
144
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
145
+ x (torch.Tensor): Corresponding input feature
146
+
147
+ Returns:
148
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
149
+ score dict of `hyp` that has string keys of `self.full_scorers`
150
+ and tensor score values of shape: `(self.n_vocab,)`,
151
+ and state dict that has string keys
152
+ and state values of `self.full_scorers`
153
+
154
+ """
155
+ scores = dict()
156
+ states = dict()
157
+ for k, d in self.full_scorers.items():
158
+ scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
159
+ return scores, states
160
+
161
+ def score_partial(
162
+ self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
163
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
164
+ """Score new hypothesis by `self.full_scorers`.
165
+
166
+ Args:
167
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
168
+ ids (torch.Tensor): 2D tensor of new partial tokens to score
169
+ x (torch.Tensor): Corresponding input feature
170
+
171
+ Returns:
172
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
173
+ score dict of `hyp` that has string keys of `self.full_scorers`
174
+ and tensor score values of shape: `(self.n_vocab,)`,
175
+ and state dict that has string keys
176
+ and state values of `self.full_scorers`
177
+
178
+ """
179
+ scores = dict()
180
+ states = dict()
181
+ for k, d in self.part_scorers.items():
182
+ scores[k], states[k] = d.batch_score_partial(
183
+ hyp.yseq, ids, hyp.states[k], x
184
+ )
185
+ return scores, states
186
+
187
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
188
+ """Merge states for new hypothesis.
189
+
190
+ Args:
191
+ states: states of `self.full_scorers`
192
+ part_states: states of `self.part_scorers`
193
+ part_idx (int): The new token id for `part_scores`
194
+
195
+ Returns:
196
+ Dict[str, torch.Tensor]: The new score dict.
197
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
198
+ Its values are states of the scorers.
199
+
200
+ """
201
+ new_states = dict()
202
+ for k, v in states.items():
203
+ new_states[k] = v
204
+ for k, v in part_states.items():
205
+ new_states[k] = v
206
+ return new_states
207
+
208
+ def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
209
+ """Search new tokens for running hypotheses and encoded speech x.
210
+
211
+ Args:
212
+ running_hyps (BatchHypothesis): Running hypotheses on beam
213
+ x (torch.Tensor): Encoded speech feature (T, D)
214
+
215
+ Returns:
216
+ BatchHypothesis: Best sorted hypotheses
217
+
218
+ """
219
+ n_batch = len(running_hyps)
220
+ part_ids = None # no pre-beam
221
+ # batch scoring
222
+ weighted_scores = torch.zeros(
223
+ n_batch, self.n_vocab, dtype=x.dtype, device=x.device
224
+ )
225
+ scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
226
+ for k in self.full_scorers:
227
+ weighted_scores += self.weights[k] * scores[k]
228
+ # partial scoring
229
+ if self.do_pre_beam:
230
+ pre_beam_scores = (
231
+ weighted_scores
232
+ if self.pre_beam_score_key == "full"
233
+ else scores[self.pre_beam_score_key]
234
+ )
235
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
236
+ # NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
237
+ # full-size score matrices, which has non-zero scores for part_ids and zeros
238
+ # for others.
239
+ part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
240
+ for k in self.part_scorers:
241
+ weighted_scores += self.weights[k] * part_scores[k]
242
+ # add previous hyp scores
243
+ weighted_scores += running_hyps.score.to(
244
+ dtype=x.dtype, device=x.device
245
+ ).unsqueeze(1)
246
+
247
+ # TODO(karita): do not use list. use batch instead
248
+ # see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
249
+ # update hyps
250
+ best_hyps = []
251
+ prev_hyps = self.unbatchfy(running_hyps)
252
+ for (
253
+ full_prev_hyp_id,
254
+ full_new_token_id,
255
+ part_prev_hyp_id,
256
+ part_new_token_id,
257
+ ) in zip(*self.batch_beam(weighted_scores, part_ids)):
258
+ prev_hyp = prev_hyps[full_prev_hyp_id]
259
+ best_hyps.append(
260
+ Hypothesis(
261
+ score=weighted_scores[full_prev_hyp_id, full_new_token_id],
262
+ yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
263
+ scores=self.merge_scores(
264
+ prev_hyp.scores,
265
+ {k: v[full_prev_hyp_id] for k, v in scores.items()},
266
+ full_new_token_id,
267
+ {k: v[part_prev_hyp_id] for k, v in part_scores.items()},
268
+ part_new_token_id,
269
+ ),
270
+ states=self.merge_states(
271
+ {
272
+ k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
273
+ for k, v in states.items()
274
+ },
275
+ {
276
+ k: self.part_scorers[k].select_state(
277
+ v, part_prev_hyp_id, part_new_token_id
278
+ )
279
+ for k, v in part_states.items()
280
+ },
281
+ part_new_token_id,
282
+ ),
283
+ )
284
+ )
285
+ return self.batchfy(best_hyps)
286
+
287
+ def post_process(
288
+ self,
289
+ i: int,
290
+ maxlen: int,
291
+ maxlenratio: float,
292
+ running_hyps: BatchHypothesis,
293
+ ended_hyps: List[Hypothesis],
294
+ ) -> BatchHypothesis:
295
+ """Perform post-processing of beam search iterations.
296
+
297
+ Args:
298
+ i (int): The length of hypothesis tokens.
299
+ maxlen (int): The maximum length of tokens in beam search.
300
+ maxlenratio (int): The maximum length ratio in beam search.
301
+ running_hyps (BatchHypothesis): The running hypotheses in beam search.
302
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
303
+
304
+ Returns:
305
+ BatchHypothesis: The new running hypotheses.
306
+
307
+ """
308
+ n_batch = running_hyps.yseq.shape[0]
309
+ logging.debug(f"the number of running hypothes: {n_batch}")
310
+ if self.token_list is not None:
311
+ logging.debug(
312
+ "best hypo: "
313
+ + "".join(
314
+ [
315
+ self.token_list[x]
316
+ for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
317
+ ]
318
+ )
319
+ )
320
+ # add eos in the final loop to avoid that there are no ended hyps
321
+ if i == maxlen - 1:
322
+ logging.debug("adding <eos> in the last position in the loop")
323
+ yseq_eos = torch.cat(
324
+ (
325
+ running_hyps.yseq,
326
+ torch.full(
327
+ (n_batch, 1),
328
+ self.eos,
329
+ device=running_hyps.yseq.device,
330
+ dtype=torch.int64,
331
+ ),
332
+ ),
333
+ 1,
334
+ )
335
+ running_hyps.yseq.resize_as_(yseq_eos)
336
+ running_hyps.yseq[:] = yseq_eos
337
+ running_hyps.length[:] = yseq_eos.shape[1]
338
+
339
+ # add ended hypotheses to a final list, and removed them from current hypotheses
340
+ # (this will be a probmlem, number of hyps < beam)
341
+ is_eos = (
342
+ running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
343
+ == self.eos
344
+ )
345
+ for b in torch.nonzero(is_eos, as_tuple=False).view(-1):
346
+ hyp = self._select(running_hyps, b)
347
+ ended_hyps.append(hyp)
348
+ remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1)
349
+ return self._batch_select(running_hyps, remained_ids)
espnet/nets/beam_search.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Beam search module."""
2
+
3
+ import logging
4
+ from itertools import chain
5
+ from typing import Any, Dict, List, NamedTuple, Tuple, Union
6
+
7
+ import torch
8
+
9
+ from espnet.nets.e2e_asr_common import end_detect
10
+ from espnet.nets.scorer_interface import PartialScorerInterface, ScorerInterface
11
+
12
+
13
+ class Hypothesis(NamedTuple):
14
+ """Hypothesis data type."""
15
+
16
+ yseq: torch.Tensor
17
+ score: Union[float, torch.Tensor] = 0
18
+ scores: Dict[str, Union[float, torch.Tensor]] = dict()
19
+ states: Dict[str, Any] = dict()
20
+
21
+ def asdict(self) -> dict:
22
+ """Convert data to JSON-friendly dict."""
23
+ return self._replace(
24
+ yseq=self.yseq.tolist(),
25
+ score=float(self.score),
26
+ scores={k: float(v) for k, v in self.scores.items()},
27
+ )._asdict()
28
+
29
+
30
+ class BeamSearch(torch.nn.Module):
31
+ """Beam search implementation."""
32
+
33
+ def __init__(
34
+ self,
35
+ scorers: Dict[str, ScorerInterface],
36
+ weights: Dict[str, float],
37
+ beam_size: int,
38
+ vocab_size: int,
39
+ sos: int,
40
+ eos: int,
41
+ token_list: List[str] = None,
42
+ pre_beam_ratio: float = 1.5,
43
+ pre_beam_score_key: str = None,
44
+ ):
45
+ """Initialize beam search.
46
+
47
+ Args:
48
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
49
+ e.g., Decoder, CTCPrefixScorer, LM
50
+ The scorer will be ignored if it is `None`
51
+ weights (dict[str, float]): Dict of weights for each scorers
52
+ The scorer will be ignored if its weight is 0
53
+ beam_size (int): The number of hypotheses kept during search
54
+ vocab_size (int): The number of vocabulary
55
+ sos (int): Start of sequence id
56
+ eos (int): End of sequence id
57
+ token_list (list[str]): List of tokens for debug log
58
+ pre_beam_score_key (str): key of scores to perform pre-beam search
59
+ pre_beam_ratio (float): beam size in the pre-beam search
60
+ will be `int(pre_beam_ratio * beam_size)`
61
+
62
+ """
63
+ super().__init__()
64
+ # set scorers
65
+ self.weights = weights
66
+ self.scorers = dict()
67
+ self.full_scorers = dict()
68
+ self.part_scorers = dict()
69
+ # this module dict is required for recursive cast
70
+ # `self.to(device, dtype)` in `recog.py`
71
+ self.nn_dict = torch.nn.ModuleDict()
72
+ for k, v in scorers.items():
73
+ w = weights.get(k, 0)
74
+ if w == 0 or v is None:
75
+ continue
76
+ assert isinstance(
77
+ v, ScorerInterface
78
+ ), f"{k} ({type(v)}) does not implement ScorerInterface"
79
+ self.scorers[k] = v
80
+ if isinstance(v, PartialScorerInterface):
81
+ self.part_scorers[k] = v
82
+ else:
83
+ self.full_scorers[k] = v
84
+ if isinstance(v, torch.nn.Module):
85
+ self.nn_dict[k] = v
86
+
87
+ # set configurations
88
+ self.sos = sos
89
+ self.eos = eos
90
+ self.token_list = token_list
91
+ self.pre_beam_size = int(pre_beam_ratio * beam_size)
92
+ self.beam_size = beam_size
93
+ self.n_vocab = vocab_size
94
+ if (
95
+ pre_beam_score_key is not None
96
+ and pre_beam_score_key != "full"
97
+ and pre_beam_score_key not in self.full_scorers
98
+ ):
99
+ raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
100
+ self.pre_beam_score_key = pre_beam_score_key
101
+ self.do_pre_beam = (
102
+ self.pre_beam_score_key is not None
103
+ and self.pre_beam_size < self.n_vocab
104
+ and len(self.part_scorers) > 0
105
+ )
106
+
107
+ def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
108
+ """Get an initial hypothesis data.
109
+
110
+ Args:
111
+ x (torch.Tensor): The encoder output feature
112
+
113
+ Returns:
114
+ Hypothesis: The initial hypothesis.
115
+
116
+ """
117
+ init_states = dict()
118
+ init_scores = dict()
119
+ for k, d in self.scorers.items():
120
+ init_states[k] = d.init_state(x)
121
+ init_scores[k] = 0.0
122
+ return [
123
+ Hypothesis(
124
+ score=0.0,
125
+ scores=init_scores,
126
+ states=init_states,
127
+ yseq=torch.tensor([self.sos], device=x.device),
128
+ )
129
+ ]
130
+
131
+ @staticmethod
132
+ def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
133
+ """Append new token to prefix tokens.
134
+
135
+ Args:
136
+ xs (torch.Tensor): The prefix token
137
+ x (int): The new token to append
138
+
139
+ Returns:
140
+ torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
141
+
142
+ """
143
+ x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
144
+ return torch.cat((xs, x))
145
+
146
+ def score_full(
147
+ self, hyp: Hypothesis, x: torch.Tensor
148
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
149
+ """Score new hypothesis by `self.full_scorers`.
150
+
151
+ Args:
152
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
153
+ x (torch.Tensor): Corresponding input feature
154
+
155
+ Returns:
156
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
157
+ score dict of `hyp` that has string keys of `self.full_scorers`
158
+ and tensor score values of shape: `(self.n_vocab,)`,
159
+ and state dict that has string keys
160
+ and state values of `self.full_scorers`
161
+
162
+ """
163
+ scores = dict()
164
+ states = dict()
165
+ for k, d in self.full_scorers.items():
166
+ scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
167
+ return scores, states
168
+
169
+ def score_partial(
170
+ self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
171
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
172
+ """Score new hypothesis by `self.part_scorers`.
173
+
174
+ Args:
175
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
176
+ ids (torch.Tensor): 1D tensor of new partial tokens to score
177
+ x (torch.Tensor): Corresponding input feature
178
+
179
+ Returns:
180
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
181
+ score dict of `hyp` that has string keys of `self.part_scorers`
182
+ and tensor score values of shape: `(len(ids),)`,
183
+ and state dict that has string keys
184
+ and state values of `self.part_scorers`
185
+
186
+ """
187
+ scores = dict()
188
+ states = dict()
189
+ for k, d in self.part_scorers.items():
190
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
191
+ return scores, states
192
+
193
+ def beam(
194
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
195
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
196
+ """Compute topk full token ids and partial token ids.
197
+
198
+ Args:
199
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
200
+ Its shape is `(self.n_vocab,)`.
201
+ ids (torch.Tensor): The partial token ids to compute topk
202
+
203
+ Returns:
204
+ Tuple[torch.Tensor, torch.Tensor]:
205
+ The topk full token ids and partial token ids.
206
+ Their shapes are `(self.beam_size,)`
207
+
208
+ """
209
+ # no pre beam performed
210
+ if weighted_scores.size(0) == ids.size(0):
211
+ top_ids = weighted_scores.topk(self.beam_size)[1]
212
+ return top_ids, top_ids
213
+
214
+ # mask pruned in pre-beam not to select in topk
215
+ tmp = weighted_scores[ids]
216
+ weighted_scores[:] = -float("inf")
217
+ weighted_scores[ids] = tmp
218
+ top_ids = weighted_scores.topk(self.beam_size)[1]
219
+ local_ids = weighted_scores[ids].topk(self.beam_size)[1]
220
+ return top_ids, local_ids
221
+
222
+ @staticmethod
223
+ def merge_scores(
224
+ prev_scores: Dict[str, float],
225
+ next_full_scores: Dict[str, torch.Tensor],
226
+ full_idx: int,
227
+ next_part_scores: Dict[str, torch.Tensor],
228
+ part_idx: int,
229
+ ) -> Dict[str, torch.Tensor]:
230
+ """Merge scores for new hypothesis.
231
+
232
+ Args:
233
+ prev_scores (Dict[str, float]):
234
+ The previous hypothesis scores by `self.scorers`
235
+ next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
236
+ full_idx (int): The next token id for `next_full_scores`
237
+ next_part_scores (Dict[str, torch.Tensor]):
238
+ scores of partial tokens by `self.part_scorers`
239
+ part_idx (int): The new token id for `next_part_scores`
240
+
241
+ Returns:
242
+ Dict[str, torch.Tensor]: The new score dict.
243
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
244
+ Its values are scalar tensors by the scorers.
245
+
246
+ """
247
+ new_scores = dict()
248
+ for k, v in next_full_scores.items():
249
+ new_scores[k] = prev_scores[k] + v[full_idx]
250
+ for k, v in next_part_scores.items():
251
+ new_scores[k] = prev_scores[k] + v[part_idx]
252
+ return new_scores
253
+
254
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
255
+ """Merge states for new hypothesis.
256
+
257
+ Args:
258
+ states: states of `self.full_scorers`
259
+ part_states: states of `self.part_scorers`
260
+ part_idx (int): The new token id for `part_scores`
261
+
262
+ Returns:
263
+ Dict[str, torch.Tensor]: The new score dict.
264
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
265
+ Its values are states of the scorers.
266
+
267
+ """
268
+ new_states = dict()
269
+ for k, v in states.items():
270
+ new_states[k] = v
271
+ for k, d in self.part_scorers.items():
272
+ new_states[k] = d.select_state(part_states[k], part_idx)
273
+ return new_states
274
+
275
+ def search(
276
+ self, running_hyps: List[Hypothesis], x: torch.Tensor
277
+ ) -> List[Hypothesis]:
278
+ """Search new tokens for running hypotheses and encoded speech x.
279
+
280
+ Args:
281
+ running_hyps (List[Hypothesis]): Running hypotheses on beam
282
+ x (torch.Tensor): Encoded speech feature (T, D)
283
+
284
+ Returns:
285
+ List[Hypotheses]: Best sorted hypotheses
286
+
287
+ """
288
+ best_hyps = []
289
+ part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
290
+ for hyp in running_hyps:
291
+ # scoring
292
+ weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
293
+ scores, states = self.score_full(hyp, x)
294
+ for k in self.full_scorers:
295
+ weighted_scores += self.weights[k] * scores[k]
296
+ # partial scoring
297
+ if self.do_pre_beam:
298
+ pre_beam_scores = (
299
+ weighted_scores
300
+ if self.pre_beam_score_key == "full"
301
+ else scores[self.pre_beam_score_key]
302
+ )
303
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
304
+ part_scores, part_states = self.score_partial(hyp, part_ids, x)
305
+ for k in self.part_scorers:
306
+ weighted_scores[part_ids] += self.weights[k] * part_scores[k]
307
+ # add previous hyp score
308
+ weighted_scores += hyp.score
309
+
310
+ # update hyps
311
+ for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
312
+ # will be (2 x beam at most)
313
+ best_hyps.append(
314
+ Hypothesis(
315
+ score=weighted_scores[j],
316
+ yseq=self.append_token(hyp.yseq, j),
317
+ scores=self.merge_scores(
318
+ hyp.scores, scores, j, part_scores, part_j
319
+ ),
320
+ states=self.merge_states(states, part_states, part_j),
321
+ )
322
+ )
323
+
324
+ # sort and prune 2 x beam -> beam
325
+ best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
326
+ : min(len(best_hyps), self.beam_size)
327
+ ]
328
+ return best_hyps
329
+
330
+ def forward(
331
+ self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
332
+ ) -> List[Hypothesis]:
333
+ """Perform beam search.
334
+
335
+ Args:
336
+ x (torch.Tensor): Encoded speech feature (T, D)
337
+ maxlenratio (float): Input length ratio to obtain max output length.
338
+ If maxlenratio=0.0 (default), it uses a end-detect function
339
+ to automatically find maximum hypothesis lengths
340
+ If maxlenratio<0.0, its absolute value is interpreted
341
+ as a constant max output length.
342
+ minlenratio (float): Input length ratio to obtain min output length.
343
+
344
+ Returns:
345
+ list[Hypothesis]: N-best decoding results
346
+
347
+ """
348
+ # set length bounds
349
+ if maxlenratio == 0:
350
+ maxlen = x.shape[0]
351
+ elif maxlenratio < 0:
352
+ maxlen = -1 * int(maxlenratio)
353
+ else:
354
+ maxlen = max(1, int(maxlenratio * x.size(0)))
355
+ minlen = int(minlenratio * x.size(0))
356
+ logging.debug("decoder input length: " + str(x.shape[0]))
357
+ logging.debug("max output length: " + str(maxlen))
358
+ logging.debug("min output length: " + str(minlen))
359
+
360
+ # main loop of prefix search
361
+ running_hyps = self.init_hyp(x)
362
+ ended_hyps = []
363
+ for i in range(maxlen):
364
+ logging.debug("position " + str(i))
365
+ best = self.search(running_hyps, x)
366
+ # post process of one iteration
367
+ running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
368
+ # end detection
369
+ if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
370
+ logging.debug(f"end detected at {i}")
371
+ break
372
+ if len(running_hyps) == 0:
373
+ logging.debug("no hypothesis. Finish decoding.")
374
+ break
375
+ else:
376
+ logging.debug(f"remained hypotheses: {len(running_hyps)}")
377
+
378
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
379
+ # check the number of hypotheses reaching to eos
380
+ if len(nbest_hyps) == 0:
381
+ logging.warning(
382
+ "there is no N-best results, perform recognition "
383
+ "again with smaller minlenratio."
384
+ )
385
+ return (
386
+ []
387
+ if minlenratio < 0.1
388
+ else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
389
+ )
390
+
391
+ # report the best result
392
+ best = nbest_hyps[0]
393
+ for k, v in best.scores.items():
394
+ logging.debug(
395
+ f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
396
+ )
397
+ logging.debug(f"total log probability: {best.score:.2f}")
398
+ logging.debug(f"normalized log probability: {best.score / len(best.yseq):.2f}")
399
+ logging.debug(f"total number of ended hypotheses: {len(nbest_hyps)}")
400
+ if self.token_list is not None:
401
+ logging.debug(
402
+ "best hypo: "
403
+ + "".join([self.token_list[x] for x in best.yseq[1:-1]])
404
+ + "\n"
405
+ )
406
+ return nbest_hyps
407
+
408
+ def post_process(
409
+ self,
410
+ i: int,
411
+ maxlen: int,
412
+ maxlenratio: float,
413
+ running_hyps: List[Hypothesis],
414
+ ended_hyps: List[Hypothesis],
415
+ ) -> List[Hypothesis]:
416
+ """Perform post-processing of beam search iterations.
417
+
418
+ Args:
419
+ i (int): The length of hypothesis tokens.
420
+ maxlen (int): The maximum length of tokens in beam search.
421
+ maxlenratio (int): The maximum length ratio in beam search.
422
+ running_hyps (List[Hypothesis]): The running hypotheses in beam search.
423
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
424
+
425
+ Returns:
426
+ List[Hypothesis]: The new running hypotheses.
427
+
428
+ """
429
+ logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
430
+ if self.token_list is not None:
431
+ logging.debug(
432
+ "best hypo: "
433
+ + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
434
+ )
435
+ # add eos in the final loop to avoid that there are no ended hyps
436
+ if i == maxlen - 1:
437
+ logging.debug("adding <eos> in the last position in the loop")
438
+ running_hyps = [
439
+ h._replace(yseq=self.append_token(h.yseq, self.eos))
440
+ for h in running_hyps
441
+ ]
442
+
443
+ # add ended hypotheses to a final list, and removed them from current hypotheses
444
+ # (this will be a problem, number of hyps < beam)
445
+ remained_hyps = []
446
+ for hyp in running_hyps:
447
+ if hyp.yseq[-1] == self.eos:
448
+ # e.g., Word LM needs to add final <eos> score
449
+ for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
450
+ s = d.final_score(hyp.states[k])
451
+ hyp.scores[k] += s
452
+ hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
453
+ ended_hyps.append(hyp)
454
+ else:
455
+ remained_hyps.append(hyp)
456
+ return remained_hyps
457
+
458
+
459
+ def beam_search(
460
+ x: torch.Tensor,
461
+ sos: int,
462
+ eos: int,
463
+ beam_size: int,
464
+ vocab_size: int,
465
+ scorers: Dict[str, ScorerInterface],
466
+ weights: Dict[str, float],
467
+ token_list: List[str] = None,
468
+ maxlenratio: float = 0.0,
469
+ minlenratio: float = 0.0,
470
+ pre_beam_ratio: float = 1.5,
471
+ pre_beam_score_key: str = "full",
472
+ ) -> list:
473
+ """Perform beam search with scorers.
474
+
475
+ Args:
476
+ x (torch.Tensor): Encoded speech feature (T, D)
477
+ sos (int): Start of sequence id
478
+ eos (int): End of sequence id
479
+ beam_size (int): The number of hypotheses kept during search
480
+ vocab_size (int): The number of vocabulary
481
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
482
+ e.g., Decoder, CTCPrefixScorer, LM
483
+ The scorer will be ignored if it is `None`
484
+ weights (dict[str, float]): Dict of weights for each scorers
485
+ The scorer will be ignored if its weight is 0
486
+ token_list (list[str]): List of tokens for debug log
487
+ maxlenratio (float): Input length ratio to obtain max output length.
488
+ If maxlenratio=0.0 (default), it uses a end-detect function
489
+ to automatically find maximum hypothesis lengths
490
+ minlenratio (float): Input length ratio to obtain min output length.
491
+ pre_beam_score_key (str): key of scores to perform pre-beam search
492
+ pre_beam_ratio (float): beam size in the pre-beam search
493
+ will be `int(pre_beam_ratio * beam_size)`
494
+
495
+ Returns:
496
+ list: N-best decoding results
497
+
498
+ """
499
+ ret = BeamSearch(
500
+ scorers,
501
+ weights,
502
+ beam_size=beam_size,
503
+ vocab_size=vocab_size,
504
+ pre_beam_ratio=pre_beam_ratio,
505
+ pre_beam_score_key=pre_beam_score_key,
506
+ sos=sos,
507
+ eos=eos,
508
+ token_list=token_list,
509
+ ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
510
+ return [h.asdict() for h in ret]
espnet/nets/ctc_prefix_score.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ class CTCPrefixScoreTH(object):
11
+ """Batch processing of CTCPrefixScore
12
+
13
+ which is based on Algorithm 2 in WATANABE et al.
14
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
15
+ but extended to efficiently compute the label probablities for multiple
16
+ hypotheses simultaneously
17
+ See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
18
+ Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
19
+ """
20
+
21
+ def __init__(self, x, xlens, blank, eos, margin=0):
22
+ """Construct CTC prefix scorer
23
+
24
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
25
+ :param torch.Tensor xlens: input lengths (B,)
26
+ :param int blank: blank label id
27
+ :param int eos: end-of-sequence id
28
+ :param int margin: margin parameter for windowing (0 means no windowing)
29
+ """
30
+ # In the comment lines,
31
+ # we assume T: input_length, B: batch size, W: beam width, O: output dim.
32
+ self.logzero = -10000000000.0
33
+ self.blank = blank
34
+ self.eos = eos
35
+ self.batch = x.size(0)
36
+ self.input_length = x.size(1)
37
+ self.odim = x.size(2)
38
+ self.dtype = x.dtype
39
+ self.device = (
40
+ torch.device("cuda:%d" % x.get_device())
41
+ if x.is_cuda
42
+ else torch.device("cpu")
43
+ )
44
+ # Pad the rest of posteriors in the batch
45
+ # TODO(takaaki-hori): need a better way without for-loops
46
+ for i, l in enumerate(xlens):
47
+ if l < self.input_length:
48
+ x[i, l:, :] = self.logzero
49
+ x[i, l:, blank] = 0
50
+ # Reshape input x
51
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
52
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
53
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
54
+ self.end_frames = torch.as_tensor(xlens) - 1
55
+
56
+ # Setup CTC windowing
57
+ self.margin = margin
58
+ if margin > 0:
59
+ self.frame_ids = torch.arange(
60
+ self.input_length, dtype=self.dtype, device=self.device
61
+ )
62
+ # Base indices for index conversion
63
+ self.idx_bh = None
64
+ self.idx_b = torch.arange(self.batch, device=self.device)
65
+ self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
66
+
67
+ def __call__(self, y, state, scoring_ids=None, att_w=None):
68
+ """Compute CTC prefix scores for next labels
69
+
70
+ :param list y: prefix label sequences
71
+ :param tuple state: previous CTC state
72
+ :param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
73
+ :param torch.Tensor att_w: attention weights to decide CTC window
74
+ :return new_state, ctc_local_scores (BW, O)
75
+ """
76
+ output_length = len(y[0]) - 1 # ignore sos
77
+ last_ids = [yi[-1] for yi in y] # last output label ids
78
+ n_bh = len(last_ids) # batch * hyps
79
+ n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
80
+ self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
81
+ # prepare state info
82
+ if state is None:
83
+ r_prev = torch.full(
84
+ (self.input_length, 2, self.batch, n_hyps),
85
+ self.logzero,
86
+ dtype=self.dtype,
87
+ device=self.device,
88
+ )
89
+ r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
90
+ r_prev = r_prev.view(-1, 2, n_bh)
91
+ s_prev = 0.0
92
+ f_min_prev = 0
93
+ f_max_prev = 1
94
+ else:
95
+ r_prev, s_prev, f_min_prev, f_max_prev = state
96
+
97
+ # select input dimensions for scoring
98
+ if self.scoring_num > 0:
99
+ scoring_idmap = torch.full(
100
+ (n_bh, self.odim), -1, dtype=torch.long, device=self.device
101
+ )
102
+ snum = self.scoring_num
103
+ if self.idx_bh is None or n_bh > len(self.idx_bh):
104
+ self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
105
+ scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
106
+ snum, device=self.device
107
+ )
108
+ scoring_idx = (
109
+ scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
110
+ ).view(-1)
111
+ x_ = torch.index_select(
112
+ self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
113
+ ).view(2, -1, n_bh, snum)
114
+ else:
115
+ scoring_ids = None
116
+ scoring_idmap = None
117
+ snum = self.odim
118
+ x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
119
+
120
+ # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
121
+ # that corresponds to r_t^n(h) and r_t^b(h) in a batch.
122
+ r = torch.full(
123
+ (self.input_length, 2, n_bh, snum),
124
+ self.logzero,
125
+ dtype=self.dtype,
126
+ device=self.device,
127
+ )
128
+ if output_length == 0:
129
+ r[0, 0] = x_[0, 0]
130
+
131
+ r_sum = torch.logsumexp(r_prev, 1)
132
+ log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
133
+ if scoring_ids is not None:
134
+ for idx in range(n_bh):
135
+ pos = scoring_idmap[idx, last_ids[idx]]
136
+ if pos >= 0:
137
+ log_phi[:, idx, pos] = r_prev[:, 1, idx]
138
+ else:
139
+ for idx in range(n_bh):
140
+ log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
141
+
142
+ # decide start and end frames based on attention weights
143
+ if att_w is not None and self.margin > 0:
144
+ f_arg = torch.matmul(att_w, self.frame_ids)
145
+ f_min = max(int(f_arg.min().cpu()), f_min_prev)
146
+ f_max = max(int(f_arg.max().cpu()), f_max_prev)
147
+ start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
148
+ end = min(f_max + self.margin, self.input_length)
149
+ else:
150
+ f_min = f_max = 0
151
+ start = max(output_length, 1)
152
+ end = self.input_length
153
+
154
+ # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
155
+ for t in range(start, end):
156
+ rp = r[t - 1]
157
+ rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
158
+ 2, 2, n_bh, snum
159
+ )
160
+ r[t] = torch.logsumexp(rr, 1) + x_[:, t]
161
+
162
+ # compute log prefix probabilities log(psi)
163
+ log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
164
+ if scoring_ids is not None:
165
+ log_psi = torch.full(
166
+ (n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device
167
+ )
168
+ log_psi_ = torch.logsumexp(
169
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
170
+ dim=0,
171
+ )
172
+ for si in range(n_bh):
173
+ log_psi[si, scoring_ids[si]] = log_psi_[si]
174
+ else:
175
+ log_psi = torch.logsumexp(
176
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
177
+ dim=0,
178
+ )
179
+
180
+ for si in range(n_bh):
181
+ log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
182
+
183
+ # exclude blank probs
184
+ log_psi[:, self.blank] = self.logzero
185
+
186
+ return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
187
+
188
+ def index_select_state(self, state, best_ids):
189
+ """Select CTC states according to best ids
190
+
191
+ :param state : CTC state
192
+ :param best_ids : index numbers selected by beam pruning (B, W)
193
+ :return selected_state
194
+ """
195
+ r, s, f_min, f_max, scoring_idmap = state
196
+ # convert ids to BHO space
197
+ n_bh = len(s)
198
+ n_hyps = n_bh // self.batch
199
+ vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
200
+ # select hypothesis scores
201
+ s_new = torch.index_select(s.view(-1), 0, vidx)
202
+ s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
203
+ # convert ids to BHS space (S: scoring_num)
204
+ if scoring_idmap is not None:
205
+ snum = self.scoring_num
206
+ hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
207
+ -1
208
+ )
209
+ label_ids = torch.fmod(best_ids, self.odim).view(-1)
210
+ score_idx = scoring_idmap[hyp_idx, label_ids]
211
+ score_idx[score_idx == -1] = 0
212
+ vidx = score_idx + hyp_idx * snum
213
+ else:
214
+ snum = self.odim
215
+ # select forward probabilities
216
+ r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
217
+ -1, 2, n_bh
218
+ )
219
+ return r_new, s_new, f_min, f_max
220
+
221
+ def extend_prob(self, x):
222
+ """Extend CTC prob.
223
+
224
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
225
+ """
226
+
227
+ if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
228
+ # Pad the rest of posteriors in the batch
229
+ # TODO(takaaki-hori): need a better way without for-loops
230
+ xlens = [x.size(1)]
231
+ for i, l in enumerate(xlens):
232
+ if l < self.input_length:
233
+ x[i, l:, :] = self.logzero
234
+ x[i, l:, self.blank] = 0
235
+ tmp_x = self.x
236
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
237
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
238
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
239
+ self.x[:, : tmp_x.shape[1], :, :] = tmp_x
240
+ self.input_length = x.size(1)
241
+ self.end_frames = torch.as_tensor(xlens) - 1
242
+
243
+ def extend_state(self, state):
244
+ """Compute CTC prefix state.
245
+
246
+
247
+ :param state : CTC state
248
+ :return ctc_state
249
+ """
250
+
251
+ if state is None:
252
+ # nothing to do
253
+ return state
254
+ else:
255
+ r_prev, s_prev, f_min_prev, f_max_prev = state
256
+
257
+ r_prev_new = torch.full(
258
+ (self.input_length, 2),
259
+ self.logzero,
260
+ dtype=self.dtype,
261
+ device=self.device,
262
+ )
263
+ start = max(r_prev.shape[0], 1)
264
+ r_prev_new[0:start] = r_prev
265
+ for t in range(start, self.input_length):
266
+ r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
267
+
268
+ return (r_prev_new, s_prev, f_min_prev, f_max_prev)
269
+
270
+
271
+ class CTCPrefixScore(object):
272
+ """Compute CTC label sequence scores
273
+
274
+ which is based on Algorithm 2 in WATANABE et al.
275
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
276
+ but extended to efficiently compute the probablities of multiple labels
277
+ simultaneously
278
+ """
279
+
280
+ def __init__(self, x, blank, eos, xp):
281
+ self.xp = xp
282
+ self.logzero = -10000000000.0
283
+ self.blank = blank
284
+ self.eos = eos
285
+ self.input_length = len(x)
286
+ self.x = x
287
+
288
+ def initial_state(self):
289
+ """Obtain an initial CTC state
290
+
291
+ :return: CTC state
292
+ """
293
+ # initial CTC state is made of a frame x 2 tensor that corresponds to
294
+ # r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
295
+ # superscripts n and b (non-blank and blank), respectively.
296
+ r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
297
+ r[0, 1] = self.x[0, self.blank]
298
+ for i in range(1, self.input_length):
299
+ r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
300
+ return r
301
+
302
+ def __call__(self, y, cs, r_prev):
303
+ """Compute CTC prefix scores for next labels
304
+
305
+ :param y : prefix label sequence
306
+ :param cs : array of next labels
307
+ :param r_prev: previous CTC state
308
+ :return ctc_scores, ctc_states
309
+ """
310
+ # initialize CTC states
311
+ output_length = len(y) - 1 # ignore sos
312
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
313
+ # that corresponds to r_t^n(h) and r_t^b(h).
314
+ r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
315
+ xs = self.x[:, cs]
316
+ if output_length == 0:
317
+ r[0, 0] = xs[0]
318
+ r[0, 1] = self.logzero
319
+ else:
320
+ r[output_length - 1] = self.logzero
321
+
322
+ # prepare forward probabilities for the last label
323
+ r_sum = self.xp.logaddexp(
324
+ r_prev[:, 0], r_prev[:, 1]
325
+ ) # log(r_t^n(g) + r_t^b(g))
326
+ last = y[-1]
327
+ if output_length > 0 and last in cs:
328
+ log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
329
+ for i in range(len(cs)):
330
+ log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
331
+ else:
332
+ log_phi = r_sum
333
+
334
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
335
+ # and log prefix probabilities log(psi)
336
+ start = max(output_length, 1)
337
+ log_psi = r[start - 1, 0]
338
+ for t in range(start, self.input_length):
339
+ r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
340
+ r[t, 1] = (
341
+ self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
342
+ )
343
+ log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
344
+
345
+ # get P(...eos|X) that ends with the prefix itself
346
+ eos_pos = self.xp.where(cs == self.eos)[0]
347
+ if len(eos_pos) > 0:
348
+ log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
349
+
350
+ # exclude blank probs
351
+ blank_pos = self.xp.where(cs == self.blank)[0]
352
+ if len(blank_pos) > 0:
353
+ log_psi[blank_pos] = self.logzero
354
+
355
+ # return the log prefix probability and CTC states, where the label axis
356
+ # of the CTC states is moved to the first axis to slice it easily
357
+ return log_psi, self.xp.rollaxis(r, 2)
espnet/nets/e2e_asr_common.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Common functions for ASR."""
8
+
9
+ import json
10
+ import logging
11
+ import sys
12
+ from itertools import groupby
13
+
14
+ import numpy as np
15
+
16
+
17
+ def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
18
+ """End detection.
19
+
20
+ described in Eq. (50) of S. Watanabe et al
21
+ "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
22
+
23
+ :param ended_hyps:
24
+ :param i:
25
+ :param M:
26
+ :param D_end:
27
+ :return:
28
+ """
29
+ if len(ended_hyps) == 0:
30
+ return False
31
+ count = 0
32
+ best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
33
+ for m in range(M):
34
+ # get ended_hyps with their length is i - m
35
+ hyp_length = i - m
36
+ hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
37
+ if len(hyps_same_length) > 0:
38
+ best_hyp_same_length = sorted(
39
+ hyps_same_length, key=lambda x: x["score"], reverse=True
40
+ )[0]
41
+ if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
42
+ count += 1
43
+
44
+ if count == M:
45
+ return True
46
+ else:
47
+ return False
48
+
49
+
50
+ class ErrorCalculator(object):
51
+ """Calculate CER and WER for E2E_ASR and CTC models during training.
52
+
53
+ :param y_hats: numpy array with predicted text
54
+ :param y_pads: numpy array with true (target) text
55
+ :param char_list:
56
+ :param sym_space:
57
+ :param sym_blank:
58
+ :return:
59
+ """
60
+
61
+ def __init__(
62
+ self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
63
+ ):
64
+ """Construct an ErrorCalculator object."""
65
+ super(ErrorCalculator, self).__init__()
66
+
67
+ self.report_cer = report_cer
68
+ self.report_wer = report_wer
69
+
70
+ self.char_list = char_list
71
+ self.space = sym_space
72
+ self.blank = sym_blank
73
+ self.idx_blank = self.char_list.index(self.blank)
74
+ if self.space in self.char_list:
75
+ self.idx_space = self.char_list.index(self.space)
76
+ else:
77
+ self.idx_space = None
78
+
79
+ def __call__(self, ys_hat, ys_pad, is_ctc=False):
80
+ """Calculate sentence-level WER/CER score.
81
+
82
+ :param torch.Tensor ys_hat: prediction (batch, seqlen)
83
+ :param torch.Tensor ys_pad: reference (batch, seqlen)
84
+ :param bool is_ctc: calculate CER score for CTC
85
+ :return: sentence-level WER score
86
+ :rtype float
87
+ :return: sentence-level CER score
88
+ :rtype float
89
+ """
90
+ cer, wer = None, None
91
+ if is_ctc:
92
+ return self.calculate_cer_ctc(ys_hat, ys_pad)
93
+ elif not self.report_cer and not self.report_wer:
94
+ return cer, wer
95
+
96
+ seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
97
+ if self.report_cer:
98
+ cer = self.calculate_cer(seqs_hat, seqs_true)
99
+
100
+ if self.report_wer:
101
+ wer = self.calculate_wer(seqs_hat, seqs_true)
102
+ return cer, wer
103
+
104
+ def calculate_cer_ctc(self, ys_hat, ys_pad):
105
+ """Calculate sentence-level CER score for CTC.
106
+
107
+ :param torch.Tensor ys_hat: prediction (batch, seqlen)
108
+ :param torch.Tensor ys_pad: reference (batch, seqlen)
109
+ :return: average sentence-level CER score
110
+ :rtype float
111
+ """
112
+ import editdistance
113
+
114
+ cers, char_ref_lens = [], []
115
+ for i, y in enumerate(ys_hat):
116
+ y_hat = [x[0] for x in groupby(y)]
117
+ y_true = ys_pad[i]
118
+ seq_hat, seq_true = [], []
119
+ for idx in y_hat:
120
+ idx = int(idx)
121
+ if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
122
+ seq_hat.append(self.char_list[int(idx)])
123
+
124
+ for idx in y_true:
125
+ idx = int(idx)
126
+ if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
127
+ seq_true.append(self.char_list[int(idx)])
128
+
129
+ hyp_chars = "".join(seq_hat)
130
+ ref_chars = "".join(seq_true)
131
+ if len(ref_chars) > 0:
132
+ cers.append(editdistance.eval(hyp_chars, ref_chars))
133
+ char_ref_lens.append(len(ref_chars))
134
+
135
+ cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
136
+ return cer_ctc
137
+
138
+ def convert_to_char(self, ys_hat, ys_pad):
139
+ """Convert index to character.
140
+
141
+ :param torch.Tensor seqs_hat: prediction (batch, seqlen)
142
+ :param torch.Tensor seqs_true: reference (batch, seqlen)
143
+ :return: token list of prediction
144
+ :rtype list
145
+ :return: token list of reference
146
+ :rtype list
147
+ """
148
+ seqs_hat, seqs_true = [], []
149
+ for i, y_hat in enumerate(ys_hat):
150
+ y_true = ys_pad[i]
151
+ eos_true = np.where(y_true == -1)[0]
152
+ ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
153
+ # NOTE: padding index (-1) in y_true is used to pad y_hat
154
+ seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
155
+ seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
156
+ seq_hat_text = "".join(seq_hat).replace(self.space, " ")
157
+ seq_hat_text = seq_hat_text.replace(self.blank, "")
158
+ seq_true_text = "".join(seq_true).replace(self.space, " ")
159
+ seqs_hat.append(seq_hat_text)
160
+ seqs_true.append(seq_true_text)
161
+ return seqs_hat, seqs_true
162
+
163
+ def calculate_cer(self, seqs_hat, seqs_true):
164
+ """Calculate sentence-level CER score.
165
+
166
+ :param list seqs_hat: prediction
167
+ :param list seqs_true: reference
168
+ :return: average sentence-level CER score
169
+ :rtype float
170
+ """
171
+ import editdistance
172
+
173
+ char_eds, char_ref_lens = [], []
174
+ for i, seq_hat_text in enumerate(seqs_hat):
175
+ seq_true_text = seqs_true[i]
176
+ hyp_chars = seq_hat_text.replace(" ", "")
177
+ ref_chars = seq_true_text.replace(" ", "")
178
+ char_eds.append(editdistance.eval(hyp_chars, ref_chars))
179
+ char_ref_lens.append(len(ref_chars))
180
+ return float(sum(char_eds)) / sum(char_ref_lens)
181
+
182
+ def calculate_wer(self, seqs_hat, seqs_true):
183
+ """Calculate sentence-level WER score.
184
+
185
+ :param list seqs_hat: prediction
186
+ :param list seqs_true: reference
187
+ :return: average sentence-level WER score
188
+ :rtype float
189
+ """
190
+ import editdistance
191
+
192
+ word_eds, word_ref_lens = [], []
193
+ for i, seq_hat_text in enumerate(seqs_hat):
194
+ seq_true_text = seqs_true[i]
195
+ hyp_words = seq_hat_text.split()
196
+ ref_words = seq_true_text.split()
197
+ word_eds.append(editdistance.eval(hyp_words, ref_words))
198
+ word_ref_lens.append(len(ref_words))
199
+ return float(sum(word_eds)) / sum(word_ref_lens)
espnet/nets/pytorch_backend/__pycache__/ctc.cpython-310.pyc ADDED
Binary file (8.1 kB). View file
 
espnet/nets/pytorch_backend/__pycache__/ctc.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
espnet/nets/pytorch_backend/__pycache__/e2e_asr_conformer.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
espnet/nets/pytorch_backend/__pycache__/e2e_asr_conformer.cpython-311.pyc ADDED
Binary file (4.76 kB). View file
 
espnet/nets/pytorch_backend/__pycache__/nets_utils.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
espnet/nets/pytorch_backend/__pycache__/nets_utils.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
espnet/nets/pytorch_backend/ctc.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from espnet.nets.pytorch_backend.nets_utils import to_device
6
+
7
+
8
+ class CTC(torch.nn.Module):
9
+ """CTC module
10
+
11
+ :param int odim: dimension of outputs
12
+ :param int eprojs: number of encoder projection units
13
+ :param float dropout_rate: dropout rate (0.0 ~ 1.0)
14
+ :param bool reduce: reduce the CTC loss into a scalar
15
+ """
16
+
17
+ def __init__(self, odim, eprojs, dropout_rate, reduce=True):
18
+ super().__init__()
19
+ self.dropout_rate = dropout_rate
20
+ self.loss = None
21
+ self.ctc_lo = torch.nn.Linear(eprojs, odim)
22
+ self.dropout = torch.nn.Dropout(dropout_rate)
23
+ self.probs = None # for visualization
24
+
25
+ reduction_type = "sum" if reduce else "none"
26
+ self.ctc_loss = torch.nn.CTCLoss(
27
+ reduction=reduction_type, zero_infinity=True
28
+ )
29
+ self.ignore_id = -1
30
+ self.reduce = reduce
31
+
32
+ def loss_fn(self, th_pred, th_target, th_ilen, th_olen):
33
+ th_pred = th_pred.log_softmax(2)
34
+ with torch.backends.cudnn.flags(deterministic=True):
35
+ loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
36
+ # Batch-size average
37
+ loss = loss / th_pred.size(1)
38
+ return loss
39
+
40
+ def forward(self, hs_pad, hlens, ys_pad):
41
+ """CTC forward
42
+
43
+ :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
44
+ :param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
45
+ :param torch.Tensor ys_pad:
46
+ batch of padded character id sequence tensor (B, Lmax)
47
+ :return: ctc loss value
48
+ :rtype: torch.Tensor
49
+ """
50
+ # TODO(kan-bayashi): need to make more smart way
51
+ ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
52
+
53
+ # zero padding for hs
54
+ ys_hat = self.ctc_lo(self.dropout(hs_pad))
55
+ ys_hat = ys_hat.transpose(0, 1)
56
+
57
+ olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys]))
58
+ hlens = hlens.long()
59
+ ys_pad = torch.cat(ys) # without this the code breaks for asr_mix
60
+ self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens)
61
+
62
+ if self.reduce:
63
+ self.loss = self.loss.sum()
64
+
65
+ return self.loss, ys_hat
66
+
67
+ def softmax(self, hs_pad):
68
+ """softmax of frame activations
69
+
70
+ :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
71
+ :return: log softmax applied 3d tensor (B, Tmax, odim)
72
+ :rtype: torch.Tensor
73
+ """
74
+ self.probs = F.softmax(self.ctc_lo(hs_pad), dim=-1)
75
+ return self.probs
76
+
77
+ def log_softmax(self, hs_pad):
78
+ """log_softmax of frame activations
79
+
80
+ :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
81
+ :return: log softmax applied 3d tensor (B, Tmax, odim)
82
+ :rtype: torch.Tensor
83
+ """
84
+ return F.log_softmax(self.ctc_lo(hs_pad), dim=-1)
85
+
86
+ def argmax(self, hs_pad):
87
+ """argmax of frame activations
88
+
89
+ :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
90
+ :return: argmax applied 2d tensor (B, Tmax)
91
+ :rtype: torch.Tensor
92
+ """
93
+ return torch.argmax(self.ctc_lo(hs_pad), dim=-1)
94
+
95
+ def forced_align(self, h, y, blank_id=0):
96
+ """forced alignment.
97
+
98
+ :param torch.Tensor h: hidden state sequence, 2d tensor (T, D)
99
+ :param torch.Tensor y: id sequence tensor 1d tensor (L)
100
+ :param int y: blank symbol index
101
+ :return: best alignment results
102
+ :rtype: list
103
+ """
104
+
105
+ def interpolate_blank(label, blank_id=0):
106
+ """Insert blank token between every two label token."""
107
+ label = np.expand_dims(label, 1)
108
+ blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
109
+ label = np.concatenate([blanks, label], axis=1)
110
+ label = label.reshape(-1)
111
+ label = np.append(label, label[0])
112
+ return label
113
+
114
+ lpz = self.log_softmax(h)
115
+ lpz = lpz.squeeze(0)
116
+
117
+ y_int = interpolate_blank(y, blank_id)
118
+
119
+ logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero
120
+ state_path = (
121
+ np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1
122
+ ) # state path
123
+
124
+ logdelta[0, 0] = lpz[0][y_int[0]]
125
+ logdelta[0, 1] = lpz[0][y_int[1]]
126
+
127
+ for t in range(1, lpz.size(0)):
128
+ for s in range(len(y_int)):
129
+ if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]:
130
+ candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]])
131
+ prev_state = [s, s - 1]
132
+ else:
133
+ candidates = np.array(
134
+ [
135
+ logdelta[t - 1, s],
136
+ logdelta[t - 1, s - 1],
137
+ logdelta[t - 1, s - 2],
138
+ ]
139
+ )
140
+ prev_state = [s, s - 1, s - 2]
141
+ logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]]
142
+ state_path[t, s] = prev_state[np.argmax(candidates)]
143
+
144
+ state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16)
145
+
146
+ candidates = np.array(
147
+ [logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]]
148
+ )
149
+ prev_state = [len(y_int) - 1, len(y_int) - 2]
150
+ state_seq[-1] = prev_state[np.argmax(candidates)]
151
+ for t in range(lpz.size(0) - 2, -1, -1):
152
+ state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
153
+
154
+ output_state_seq = []
155
+ for t in range(0, lpz.size(0)):
156
+ output_state_seq.append(y_int[state_seq[t, 0]])
157
+
158
+ return output_state_seq
159
+
160
+ def forced_align_batch(self, hs_pad, ys_pad, ilens, blank_id=0):
161
+ """forced alignment with batch processing.
162
+
163
+ :param torch.Tensor hs_pad: hidden state sequence, 3d tensor (T, B, D)
164
+ :param torch.Tensor ys_pad: id sequence tensor 2d tensor (B, L)
165
+ :param torch.Tensor ilens: Input length of each utterance (B,)
166
+ :param int blank_id: blank symbol index
167
+ :return: best alignment results
168
+ :rtype: list of numpy.array
169
+ """
170
+
171
+ def interpolate_blank(label, olens_int):
172
+ """Insert blank token between every two label token."""
173
+ lab_len = label.shape[1] * 2 + 1
174
+ label_out = np.full((label.shape[0], lab_len), blank_id, dtype=np.int64)
175
+ label_out[:, 1::2] = label
176
+ for b in range(label.shape[0]):
177
+ label_out[b, olens_int[b] * 2 + 1 :] = self.ignore_id
178
+ return label_out
179
+
180
+ neginf = float("-inf") # log of zero
181
+ # lpz = self.log_softmax(hs_pad).cpu().detach().numpy()
182
+ # hs_pad = hs_pad.transpose(1,0)
183
+ lpz = F.log_softmax(hs_pad, dim=-1).cpu().detach().numpy()
184
+ ilens = ilens.cpu().detach().numpy()
185
+
186
+ ys_pad = ys_pad.cpu().detach().numpy()
187
+ ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
188
+ olens = np.array([len(s) for s in ys])
189
+ olens_int = olens * 2 + 1
190
+ ys_int = interpolate_blank(ys_pad, olens_int)
191
+
192
+ Tmax, B, _ = lpz.shape
193
+ Lmax = ys_int.shape[-1]
194
+ logdelta = np.full((Tmax, B, Lmax), neginf, dtype=lpz.dtype)
195
+ state_path = -np.ones(logdelta.shape, dtype=np.int16) # state path
196
+
197
+ b_indx = np.arange(B, dtype=np.int64)
198
+ t_0 = np.zeros(B, dtype=np.int64)
199
+ logdelta[0, :, 0] = lpz[t_0, b_indx, ys_int[:, 0]]
200
+ logdelta[0, :, 1] = lpz[t_0, b_indx, ys_int[:, 1]]
201
+
202
+ s_indx_mat = np.arange(Lmax)[None, :].repeat(B, 0)
203
+ notignore_mat = ys_int != self.ignore_id
204
+ same_lab_mat = np.zeros((B, Lmax), dtype=bool)
205
+ same_lab_mat[:, 3::2] = ys_int[:, 3::2] == ys_int[:, 1:-2:2]
206
+ Lmin = olens_int.min()
207
+ for t in range(1, Tmax):
208
+ s_start = max(0, Lmin - (Tmax - t) * 2)
209
+ s_end = min(Lmax, t * 2 + 2)
210
+ candidates = np.full((B, Lmax, 3), neginf, dtype=logdelta.dtype)
211
+ candidates[:, :, 0] = logdelta[t - 1, :, :]
212
+ candidates[:, 1:, 1] = logdelta[t - 1, :, :-1]
213
+ candidates[:, 3::2, 2] = logdelta[t - 1, :, 1:-2:2]
214
+ candidates[same_lab_mat, 2] = neginf
215
+ candidates_ = candidates[:, s_start:s_end, :]
216
+ idx = candidates_.argmax(-1)
217
+ b_i, s_i = np.ogrid[:B, : idx.shape[-1]]
218
+ nignore = notignore_mat[:, s_start:s_end]
219
+ logdelta[t, :, s_start:s_end][nignore] = (
220
+ candidates_[b_i, s_i, idx][nignore]
221
+ + lpz[t, b_i, ys_int[:, s_start:s_end]][nignore]
222
+ )
223
+ s = s_indx_mat[:, s_start:s_end]
224
+ state_path[t, :, s_start:s_end][nignore] = (s - idx)[nignore]
225
+
226
+ alignments = []
227
+ prev_states = logdelta[
228
+ ilens[:, None] - 1,
229
+ b_indx[:, None],
230
+ np.stack([olens_int - 2, olens_int - 1], -1),
231
+ ].argmax(-1)
232
+ for b in range(B):
233
+ T, L = ilens[b], olens_int[b]
234
+ prev_state = prev_states[b] + L - 2
235
+ ali = np.empty(T, dtype=ys_int.dtype)
236
+ ali[T - 1] = ys_int[b, prev_state]
237
+ for t in range(T - 2, -1, -1):
238
+ prev_state = state_path[t + 1, b, prev_state]
239
+ ali[t] = ys_int[b, prev_state]
240
+ alignments.append(ali)
241
+
242
+ return alignments
espnet/nets/pytorch_backend/decoder/__pycache__/transformer_decoder.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
espnet/nets/pytorch_backend/decoder/__pycache__/transformer_decoder.cpython-311.pyc ADDED
Binary file (17.6 kB). View file
 
espnet/nets/pytorch_backend/decoder/transformer_decoder.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Decoder definition."""
8
+
9
+ from typing import Any, List, Tuple
10
+
11
+ import torch
12
+ from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
13
+ from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
14
+ from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
15
+ from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
16
+ from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
17
+ from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import PositionwiseFeedForward
18
+ from espnet.nets.pytorch_backend.transformer.repeat import repeat
19
+ from espnet.nets.scorer_interface import BatchScorerInterface
20
+
21
+
22
+ class DecoderLayer(torch.nn.Module):
23
+ """Single decoder layer module.
24
+ :param int size: input dim
25
+ :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention
26
+ self_attn: self attention module
27
+ :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention
28
+ src_attn: source attention module
29
+ :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
30
+ PositionwiseFeedForward feed_forward: feed forward layer module
31
+ :param float dropout_rate: dropout rate
32
+ :param bool normalize_before: whether to use layer_norm before the first block
33
+ :param bool concat_after: whether to concat attention layer's input and output
34
+ if True, additional linear will be applied.
35
+ i.e. x -> x + linear(concat(x, att(x)))
36
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ size,
42
+ self_attn,
43
+ src_attn,
44
+ feed_forward,
45
+ dropout_rate,
46
+ normalize_before=True,
47
+ concat_after=False,
48
+ ):
49
+ """Construct an DecoderLayer object."""
50
+ super(DecoderLayer, self).__init__()
51
+ self.size = size
52
+ self.self_attn = self_attn
53
+ self.src_attn = src_attn
54
+ self.feed_forward = feed_forward
55
+ self.norm1 = LayerNorm(size)
56
+ self.norm2 = LayerNorm(size)
57
+ self.norm3 = LayerNorm(size)
58
+ self.dropout = torch.nn.Dropout(dropout_rate)
59
+ self.normalize_before = normalize_before
60
+ self.concat_after = concat_after
61
+ if self.concat_after:
62
+ self.concat_linear1 = torch.nn.Linear(size + size, size)
63
+ self.concat_linear2 = torch.nn.Linear(size + size, size)
64
+
65
+ def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
66
+ """Compute decoded features.
67
+ Args:
68
+ tgt (torch.Tensor):
69
+ decoded previous target features (batch, max_time_out, size)
70
+ tgt_mask (torch.Tensor): mask for x (batch, max_time_out)
71
+ memory (torch.Tensor): encoded source features (batch, max_time_in, size)
72
+ memory_mask (torch.Tensor): mask for memory (batch, max_time_in)
73
+ cache (torch.Tensor): cached output (batch, max_time_out-1, size)
74
+ """
75
+ residual = tgt
76
+ if self.normalize_before:
77
+ tgt = self.norm1(tgt)
78
+
79
+ if cache is None:
80
+ tgt_q = tgt
81
+ tgt_q_mask = tgt_mask
82
+ else:
83
+ # compute only the last frame query keeping dim: max_time_out -> 1
84
+ assert cache.shape == (
85
+ tgt.shape[0],
86
+ tgt.shape[1] - 1,
87
+ self.size,
88
+ ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
89
+ tgt_q = tgt[:, -1:, :]
90
+ residual = residual[:, -1:, :]
91
+ tgt_q_mask = None
92
+ if tgt_mask is not None:
93
+ tgt_q_mask = tgt_mask[:, -1:, :]
94
+
95
+ if self.concat_after:
96
+ tgt_concat = torch.cat(
97
+ (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
98
+ )
99
+ x = residual + self.concat_linear1(tgt_concat)
100
+ else:
101
+ x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
102
+ if not self.normalize_before:
103
+ x = self.norm1(x)
104
+
105
+ residual = x
106
+ if self.normalize_before:
107
+ x = self.norm2(x)
108
+ if self.concat_after:
109
+ x_concat = torch.cat(
110
+ (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
111
+ )
112
+ x = residual + self.concat_linear2(x_concat)
113
+ else:
114
+ x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
115
+ if not self.normalize_before:
116
+ x = self.norm2(x)
117
+
118
+ residual = x
119
+ if self.normalize_before:
120
+ x = self.norm3(x)
121
+ x = residual + self.dropout(self.feed_forward(x))
122
+ if not self.normalize_before:
123
+ x = self.norm3(x)
124
+
125
+ if cache is not None:
126
+ x = torch.cat([cache, x], dim=1)
127
+
128
+ return x, tgt_mask, memory, memory_mask
129
+
130
+
131
+ def _pre_hook(
132
+ state_dict,
133
+ prefix,
134
+ local_metadata,
135
+ strict,
136
+ missing_keys,
137
+ unexpected_keys,
138
+ error_msgs,
139
+ ):
140
+ # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
141
+ rename_state_dict(prefix + "output_norm.", prefix + "after_norm.", state_dict)
142
+
143
+
144
+ class TransformerDecoder(BatchScorerInterface, torch.nn.Module):
145
+ """Transfomer decoder module.
146
+
147
+ :param int odim: output dim
148
+ :param int attention_dim: dimention of attention
149
+ :param int attention_heads: the number of heads of multi head attention
150
+ :param int linear_units: the number of units of position-wise feed forward
151
+ :param int num_blocks: the number of decoder blocks
152
+ :param float dropout_rate: dropout rate
153
+ :param float attention_dropout_rate: dropout rate for attention
154
+ :param str or torch.nn.Module input_layer: input layer type
155
+ :param bool use_output_layer: whether to use output layer
156
+ :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
157
+ :param bool normalize_before: whether to use layer_norm before the first block
158
+ :param bool concat_after: whether to concat attention layer's input and output
159
+ if True, additional linear will be applied.
160
+ i.e. x -> x + linear(concat(x, att(x)))
161
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ odim,
167
+ attention_dim=256,
168
+ attention_heads=4,
169
+ linear_units=2048,
170
+ num_blocks=6,
171
+ dropout_rate=0.1,
172
+ positional_dropout_rate=0.1,
173
+ self_attention_dropout_rate=0.1,
174
+ src_attention_dropout_rate=0.1,
175
+ input_layer="embed",
176
+ use_output_layer=True,
177
+ pos_enc_class=PositionalEncoding,
178
+ normalize_before=True,
179
+ concat_after=False,
180
+ layer_drop_rate=0.0,
181
+ ):
182
+ """Construct an Decoder object."""
183
+ torch.nn.Module.__init__(self)
184
+ self._register_load_state_dict_pre_hook(_pre_hook)
185
+ if input_layer == "embed":
186
+ self.embed = torch.nn.Sequential(
187
+ torch.nn.Embedding(odim, attention_dim),
188
+ pos_enc_class(attention_dim, positional_dropout_rate),
189
+ )
190
+ elif input_layer == "linear":
191
+ self.embed = torch.nn.Sequential(
192
+ torch.nn.Linear(odim, attention_dim),
193
+ torch.nn.LayerNorm(attention_dim),
194
+ torch.nn.Dropout(dropout_rate),
195
+ torch.nn.ReLU(),
196
+ pos_enc_class(attention_dim, positional_dropout_rate),
197
+ )
198
+ elif isinstance(input_layer, torch.nn.Module):
199
+ self.embed = torch.nn.Sequential(
200
+ input_layer, pos_enc_class(attention_dim, positional_dropout_rate)
201
+ )
202
+ else:
203
+ raise NotImplementedError("only `embed` or torch.nn.Module is supported.")
204
+ self.normalize_before = normalize_before
205
+ self.decoders = repeat(
206
+ num_blocks,
207
+ lambda lnum: DecoderLayer(
208
+ attention_dim,
209
+ MultiHeadedAttention(
210
+ attention_heads, attention_dim, self_attention_dropout_rate
211
+ ),
212
+ MultiHeadedAttention(
213
+ attention_heads, attention_dim, src_attention_dropout_rate
214
+ ),
215
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
216
+ dropout_rate,
217
+ normalize_before,
218
+ concat_after,
219
+ ),
220
+ layer_drop_rate,
221
+ )
222
+ if self.normalize_before:
223
+ self.after_norm = LayerNorm(attention_dim)
224
+ if use_output_layer:
225
+ self.output_layer = torch.nn.Linear(attention_dim, odim)
226
+ else:
227
+ self.output_layer = None
228
+
229
+ def forward(self, tgt, tgt_mask, memory, memory_mask):
230
+ """Forward decoder.
231
+ :param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
232
+ if input_layer == "embed"
233
+ input tensor (batch, maxlen_out, #mels)
234
+ in the other cases
235
+ :param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
236
+ dtype=torch.uint8 in PyTorch 1.2-
237
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
238
+ :param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
239
+ :param torch.Tensor memory_mask: encoded memory mask, (batch, maxlen_in)
240
+ dtype=torch.uint8 in PyTorch 1.2-
241
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
242
+ :return x: decoded token score before softmax (batch, maxlen_out, token)
243
+ if use_output_layer is True,
244
+ final block outputs (batch, maxlen_out, attention_dim)
245
+ in the other cases
246
+ :rtype: torch.Tensor
247
+ :return tgt_mask: score mask before softmax (batch, maxlen_out)
248
+ :rtype: torch.Tensor
249
+ """
250
+ x = self.embed(tgt)
251
+ x, tgt_mask, memory, memory_mask = self.decoders(
252
+ x, tgt_mask, memory, memory_mask
253
+ )
254
+ if self.normalize_before:
255
+ x = self.after_norm(x)
256
+ if self.output_layer is not None:
257
+ x = self.output_layer(x)
258
+ return x, tgt_mask
259
+
260
+ def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
261
+ """Forward one step.
262
+ :param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
263
+ :param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
264
+ dtype=torch.uint8 in PyTorch 1.2-
265
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
266
+ :param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
267
+ :param List[torch.Tensor] cache:
268
+ cached output list of (batch, max_time_out-1, size)
269
+ :return y, cache: NN output value and cache per `self.decoders`.
270
+ `y.shape` is (batch, maxlen_out, token)
271
+ :rtype: Tuple[torch.Tensor, List[torch.Tensor]]
272
+ """
273
+ x = self.embed(tgt)
274
+ if cache is None:
275
+ cache = [None] * len(self.decoders)
276
+ new_cache = []
277
+ for c, decoder in zip(cache, self.decoders):
278
+ x, tgt_mask, memory, memory_mask = decoder(
279
+ x, tgt_mask, memory, memory_mask, cache=c
280
+ )
281
+ new_cache.append(x)
282
+
283
+ if self.normalize_before:
284
+ y = self.after_norm(x[:, -1])
285
+ else:
286
+ y = x[:, -1]
287
+ if self.output_layer is not None:
288
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
289
+
290
+ return y, new_cache
291
+
292
+ # beam search API (see ScorerInterface)
293
+ def score(self, ys, state, x):
294
+ """Score."""
295
+ ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
296
+ logp, state = self.forward_one_step(
297
+ ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
298
+ )
299
+ return logp.squeeze(0), state
300
+
301
+ # batch beam search API (see BatchScorerInterface)
302
+ def batch_score(
303
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
304
+ ) -> Tuple[torch.Tensor, List[Any]]:
305
+ """Score new token batch (required).
306
+ Args:
307
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
308
+ states (List[Any]): Scorer states for prefix tokens.
309
+ xs (torch.Tensor):
310
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
311
+ Returns:
312
+ tuple[torch.Tensor, List[Any]]: Tuple of
313
+ batchfied scores for next token with shape of `(n_batch, n_vocab)`
314
+ and next state list for ys.
315
+ """
316
+ # merge states
317
+ n_batch = len(ys)
318
+ n_layers = len(self.decoders)
319
+ if states[0] is None:
320
+ batch_state = None
321
+ else:
322
+ # transpose state of [batch, layer] into [layer, batch]
323
+ batch_state = [
324
+ torch.stack([states[b][l] for b in range(n_batch)])
325
+ for l in range(n_layers)
326
+ ]
327
+
328
+ # batch decoding
329
+ ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
330
+ logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
331
+
332
+ # transpose state of [layer, batch] into [batch, layer]
333
+ state_list = [[states[l][b] for l in range(n_layers)] for b in range(n_batch)]
334
+ return logp, state_list
espnet/nets/pytorch_backend/e2e_asr_conformer.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2023 Imperial College London (Pingchuan Ma)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ import torch
8
+
9
+ from espnet.nets.pytorch_backend.frontend.resnet import video_resnet
10
+ from espnet.nets.pytorch_backend.frontend.resnet1d import audio_resnet
11
+ from espnet.nets.pytorch_backend.ctc import CTC
12
+ from espnet.nets.pytorch_backend.encoder.conformer_encoder import ConformerEncoder
13
+ from espnet.nets.pytorch_backend.decoder.transformer_decoder import TransformerDecoder
14
+ from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask, th_accuracy
15
+ from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
16
+ from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import LabelSmoothingLoss
17
+ from espnet.nets.pytorch_backend.transformer.mask import target_mask
18
+ from espnet.nets.scorers.ctc import CTCPrefixScorer
19
+
20
+
21
+ class E2E(torch.nn.Module):
22
+ def __init__(self, odim, modality, ctc_weight=0.1, ignore_id=-1):
23
+ super().__init__()
24
+
25
+ self.modality = modality
26
+ if modality == "audio":
27
+ self.frontend = audio_resnet()
28
+ elif modality == "video":
29
+ self.frontend = video_resnet()
30
+
31
+ self.proj_encoder = torch.nn.Linear(512, 768)
32
+
33
+ self.encoder = ConformerEncoder(
34
+ attention_dim=768,
35
+ attention_heads=12,
36
+ linear_units=3072,
37
+ num_blocks=12,
38
+ cnn_module_kernel=31,
39
+ )
40
+
41
+ self.decoder = TransformerDecoder(
42
+ odim=odim,
43
+ attention_dim=768,
44
+ attention_heads=12,
45
+ linear_units=3072,
46
+ num_blocks=6,
47
+ )
48
+
49
+ self.blank = 0
50
+ self.sos = odim - 1
51
+ self.eos = odim - 1
52
+ self.odim = odim
53
+ self.ignore_id = ignore_id
54
+
55
+ # loss
56
+ self.ctc_weight = ctc_weight
57
+ self.ctc = CTC(odim, 768, 0.1, reduce=True)
58
+ self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, 0.1, False)
59
+
60
+ def scorers(self):
61
+ return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
62
+
63
+ def forward(self, x, lengths, label):
64
+ if self.modality == "audio":
65
+ lengths = torch.div(lengths, 640, rounding_mode="trunc")
66
+
67
+ padding_mask = make_non_pad_mask(lengths).to(x.device).unsqueeze(-2)
68
+
69
+ x = self.frontend(x)
70
+ x = self.proj_encoder(x)
71
+ x, _ = self.encoder(x, padding_mask)
72
+
73
+ # ctc loss
74
+ loss_ctc, ys_hat = self.ctc(x, lengths, label)
75
+
76
+ # decoder loss
77
+ ys_in_pad, ys_out_pad = add_sos_eos(label, self.sos, self.eos, self.ignore_id)
78
+ ys_mask = target_mask(ys_in_pad, self.ignore_id)
79
+ pred_pad, _ = self.decoder(ys_in_pad, ys_mask, x, padding_mask)
80
+ loss_att = self.criterion(pred_pad, ys_out_pad)
81
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
82
+
83
+ acc = th_accuracy(
84
+ pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
85
+ )
86
+
87
+ return loss, loss_ctc, loss_att, acc
espnet/nets/pytorch_backend/encoder/__pycache__/conformer_encoder.cpython-310.pyc ADDED
Binary file (9.5 kB). View file
 
espnet/nets/pytorch_backend/encoder/__pycache__/conformer_encoder.cpython-311.pyc ADDED
Binary file (15 kB). View file
 
espnet/nets/pytorch_backend/encoder/conformer_encoder.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Encoder definition."""
8
+
9
+ import copy
10
+ import torch
11
+ from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
12
+ from espnet.nets.pytorch_backend.transformer.attention import RelPositionMultiHeadedAttention
13
+ from espnet.nets.pytorch_backend.transformer.embedding import RelPositionalEncoding
14
+ from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
15
+ from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import PositionwiseFeedForward
16
+ from espnet.nets.pytorch_backend.transformer.repeat import repeat
17
+
18
+
19
+ class ConvolutionModule(torch.nn.Module):
20
+ def __init__(self, channels, kernel_size, bias=True):
21
+ super().__init__()
22
+ assert (kernel_size - 1) % 2 == 0
23
+
24
+ self.pointwise_cov1 = torch.nn.Conv1d(channels, 2 * channels, 1, bias=bias)
25
+ self.depthwise_conv = torch.nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size - 1) // 2, groups=channels, bias=bias)
26
+ self.norm = torch.nn.BatchNorm1d(channels)
27
+ self.pointwise_cov2 = torch.nn.Conv1d(channels, channels, 1, bias=bias)
28
+ self.activation = torch.nn.SiLU(inplace=True)
29
+
30
+ def forward(self, x):
31
+ x = x.transpose(1, 2)
32
+ x = torch.nn.functional.glu(self.pointwise_cov1(x), dim=1)
33
+ x = self.activation(self.norm(self.depthwise_conv(x)))
34
+ x = self.pointwise_cov2(x)
35
+ return x.transpose(1, 2)
36
+
37
+
38
+ class EncoderLayer(torch.nn.Module):
39
+ """Encoder layer module.
40
+
41
+ :param int size: input dim
42
+ :param espnet.nets.pytorch_backend.transformer.attention.
43
+ MultiHeadedAttention self_attn: self attention module
44
+ RelPositionMultiHeadedAttention self_attn: self attention module
45
+ :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
46
+ PositionwiseFeedForward feed_forward:
47
+ feed forward module
48
+ :param espnet.nets.pytorch_backend.transformer.convolution.
49
+ ConvolutionModule feed_foreard:
50
+ feed forward module
51
+ :param float dropout_rate: dropout rate
52
+ :param bool normalize_before: whether to use layer_norm before the first block
53
+ :param bool concat_after: whether to concat attention layer's input and output
54
+ if True, additional linear will be applied.
55
+ i.e. x -> x + linear(concat(x, att(x)))
56
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
57
+ :param bool macaron_style: whether to use macaron style for PositionwiseFeedForward
58
+
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ size,
64
+ self_attn,
65
+ feed_forward,
66
+ conv_module,
67
+ dropout_rate,
68
+ normalize_before=True,
69
+ concat_after=False,
70
+ macaron_style=False,
71
+ ):
72
+ """Construct an EncoderLayer object."""
73
+ super(EncoderLayer, self).__init__()
74
+ self.self_attn = self_attn
75
+ self.feed_forward = feed_forward
76
+ self.ff_scale = 1.0
77
+ self.conv_module = conv_module
78
+ self.macaron_style = macaron_style
79
+ self.norm_ff = LayerNorm(size) # for the FNN module
80
+ self.norm_mha = LayerNorm(size) # for the MHA module
81
+ if self.macaron_style:
82
+ self.feed_forward_macaron = copy.deepcopy(feed_forward)
83
+ self.ff_scale = 0.5
84
+ # for another FNN module in macaron style
85
+ self.norm_ff_macaron = LayerNorm(size)
86
+ if self.conv_module is not None:
87
+ self.norm_conv = LayerNorm(size) # for the CNN module
88
+ self.norm_final = LayerNorm(size) # for the final output of the block
89
+ self.dropout = torch.nn.Dropout(dropout_rate)
90
+ self.size = size
91
+ self.normalize_before = normalize_before
92
+ self.concat_after = concat_after
93
+ if self.concat_after:
94
+ self.concat_linear = torch.nn.Linear(size + size, size)
95
+
96
+ def forward(self, x_input, mask, cache=None):
97
+ """Compute encoded features.
98
+
99
+ :param torch.Tensor x_input: encoded source features (batch, max_time_in, size)
100
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
101
+ :param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
102
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
103
+ """
104
+ if isinstance(x_input, tuple):
105
+ x, pos_emb = x_input[0], x_input[1]
106
+ else:
107
+ x, pos_emb = x_input, None
108
+
109
+ # whether to use macaron style
110
+ if self.macaron_style:
111
+ residual = x
112
+ if self.normalize_before:
113
+ x = self.norm_ff_macaron(x)
114
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
115
+ if not self.normalize_before:
116
+ x = self.norm_ff_macaron(x)
117
+
118
+ # multi-headed self-attention module
119
+ residual = x
120
+ if self.normalize_before:
121
+ x = self.norm_mha(x)
122
+
123
+ if cache is None:
124
+ x_q = x
125
+ else:
126
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
127
+ x_q = x[:, -1:, :]
128
+ residual = residual[:, -1:, :]
129
+ mask = None if mask is None else mask[:, -1:, :]
130
+
131
+ if pos_emb is not None:
132
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
133
+ else:
134
+ x_att = self.self_attn(x_q, x, x, mask)
135
+
136
+ if self.concat_after:
137
+ x_concat = torch.cat((x, x_att), dim=-1)
138
+ x = residual + self.concat_linear(x_concat)
139
+ else:
140
+ x = residual + self.dropout(x_att)
141
+ if not self.normalize_before:
142
+ x = self.norm_mha(x)
143
+
144
+ # convolution module
145
+ if self.conv_module is not None:
146
+ residual = x
147
+ if self.normalize_before:
148
+ x = self.norm_conv(x)
149
+ x = residual + self.dropout(self.conv_module(x))
150
+ if not self.normalize_before:
151
+ x = self.norm_conv(x)
152
+
153
+ # feed forward module
154
+ residual = x
155
+ if self.normalize_before:
156
+ x = self.norm_ff(x)
157
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
158
+ if not self.normalize_before:
159
+ x = self.norm_ff(x)
160
+
161
+ if self.conv_module is not None:
162
+ x = self.norm_final(x)
163
+
164
+ if cache is not None:
165
+ x = torch.cat([cache, x], dim=1)
166
+
167
+ if pos_emb is not None:
168
+ return (x, pos_emb), mask
169
+ else:
170
+ return x, mask
171
+
172
+
173
+ def _pre_hook(
174
+ state_dict,
175
+ prefix,
176
+ local_metadata,
177
+ strict,
178
+ missing_keys,
179
+ unexpected_keys,
180
+ error_msgs,
181
+ ):
182
+ rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
183
+ rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
184
+
185
+
186
+ class ConformerEncoder(torch.nn.Module):
187
+ """Transformer encoder module.
188
+
189
+ :param int idim: input dim
190
+ :param int attention_dim: dimention of attention
191
+ :param int attention_heads: the number of heads of multi head attention
192
+ :param int linear_units: the number of units of position-wise feed forward
193
+ :param int num_blocks: the number of decoder blocks
194
+ :param float dropout_rate: dropout rate
195
+ :param float attention_dropout_rate: dropout rate in attention
196
+ :param float positional_dropout_rate: dropout rate after adding positional encoding
197
+ :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
198
+ :param bool normalize_before: whether to use layer_norm before the first block
199
+ :param bool concat_after: whether to concat attention layer's input and output
200
+ if True, additional linear will be applied.
201
+ i.e. x -> x + linear(concat(x, att(x)))
202
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
203
+ :param str positionwise_layer_type: linear of conv1d
204
+ :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
205
+ :param bool macaron_style: whether to use macaron style for positionwise layer
206
+ :param bool use_cnn_module: whether to use convolution module
207
+ :param bool zero_triu: whether to zero the upper triangular part of attention matrix
208
+ :param int cnn_module_kernel: kernerl size of convolution module
209
+ :param int padding_idx: padding_idx for input_layer=embed
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ attention_dim=768,
215
+ attention_heads=12,
216
+ linear_units=3072,
217
+ num_blocks=12,
218
+ dropout_rate=0.1,
219
+ positional_dropout_rate=0.1,
220
+ attention_dropout_rate=0.0,
221
+ normalize_before=True,
222
+ concat_after=False,
223
+ macaron_style=True,
224
+ use_cnn_module=True,
225
+ zero_triu=False,
226
+ cnn_module_kernel=31,
227
+ padding_idx=-1,
228
+ relu_type="swish",
229
+ layer_drop_rate=0.0,
230
+ ):
231
+ """Construct an Encoder object."""
232
+ super(ConformerEncoder, self).__init__()
233
+ self._register_load_state_dict_pre_hook(_pre_hook)
234
+
235
+ self.embed = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate))
236
+ self.normalize_before = normalize_before
237
+
238
+ positionwise_layer = PositionwiseFeedForward
239
+ positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
240
+
241
+ encoder_attn_layer = RelPositionMultiHeadedAttention
242
+ encoder_attn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
243
+
244
+ convolution_layer = ConvolutionModule
245
+ convolution_layer_args = (attention_dim, cnn_module_kernel)
246
+
247
+ self.encoders = repeat(
248
+ num_blocks,
249
+ lambda lnum: EncoderLayer(
250
+ attention_dim,
251
+ encoder_attn_layer(*encoder_attn_layer_args),
252
+ positionwise_layer(*positionwise_layer_args),
253
+ convolution_layer(*convolution_layer_args) if use_cnn_module else None,
254
+ dropout_rate,
255
+ normalize_before,
256
+ concat_after,
257
+ macaron_style,
258
+ ),
259
+ layer_drop_rate=0.0,
260
+ )
261
+ if self.normalize_before:
262
+ self.after_norm = LayerNorm(attention_dim)
263
+
264
+ def forward(self, xs, masks):
265
+ """Encode input sequence.
266
+
267
+ :param torch.Tensor xs: input tensor
268
+ :param torch.Tensor masks: input mask
269
+ :return: position embedded tensor and mask
270
+ :rtype Tuple[torch.Tensor, torch.Tensor]:
271
+ """
272
+ xs = self.embed(xs)
273
+
274
+ xs, masks = self.encoders(xs, masks)
275
+
276
+ if isinstance(xs, tuple):
277
+ xs = xs[0]
278
+
279
+ if self.normalize_before:
280
+ xs = self.after_norm(xs)
281
+
282
+ return xs, masks
283
+
284
+ def forward_one_step(self, xs, masks, cache=None):
285
+ """Encode input frame.
286
+
287
+ :param torch.Tensor xs: input tensor
288
+ :param torch.Tensor masks: input mask
289
+ :param List[torch.Tensor] cache: cache tensors
290
+ :return: position embedded tensor, mask and new cache
291
+ :rtype Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
292
+ """
293
+ xs, masks = self.embed(xs, masks)
294
+
295
+ if cache is None:
296
+ cache = [None for _ in range(len(self.encoders))]
297
+ new_cache = []
298
+ for c, e in zip(cache, self.encoders):
299
+ xs, masks = e(xs, masks, cache=c)
300
+ new_cache.append(xs)
301
+ if self.normalize_before:
302
+ xs = self.after_norm(xs)
303
+ return xs, masks, new_cache
espnet/nets/pytorch_backend/frontend/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (6.07 kB). View file
 
espnet/nets/pytorch_backend/frontend/__pycache__/resnet.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
espnet/nets/pytorch_backend/frontend/__pycache__/resnet1d.cpython-310.pyc ADDED
Binary file (6.23 kB). View file
 
espnet/nets/pytorch_backend/frontend/__pycache__/resnet1d.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
espnet/nets/pytorch_backend/frontend/resnet.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def conv3x3(in_planes, out_planes, stride=1):
5
+ """conv3x3.
6
+ :param in_planes: int, number of channels in the input sequence.
7
+ :param out_planes: int, number of channels produced by the convolution.
8
+ :param stride: int, size of the convolving kernel.
9
+ """
10
+ return nn.Conv2d(
11
+ in_planes,
12
+ out_planes,
13
+ kernel_size=3,
14
+ stride=stride,
15
+ padding=1,
16
+ bias=False,
17
+ )
18
+
19
+
20
+ def downsample_basic_block(inplanes, outplanes, stride):
21
+ """downsample_basic_block.
22
+ :param inplanes: int, number of channels in the input sequence.
23
+ :param outplanes: int, number of channels produced by the convolution.
24
+ :param stride: int, size of the convolving kernel.
25
+ """
26
+ return nn.Sequential(
27
+ nn.Conv2d(
28
+ inplanes,
29
+ outplanes,
30
+ kernel_size=1,
31
+ stride=stride,
32
+ bias=False,
33
+ ),
34
+ nn.BatchNorm2d(outplanes),
35
+ )
36
+
37
+
38
+ class BasicBlock(nn.Module):
39
+ expansion = 1
40
+
41
+ def __init__(
42
+ self,
43
+ inplanes,
44
+ planes,
45
+ stride=1,
46
+ downsample=None,
47
+ relu_type="swish",
48
+ ):
49
+ """__init__.
50
+ :param inplanes: int, number of channels in the input sequence.
51
+ :param planes: int, number of channels produced by the convolution.
52
+ :param stride: int, size of the convolving kernel.
53
+ :param downsample: boolean, if True, the temporal resolution is downsampled.
54
+ :param relu_type: str, type of activation function.
55
+ """
56
+ super(BasicBlock, self).__init__()
57
+
58
+ assert relu_type in ["relu", "prelu", "swish"]
59
+
60
+ self.conv1 = conv3x3(inplanes, planes, stride)
61
+ self.bn1 = nn.BatchNorm2d(planes)
62
+
63
+ if relu_type == "relu":
64
+ self.relu1 = nn.ReLU(inplace=True)
65
+ self.relu2 = nn.ReLU(inplace=True)
66
+ elif relu_type == "prelu":
67
+ self.relu1 = nn.PReLU(num_parameters=planes)
68
+ self.relu2 = nn.PReLU(num_parameters=planes)
69
+ elif relu_type == "swish":
70
+ self.relu1 = nn.SiLU(inplace=True)
71
+ self.relu2 = nn.SiLU(inplace=True)
72
+ else:
73
+ raise NotImplementedError
74
+ # --------
75
+
76
+ self.conv2 = conv3x3(planes, planes)
77
+ self.bn2 = nn.BatchNorm2d(planes)
78
+
79
+ self.downsample = downsample
80
+ self.stride = stride
81
+
82
+ def forward(self, x):
83
+ """forward.
84
+ :param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
85
+ """
86
+ residual = x
87
+ out = self.conv1(x)
88
+ out = self.bn1(out)
89
+ out = self.relu1(out)
90
+ out = self.conv2(out)
91
+ out = self.bn2(out)
92
+ if self.downsample is not None:
93
+ residual = self.downsample(x)
94
+
95
+ out += residual
96
+ out = self.relu2(out)
97
+
98
+ return out
99
+
100
+
101
+ class ResNet(nn.Module):
102
+ def __init__(
103
+ self,
104
+ block,
105
+ layers,
106
+ relu_type="swish",
107
+ ):
108
+ super(ResNet, self).__init__()
109
+ self.inplanes = 64
110
+ self.relu_type = relu_type
111
+ self.downsample_block = downsample_basic_block
112
+
113
+ self.layer1 = self._make_layer(block, 64, layers[0])
114
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
115
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
116
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
117
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
118
+
119
+ def _make_layer(self, block, planes, blocks, stride=1):
120
+ """_make_layer.
121
+ :param block: torch.nn.Module, class of blocks.
122
+ :param planes: int, number of channels produced by the convolution.
123
+ :param blocks: int, number of layers in a block.
124
+ :param stride: int, size of the convolving kernel.
125
+ """
126
+ downsample = None
127
+ if stride != 1 or self.inplanes != planes * block.expansion:
128
+ downsample = self.downsample_block(
129
+ inplanes=self.inplanes,
130
+ outplanes=planes * block.expansion,
131
+ stride=stride,
132
+ )
133
+
134
+ layers = []
135
+ layers.append(
136
+ block(
137
+ self.inplanes,
138
+ planes,
139
+ stride,
140
+ downsample,
141
+ relu_type=self.relu_type,
142
+ )
143
+ )
144
+ self.inplanes = planes * block.expansion
145
+ for _ in range(1, blocks):
146
+ layers.append(
147
+ block(
148
+ self.inplanes,
149
+ planes,
150
+ relu_type=self.relu_type,
151
+ )
152
+ )
153
+
154
+ return nn.Sequential(*layers)
155
+
156
+ def forward(self, x):
157
+ """forward.
158
+ :param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
159
+ """
160
+ x = self.layer1(x)
161
+ x = self.layer2(x)
162
+ x = self.layer3(x)
163
+ x = self.layer4(x)
164
+ x = self.avgpool(x)
165
+ x = x.view(x.size(0), -1)
166
+ return x
167
+
168
+
169
+ # -- auxiliary functions
170
+ def threeD_to_2D_tensor(x):
171
+ n_batch, n_channels, s_time, sx, sy = x.shape
172
+ x = x.transpose(1, 2)
173
+ return x.reshape(n_batch * s_time, n_channels, sx, sy)
174
+
175
+
176
+ class Conv3dResNet(nn.Module):
177
+ """Conv3dResNet module"""
178
+
179
+ def __init__(self, backbone_type="resnet", relu_type="swish"):
180
+ """__init__.
181
+ :param backbone_type: str, the type of a visual front-end.
182
+ :param relu_type: str, activation function used in an audio front-end.
183
+ """
184
+ super(Conv3dResNet, self).__init__()
185
+
186
+ self.backbone_type = backbone_type
187
+
188
+ self.frontend_nout = 64
189
+ self.trunk = ResNet(
190
+ BasicBlock,
191
+ [2, 2, 2, 2],
192
+ relu_type=relu_type,
193
+ )
194
+
195
+ # -- frontend3D
196
+ if relu_type == "relu":
197
+ frontend_relu = nn.ReLU(True)
198
+ elif relu_type == "prelu":
199
+ frontend_relu = nn.PReLU(self.frontend_nout)
200
+ elif relu_type == "swish":
201
+ frontend_relu = nn.SiLU(inplace=True)
202
+
203
+ self.frontend3D = nn.Sequential(
204
+ nn.Conv3d(
205
+ in_channels=1,
206
+ out_channels=self.frontend_nout,
207
+ kernel_size=(5, 7, 7),
208
+ stride=(1, 2, 2),
209
+ padding=(2, 3, 3),
210
+ bias=False,
211
+ ),
212
+ nn.BatchNorm3d(self.frontend_nout),
213
+ frontend_relu,
214
+ nn.MaxPool3d(
215
+ kernel_size=(1, 3, 3),
216
+ stride=(1, 2, 2),
217
+ padding=(0, 1, 1),
218
+ ),
219
+ )
220
+
221
+ def forward(self, xs_pad):
222
+ """forward.
223
+ :param xs_pad: torch.Tensor, batch of padded input sequences.
224
+ """
225
+ # -- include Channel dimension
226
+ xs_pad = xs_pad.transpose(2, 1)
227
+ B, C, T, H, W = xs_pad.size()
228
+ xs_pad = self.frontend3D(xs_pad)
229
+ Tnew = xs_pad.shape[2] # outpu should be B x C2 x Tnew x H x W
230
+ xs_pad = threeD_to_2D_tensor(xs_pad)
231
+ xs_pad = self.trunk(xs_pad)
232
+ xs_pad = xs_pad.view(B, Tnew, xs_pad.size(1))
233
+ return xs_pad
234
+
235
+
236
+ def video_resnet():
237
+ return Conv3dResNet()
espnet/nets/pytorch_backend/frontend/resnet1d.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def conv3x3(in_planes, out_planes, stride=1):
5
+ """conv3x3.
6
+ :param in_planes: int, number of channels in the input sequence.
7
+ :param out_planes: int, number of channels produced by the convolution.
8
+ :param stride: int, size of the convolving kernel.
9
+ """
10
+ return nn.Conv1d(
11
+ in_planes,
12
+ out_planes,
13
+ kernel_size=3,
14
+ stride=stride,
15
+ padding=1,
16
+ bias=False,
17
+ )
18
+
19
+
20
+ def downsample_basic_block(inplanes, outplanes, stride):
21
+ """downsample_basic_block.
22
+ :param inplanes: int, number of channels in the input sequence.
23
+ :param outplanes: int, number of channels produced by the convolution.
24
+ :param stride: int, size of the convolving kernel.
25
+ """
26
+ return nn.Sequential(
27
+ nn.Conv1d(
28
+ inplanes,
29
+ outplanes,
30
+ kernel_size=1,
31
+ stride=stride,
32
+ bias=False,
33
+ ),
34
+ nn.BatchNorm1d(outplanes),
35
+ )
36
+
37
+
38
+ class BasicBlock1D(nn.Module):
39
+ expansion = 1
40
+
41
+ def __init__(
42
+ self,
43
+ inplanes,
44
+ planes,
45
+ stride=1,
46
+ downsample=None,
47
+ relu_type="relu",
48
+ ):
49
+ """__init__.
50
+ :param inplanes: int, number of channels in the input sequence.
51
+ :param planes: int, number of channels produced by the convolution.
52
+ :param stride: int, size of the convolving kernel.
53
+ :param downsample: boolean, if True, the temporal resolution is downsampled.
54
+ :param relu_type: str, type of activation function.
55
+ """
56
+ super(BasicBlock1D, self).__init__()
57
+
58
+ assert relu_type in ["relu", "prelu", "swish"]
59
+
60
+ self.conv1 = conv3x3(inplanes, planes, stride)
61
+ self.bn1 = nn.BatchNorm1d(planes)
62
+
63
+ # type of ReLU is an input option
64
+ if relu_type == "relu":
65
+ self.relu1 = nn.ReLU(inplace=True)
66
+ self.relu2 = nn.ReLU(inplace=True)
67
+ elif relu_type == "prelu":
68
+ self.relu1 = nn.PReLU(num_parameters=planes)
69
+ self.relu2 = nn.PReLU(num_parameters=planes)
70
+ elif relu_type == "swish":
71
+ self.relu1 = nn.SiLU(inplace=True)
72
+ self.relu2 = nn.SiLU(inplace=True)
73
+ else:
74
+ raise NotImplementedError
75
+ # --------
76
+
77
+ self.conv2 = conv3x3(planes, planes)
78
+ self.bn2 = nn.BatchNorm1d(planes)
79
+
80
+ self.downsample = downsample
81
+ self.stride = stride
82
+
83
+ def forward(self, x):
84
+ """forward.
85
+ :param x: torch.Tensor, input tensor with input size (B, C, T)
86
+ """
87
+ residual = x
88
+ out = self.conv1(x)
89
+ out = self.bn1(out)
90
+ out = self.relu1(out)
91
+ out = self.conv2(out)
92
+ out = self.bn2(out)
93
+ if self.downsample is not None:
94
+ residual = self.downsample(x)
95
+
96
+ out += residual
97
+ out = self.relu2(out)
98
+
99
+ return out
100
+
101
+
102
+ class ResNet1D(nn.Module):
103
+ def __init__(
104
+ self,
105
+ block,
106
+ layers,
107
+ relu_type="swish",
108
+ a_upsample_ratio=1,
109
+ ):
110
+ """__init__.
111
+ :param block: torch.nn.Module, class of blocks.
112
+ :param layers: List, customised layers in each block.
113
+ :param relu_type: str, type of activation function.
114
+ :param a_upsample_ratio: int, The ratio related to the \
115
+ temporal resolution of output features of the frontend. \
116
+ a_upsample_ratio=1 produce features with a fps of 25.
117
+ """
118
+ super(ResNet1D, self).__init__()
119
+ self.inplanes = 64
120
+ self.relu_type = relu_type
121
+ self.downsample_block = downsample_basic_block
122
+ self.a_upsample_ratio = a_upsample_ratio
123
+
124
+ self.conv1 = nn.Conv1d(
125
+ in_channels=1,
126
+ out_channels=self.inplanes,
127
+ kernel_size=80,
128
+ stride=4,
129
+ padding=38,
130
+ bias=False,
131
+ )
132
+ self.bn1 = nn.BatchNorm1d(self.inplanes)
133
+
134
+ if relu_type == "relu":
135
+ self.relu = nn.ReLU(inplace=True)
136
+ elif relu_type == "prelu":
137
+ self.relu = nn.PReLU(num_parameters=self.inplanes)
138
+ elif relu_type == "swish":
139
+ self.relu = nn.SiLU(inplace=True)
140
+
141
+ self.layer1 = self._make_layer(block, 64, layers[0])
142
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
143
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
144
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
145
+ self.avgpool = nn.AvgPool1d(
146
+ kernel_size=20 // self.a_upsample_ratio,
147
+ stride=20 // self.a_upsample_ratio,
148
+ )
149
+
150
+ def _make_layer(self, block, planes, blocks, stride=1):
151
+ """_make_layer.
152
+ :param block: torch.nn.Module, class of blocks.
153
+ :param planes: int, number of channels produced by the convolution.
154
+ :param blocks: int, number of layers in a block.
155
+ :param stride: int, size of the convolving kernel.
156
+ """
157
+
158
+ downsample = None
159
+ if stride != 1 or self.inplanes != planes * block.expansion:
160
+ downsample = self.downsample_block(
161
+ inplanes=self.inplanes,
162
+ outplanes=planes * block.expansion,
163
+ stride=stride,
164
+ )
165
+
166
+ layers = []
167
+ layers.append(
168
+ block(
169
+ self.inplanes,
170
+ planes,
171
+ stride,
172
+ downsample,
173
+ relu_type=self.relu_type,
174
+ )
175
+ )
176
+ self.inplanes = planes * block.expansion
177
+ for _ in range(1, blocks):
178
+ layers.append(
179
+ block(
180
+ self.inplanes,
181
+ planes,
182
+ relu_type=self.relu_type,
183
+ )
184
+ )
185
+
186
+ return nn.Sequential(*layers)
187
+
188
+ def forward(self, x):
189
+ """forward.
190
+ :param x: torch.Tensor, input tensor with input size (B, C, T)
191
+ """
192
+ x = self.conv1(x)
193
+ x = self.bn1(x)
194
+ x = self.relu(x)
195
+
196
+ x = self.layer1(x)
197
+ x = self.layer2(x)
198
+ x = self.layer3(x)
199
+ x = self.layer4(x)
200
+ x = self.avgpool(x)
201
+ return x
202
+
203
+
204
+ class Conv1dResNet(nn.Module):
205
+ """Conv1dResNet"""
206
+
207
+ def __init__(self, relu_type="swish", a_upsample_ratio=1):
208
+ """__init__.
209
+ :param relu_type: str, Activation function used in an audio front-end.
210
+ :param a_upsample_ratio: int, The ratio related to the \
211
+ temporal resolution of output features of the frontend. \
212
+ a_upsample_ratio=1 produce features with a fps of 25.
213
+ """
214
+
215
+ super(Conv1dResNet, self).__init__()
216
+ self.a_upsample_ratio = a_upsample_ratio
217
+ self.trunk = ResNet1D(
218
+ BasicBlock1D,
219
+ [2, 2, 2, 2],
220
+ relu_type=relu_type,
221
+ a_upsample_ratio=a_upsample_ratio,
222
+ )
223
+
224
+ def forward(self, xs_pad):
225
+ """forward.
226
+ :param xs_pad: torch.Tensor, batch of padded input sequences (B, Tmax, idim)
227
+ """
228
+ B, T, C = xs_pad.size()
229
+ xs_pad = xs_pad[:, : T // 640 * 640, :]
230
+ xs_pad = xs_pad.transpose(1, 2)
231
+ xs_pad = self.trunk(xs_pad)
232
+ # -- from B x C x T to B x T x C
233
+ xs_pad = xs_pad.transpose(1, 2)
234
+ return xs_pad
235
+
236
+
237
+ def audio_resnet():
238
+ return Conv1dResNet()
espnet/nets/pytorch_backend/nets_utils.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """Network related utility tools."""
4
+
5
+ import logging
6
+ from typing import Dict
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def to_device(m, x):
13
+ """Send tensor into the device of the module.
14
+
15
+ Args:
16
+ m (torch.nn.Module): Torch module.
17
+ x (Tensor): Torch tensor.
18
+
19
+ Returns:
20
+ Tensor: Torch tensor located in the same place as torch module.
21
+
22
+ """
23
+ if isinstance(m, torch.nn.Module):
24
+ device = next(m.parameters()).device
25
+ elif isinstance(m, torch.Tensor):
26
+ device = m.device
27
+ else:
28
+ raise TypeError(
29
+ "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
30
+ )
31
+ return x.to(device)
32
+
33
+
34
+ def pad_list(xs, pad_value):
35
+ """Perform padding for the list of tensors.
36
+
37
+ Args:
38
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
39
+ pad_value (float): Value for padding.
40
+
41
+ Returns:
42
+ Tensor: Padded tensor (B, Tmax, `*`).
43
+
44
+ Examples:
45
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
46
+ >>> x
47
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
48
+ >>> pad_list(x, 0)
49
+ tensor([[1., 1., 1., 1.],
50
+ [1., 1., 0., 0.],
51
+ [1., 0., 0., 0.]])
52
+
53
+ """
54
+ n_batch = len(xs)
55
+ max_len = max(x.size(0) for x in xs)
56
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
57
+
58
+ for i in range(n_batch):
59
+ pad[i, : xs[i].size(0)] = xs[i]
60
+
61
+ return pad
62
+
63
+
64
+ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
65
+ """Make mask tensor containing indices of padded part.
66
+
67
+ Args:
68
+ lengths (LongTensor or List): Batch of lengths (B,).
69
+ xs (Tensor, optional): The reference tensor.
70
+ If set, masks will be the same shape as this tensor.
71
+ length_dim (int, optional): Dimension indicator of the above tensor.
72
+ See the example.
73
+
74
+ Returns:
75
+ Tensor: Mask tensor containing indices of padded part.
76
+ dtype=torch.uint8 in PyTorch 1.2-
77
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
78
+
79
+ Examples:
80
+ With only lengths.
81
+
82
+ >>> lengths = [5, 3, 2]
83
+ >>> make_pad_mask(lengths)
84
+ masks = [[0, 0, 0, 0 ,0],
85
+ [0, 0, 0, 1, 1],
86
+ [0, 0, 1, 1, 1]]
87
+
88
+ With the reference tensor.
89
+
90
+ >>> xs = torch.zeros((3, 2, 4))
91
+ >>> make_pad_mask(lengths, xs)
92
+ tensor([[[0, 0, 0, 0],
93
+ [0, 0, 0, 0]],
94
+ [[0, 0, 0, 1],
95
+ [0, 0, 0, 1]],
96
+ [[0, 0, 1, 1],
97
+ [0, 0, 1, 1]]], dtype=torch.uint8)
98
+ >>> xs = torch.zeros((3, 2, 6))
99
+ >>> make_pad_mask(lengths, xs)
100
+ tensor([[[0, 0, 0, 0, 0, 1],
101
+ [0, 0, 0, 0, 0, 1]],
102
+ [[0, 0, 0, 1, 1, 1],
103
+ [0, 0, 0, 1, 1, 1]],
104
+ [[0, 0, 1, 1, 1, 1],
105
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
106
+
107
+ With the reference tensor and dimension indicator.
108
+
109
+ >>> xs = torch.zeros((3, 6, 6))
110
+ >>> make_pad_mask(lengths, xs, 1)
111
+ tensor([[[0, 0, 0, 0, 0, 0],
112
+ [0, 0, 0, 0, 0, 0],
113
+ [0, 0, 0, 0, 0, 0],
114
+ [0, 0, 0, 0, 0, 0],
115
+ [0, 0, 0, 0, 0, 0],
116
+ [1, 1, 1, 1, 1, 1]],
117
+ [[0, 0, 0, 0, 0, 0],
118
+ [0, 0, 0, 0, 0, 0],
119
+ [0, 0, 0, 0, 0, 0],
120
+ [1, 1, 1, 1, 1, 1],
121
+ [1, 1, 1, 1, 1, 1],
122
+ [1, 1, 1, 1, 1, 1]],
123
+ [[0, 0, 0, 0, 0, 0],
124
+ [0, 0, 0, 0, 0, 0],
125
+ [1, 1, 1, 1, 1, 1],
126
+ [1, 1, 1, 1, 1, 1],
127
+ [1, 1, 1, 1, 1, 1],
128
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
129
+ >>> make_pad_mask(lengths, xs, 2)
130
+ tensor([[[0, 0, 0, 0, 0, 1],
131
+ [0, 0, 0, 0, 0, 1],
132
+ [0, 0, 0, 0, 0, 1],
133
+ [0, 0, 0, 0, 0, 1],
134
+ [0, 0, 0, 0, 0, 1],
135
+ [0, 0, 0, 0, 0, 1]],
136
+ [[0, 0, 0, 1, 1, 1],
137
+ [0, 0, 0, 1, 1, 1],
138
+ [0, 0, 0, 1, 1, 1],
139
+ [0, 0, 0, 1, 1, 1],
140
+ [0, 0, 0, 1, 1, 1],
141
+ [0, 0, 0, 1, 1, 1]],
142
+ [[0, 0, 1, 1, 1, 1],
143
+ [0, 0, 1, 1, 1, 1],
144
+ [0, 0, 1, 1, 1, 1],
145
+ [0, 0, 1, 1, 1, 1],
146
+ [0, 0, 1, 1, 1, 1],
147
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
148
+
149
+ """
150
+ if length_dim == 0:
151
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
152
+
153
+ if not isinstance(lengths, list):
154
+ lengths = lengths.tolist()
155
+ bs = int(len(lengths))
156
+ if maxlen is None:
157
+ if xs is None:
158
+ maxlen = int(max(lengths))
159
+ else:
160
+ maxlen = xs.size(length_dim)
161
+ else:
162
+ assert xs is None
163
+ assert maxlen >= int(max(lengths))
164
+
165
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
166
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
167
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
168
+ mask = seq_range_expand >= seq_length_expand
169
+
170
+ if xs is not None:
171
+ assert xs.size(0) == bs, (xs.size(0), bs)
172
+
173
+ if length_dim < 0:
174
+ length_dim = xs.dim() + length_dim
175
+ # ind = (:, None, ..., None, :, , None, ..., None)
176
+ ind = tuple(
177
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
178
+ )
179
+ mask = mask[ind].expand_as(xs).to(xs.device)
180
+ return mask
181
+
182
+
183
+ def make_non_pad_mask(lengths, xs=None, length_dim=-1):
184
+ """Make mask tensor containing indices of non-padded part.
185
+
186
+ Args:
187
+ lengths (LongTensor or List): Batch of lengths (B,).
188
+ xs (Tensor, optional): The reference tensor.
189
+ If set, masks will be the same shape as this tensor.
190
+ length_dim (int, optional): Dimension indicator of the above tensor.
191
+ See the example.
192
+
193
+ Returns:
194
+ ByteTensor: mask tensor containing indices of padded part.
195
+ dtype=torch.uint8 in PyTorch 1.2-
196
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
197
+
198
+ Examples:
199
+ With only lengths.
200
+
201
+ >>> lengths = [5, 3, 2]
202
+ >>> make_non_pad_mask(lengths)
203
+ masks = [[1, 1, 1, 1 ,1],
204
+ [1, 1, 1, 0, 0],
205
+ [1, 1, 0, 0, 0]]
206
+
207
+ With the reference tensor.
208
+
209
+ >>> xs = torch.zeros((3, 2, 4))
210
+ >>> make_non_pad_mask(lengths, xs)
211
+ tensor([[[1, 1, 1, 1],
212
+ [1, 1, 1, 1]],
213
+ [[1, 1, 1, 0],
214
+ [1, 1, 1, 0]],
215
+ [[1, 1, 0, 0],
216
+ [1, 1, 0, 0]]], dtype=torch.uint8)
217
+ >>> xs = torch.zeros((3, 2, 6))
218
+ >>> make_non_pad_mask(lengths, xs)
219
+ tensor([[[1, 1, 1, 1, 1, 0],
220
+ [1, 1, 1, 1, 1, 0]],
221
+ [[1, 1, 1, 0, 0, 0],
222
+ [1, 1, 1, 0, 0, 0]],
223
+ [[1, 1, 0, 0, 0, 0],
224
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
225
+
226
+ With the reference tensor and dimension indicator.
227
+
228
+ >>> xs = torch.zeros((3, 6, 6))
229
+ >>> make_non_pad_mask(lengths, xs, 1)
230
+ tensor([[[1, 1, 1, 1, 1, 1],
231
+ [1, 1, 1, 1, 1, 1],
232
+ [1, 1, 1, 1, 1, 1],
233
+ [1, 1, 1, 1, 1, 1],
234
+ [1, 1, 1, 1, 1, 1],
235
+ [0, 0, 0, 0, 0, 0]],
236
+ [[1, 1, 1, 1, 1, 1],
237
+ [1, 1, 1, 1, 1, 1],
238
+ [1, 1, 1, 1, 1, 1],
239
+ [0, 0, 0, 0, 0, 0],
240
+ [0, 0, 0, 0, 0, 0],
241
+ [0, 0, 0, 0, 0, 0]],
242
+ [[1, 1, 1, 1, 1, 1],
243
+ [1, 1, 1, 1, 1, 1],
244
+ [0, 0, 0, 0, 0, 0],
245
+ [0, 0, 0, 0, 0, 0],
246
+ [0, 0, 0, 0, 0, 0],
247
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
248
+ >>> make_non_pad_mask(lengths, xs, 2)
249
+ tensor([[[1, 1, 1, 1, 1, 0],
250
+ [1, 1, 1, 1, 1, 0],
251
+ [1, 1, 1, 1, 1, 0],
252
+ [1, 1, 1, 1, 1, 0],
253
+ [1, 1, 1, 1, 1, 0],
254
+ [1, 1, 1, 1, 1, 0]],
255
+ [[1, 1, 1, 0, 0, 0],
256
+ [1, 1, 1, 0, 0, 0],
257
+ [1, 1, 1, 0, 0, 0],
258
+ [1, 1, 1, 0, 0, 0],
259
+ [1, 1, 1, 0, 0, 0],
260
+ [1, 1, 1, 0, 0, 0]],
261
+ [[1, 1, 0, 0, 0, 0],
262
+ [1, 1, 0, 0, 0, 0],
263
+ [1, 1, 0, 0, 0, 0],
264
+ [1, 1, 0, 0, 0, 0],
265
+ [1, 1, 0, 0, 0, 0],
266
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
267
+
268
+ """
269
+ return ~make_pad_mask(lengths, xs, length_dim)
270
+
271
+
272
+ def th_accuracy(pad_outputs, pad_targets, ignore_label):
273
+ """Calculate accuracy.
274
+
275
+ Args:
276
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
277
+ pad_targets (LongTensor): Target label tensors (B, Lmax, D).
278
+ ignore_label (int): Ignore label id.
279
+
280
+ Returns:
281
+ float: Accuracy value (0.0 - 1.0).
282
+
283
+ """
284
+ pad_pred = pad_outputs.view(
285
+ pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
286
+ ).argmax(2)
287
+ mask = pad_targets != ignore_label
288
+ numerator = torch.sum(
289
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
290
+ )
291
+ denominator = torch.sum(mask)
292
+ return float(numerator) / float(denominator)
293
+
294
+
295
+ def rename_state_dict(
296
+ old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
297
+ ):
298
+ """Replace keys of old prefix with new prefix in state dict."""
299
+ # need this list not to break the dict iterator
300
+ old_keys = [k for k in state_dict if k.startswith(old_prefix)]
301
+ if len(old_keys) > 0:
302
+ logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
303
+ for k in old_keys:
304
+ v = state_dict.pop(k)
305
+ new_k = k.replace(old_prefix, new_prefix)
306
+ state_dict[new_k] = v
espnet/nets/pytorch_backend/transformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/nets/pytorch_backend/transformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (205 Bytes). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (236 Bytes). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/add_sos_eos.cpython-310.pyc ADDED
Binary file (1.32 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/add_sos_eos.cpython-311.pyc ADDED
Binary file (1.96 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/attention.cpython-310.pyc ADDED
Binary file (7.08 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/attention.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/embedding.cpython-310.pyc ADDED
Binary file (5.9 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/embedding.cpython-311.pyc ADDED
Binary file (11.3 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc ADDED
Binary file (2.14 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/label_smoothing_loss.cpython-311.pyc ADDED
Binary file (3.71 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/layer_norm.cpython-310.pyc ADDED
Binary file (1.17 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/layer_norm.cpython-311.pyc ADDED
Binary file (1.84 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/mask.cpython-310.pyc ADDED
Binary file (1.18 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/mask.cpython-311.pyc ADDED
Binary file (1.59 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc ADDED
Binary file (1.92 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/repeat.cpython-310.pyc ADDED
Binary file (1.73 kB). View file
 
espnet/nets/pytorch_backend/transformer/__pycache__/repeat.cpython-311.pyc ADDED
Binary file (2.43 kB). View file
 
espnet/nets/pytorch_backend/transformer/add_sos_eos.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Unility funcitons for Transformer."""
8
+
9
+ import torch
10
+
11
+
12
+ def add_sos_eos(ys_pad, sos, eos, ignore_id):
13
+ """Add <sos> and <eos> labels.
14
+
15
+ :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
16
+ :param int sos: index of <sos>
17
+ :param int eos: index of <eeos>
18
+ :param int ignore_id: index of padding
19
+ :return: padded tensor (B, Lmax)
20
+ :rtype: torch.Tensor
21
+ :return: padded tensor (B, Lmax)
22
+ :rtype: torch.Tensor
23
+ """
24
+ from espnet.nets.pytorch_backend.nets_utils import pad_list
25
+
26
+ _sos = ys_pad.new([sos])
27
+ _eos = ys_pad.new([eos])
28
+ ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
29
+ ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
30
+ ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
31
+ return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
espnet/nets/pytorch_backend/transformer/attention.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Multi-Head Attention layer definition."""
8
+
9
+ import math
10
+
11
+ import numpy
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class MultiHeadedAttention(nn.Module):
17
+ """Multi-Head Attention layer.
18
+ Args:
19
+ n_head (int): The number of heads.
20
+ n_feat (int): The number of features.
21
+ dropout_rate (float): Dropout rate.
22
+ """
23
+
24
+ def __init__(self, n_head, n_feat, dropout_rate):
25
+ """Construct an MultiHeadedAttention object."""
26
+ super(MultiHeadedAttention, self).__init__()
27
+ assert n_feat % n_head == 0
28
+ # We assume d_v always equals d_k
29
+ self.d_k = n_feat // n_head
30
+ self.h = n_head
31
+ self.linear_q = nn.Linear(n_feat, n_feat)
32
+ self.linear_k = nn.Linear(n_feat, n_feat)
33
+ self.linear_v = nn.Linear(n_feat, n_feat)
34
+ self.linear_out = nn.Linear(n_feat, n_feat)
35
+ self.attn = None
36
+ self.dropout = nn.Dropout(p=dropout_rate)
37
+
38
+ def forward_qkv(self, query, key, value):
39
+ """Transform query, key and value.
40
+ Args:
41
+ query (torch.Tensor): Query tensor (#batch, time1, size).
42
+ key (torch.Tensor): Key tensor (#batch, time2, size).
43
+ value (torch.Tensor): Value tensor (#batch, time2, size).
44
+ Returns:
45
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
46
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
47
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
48
+ """
49
+ n_batch = query.size(0)
50
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
51
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
52
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
53
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
54
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
55
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
56
+
57
+ return q, k, v
58
+
59
+ def forward_attention(self, value, scores, mask, rtn_attn=False):
60
+ """Compute attention context vector.
61
+ Args:
62
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
63
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
64
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
65
+ rtn_attn (boolean): Flag of return attention score
66
+ Returns:
67
+ torch.Tensor: Transformed value (#batch, time1, d_model)
68
+ weighted by the attention score (#batch, time1, time2).
69
+ """
70
+ n_batch = value.size(0)
71
+ if mask is not None:
72
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
73
+ min_value = torch.finfo(scores.dtype).min
74
+ scores = scores.masked_fill(mask, min_value)
75
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
76
+ mask, 0.0
77
+ ) # (batch, head, time1, time2)
78
+ else:
79
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
80
+
81
+ p_attn = self.dropout(self.attn)
82
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
83
+ x = (
84
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
85
+ ) # (batch, time1, d_model)
86
+ if rtn_attn:
87
+ return self.linear_out(x), self.attn
88
+ return self.linear_out(x) # (batch, time1, d_model)
89
+
90
+ def forward(self, query, key, value, mask, rtn_attn=False):
91
+ """Compute scaled dot product attention.
92
+ Args:
93
+ query (torch.Tensor): Query tensor (#batch, time1, size).
94
+ key (torch.Tensor): Key tensor (#batch, time2, size).
95
+ value (torch.Tensor): Value tensor (#batch, time2, size).
96
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
97
+ (#batch, time1, time2).
98
+ rtn_attn (boolean): Flag of return attention score
99
+ Returns:
100
+ torch.Tensor: Output tensor (#batch, time1, d_model).
101
+ """
102
+ q, k, v = self.forward_qkv(query, key, value)
103
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
104
+ return self.forward_attention(v, scores, mask, rtn_attn)
105
+
106
+
107
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
108
+ """Multi-Head Attention layer with relative position encoding (new implementation).
109
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
110
+ Paper: https://arxiv.org/abs/1901.02860
111
+ Args:
112
+ n_head (int): The number of heads.
113
+ n_feat (int): The number of features.
114
+ dropout_rate (float): Dropout rate.
115
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
116
+ """
117
+
118
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
119
+ """Construct an RelPositionMultiHeadedAttention object."""
120
+ super().__init__(n_head, n_feat, dropout_rate)
121
+ self.zero_triu = zero_triu
122
+ # linear transformation for positional encoding
123
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
124
+ # these two learnable bias are used in matrix c and matrix d
125
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
126
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
127
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
128
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
129
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
130
+
131
+ def rel_shift(self, x):
132
+ """Compute relative positional encoding.
133
+ Args:
134
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
135
+ time1 means the length of query vector.
136
+ Returns:
137
+ torch.Tensor: Output tensor.
138
+ """
139
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
140
+ x_padded = torch.cat([zero_pad, x], dim=-1)
141
+
142
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
143
+ x = x_padded[:, :, 1:].view_as(x)[
144
+ :, :, :, : x.size(-1) // 2 + 1
145
+ ] # only keep the positions from 0 to time2
146
+
147
+ if self.zero_triu:
148
+ ones = torch.ones((x.size(2), x.size(3)), device=x.device)
149
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
150
+
151
+ return x
152
+
153
+ def forward(self, query, key, value, pos_emb, mask):
154
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
155
+ Args:
156
+ query (torch.Tensor): Query tensor (#batch, time1, size).
157
+ key (torch.Tensor): Key tensor (#batch, time2, size).
158
+ value (torch.Tensor): Value tensor (#batch, time2, size).
159
+ pos_emb (torch.Tensor): Positional embedding tensor
160
+ (#batch, 2*time1-1, size).
161
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
162
+ (#batch, time1, time2).
163
+ Returns:
164
+ torch.Tensor: Output tensor (#batch, time1, d_model).
165
+ """
166
+ q, k, v = self.forward_qkv(query, key, value)
167
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
168
+
169
+ n_batch_pos = pos_emb.size(0)
170
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
171
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
172
+
173
+ # (batch, head, time1, d_k)
174
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
175
+ # (batch, head, time1, d_k)
176
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
177
+
178
+ # compute attention score
179
+ # first compute matrix a and matrix c
180
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
181
+ # (batch, head, time1, time2)
182
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
183
+
184
+ # compute matrix b and matrix d
185
+ # (batch, head, time1, 2*time1-1)
186
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
187
+ matrix_bd = self.rel_shift(matrix_bd)
188
+
189
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
190
+ self.d_k
191
+ ) # (batch, head, time1, time2)
192
+
193
+ return self.forward_attention(v, scores, mask)
espnet/nets/pytorch_backend/transformer/embedding.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Positional Encoding Module."""
8
+
9
+ import math
10
+
11
+ import torch
12
+
13
+
14
+ def _pre_hook(
15
+ state_dict,
16
+ prefix,
17
+ local_metadata,
18
+ strict,
19
+ missing_keys,
20
+ unexpected_keys,
21
+ error_msgs,
22
+ ):
23
+ """Perform pre-hook in load_state_dict for backward compatibility.
24
+ Note:
25
+ We saved self.pe until v.0.5.2 but we have omitted it later.
26
+ Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
27
+ """
28
+ k = prefix + "pe"
29
+ if k in state_dict:
30
+ state_dict.pop(k)
31
+
32
+
33
+ class PositionalEncoding(torch.nn.Module):
34
+ """Positional encoding.
35
+ Args:
36
+ d_model (int): Embedding dimension.
37
+ dropout_rate (float): Dropout rate.
38
+ max_len (int): Maximum input length.
39
+ reverse (bool): Whether to reverse the input position. Only for
40
+ the class LegacyRelPositionalEncoding. We remove it in the current
41
+ class RelPositionalEncoding.
42
+ """
43
+
44
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
45
+ """Construct an PositionalEncoding object."""
46
+ super(PositionalEncoding, self).__init__()
47
+ self.d_model = d_model
48
+ self.reverse = reverse
49
+ self.xscale = math.sqrt(self.d_model)
50
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
51
+ self.pe = None
52
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
53
+ self._register_load_state_dict_pre_hook(_pre_hook)
54
+
55
+ def extend_pe(self, x):
56
+ """Reset the positional encodings."""
57
+ if self.pe is not None:
58
+ if self.pe.size(1) >= x.size(1):
59
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
60
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
61
+ return
62
+ pe = torch.zeros(x.size(1), self.d_model)
63
+ if self.reverse:
64
+ position = torch.arange(
65
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
66
+ ).unsqueeze(1)
67
+ else:
68
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
69
+ div_term = torch.exp(
70
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
71
+ * -(math.log(10000.0) / self.d_model)
72
+ )
73
+ pe[:, 0::2] = torch.sin(position * div_term)
74
+ pe[:, 1::2] = torch.cos(position * div_term)
75
+ pe = pe.unsqueeze(0)
76
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
77
+
78
+ def forward(self, x: torch.Tensor):
79
+ """Add positional encoding.
80
+ Args:
81
+ x (torch.Tensor): Input tensor (batch, time, `*`).
82
+ Returns:
83
+ torch.Tensor: Encoded tensor (batch, time, `*`).
84
+ """
85
+ self.extend_pe(x)
86
+ x = x * self.xscale + self.pe[:, : x.size(1)]
87
+ return self.dropout(x)
88
+
89
+
90
+ class ScaledPositionalEncoding(PositionalEncoding):
91
+ """Scaled positional encoding module.
92
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
93
+ Args:
94
+ d_model (int): Embedding dimension.
95
+ dropout_rate (float): Dropout rate.
96
+ max_len (int): Maximum input length.
97
+ """
98
+
99
+ def __init__(self, d_model, dropout_rate, max_len=5000):
100
+ """Initialize class."""
101
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
102
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
103
+
104
+ def reset_parameters(self):
105
+ """Reset parameters."""
106
+ self.alpha.data = torch.tensor(1.0)
107
+
108
+ def forward(self, x):
109
+ """Add positional encoding.
110
+ Args:
111
+ x (torch.Tensor): Input tensor (batch, time, `*`).
112
+ Returns:
113
+ torch.Tensor: Encoded tensor (batch, time, `*`).
114
+ """
115
+ self.extend_pe(x)
116
+ x = x + self.alpha * self.pe[:, : x.size(1)]
117
+ return self.dropout(x)
118
+
119
+
120
+ class RelPositionalEncoding(torch.nn.Module):
121
+ """Relative positional encoding module (new implementation).
122
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
123
+ See : Appendix B in https://arxiv.org/abs/1901.02860
124
+ Args:
125
+ d_model (int): Embedding dimension.
126
+ dropout_rate (float): Dropout rate.
127
+ max_len (int): Maximum input length.
128
+ """
129
+
130
+ def __init__(self, d_model, dropout_rate, max_len=5000):
131
+ """Construct an PositionalEncoding object."""
132
+ super(RelPositionalEncoding, self).__init__()
133
+ self.d_model = d_model
134
+ self.xscale = math.sqrt(self.d_model)
135
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
136
+ self.pe = None
137
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
138
+
139
+ def extend_pe(self, x):
140
+ """Reset the positional encodings."""
141
+ if self.pe is not None:
142
+ # self.pe contains both positive and negative parts
143
+ # the length of self.pe is 2 * input_len - 1
144
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
145
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
146
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
147
+ return
148
+ # Suppose `i` means to the position of query vecotr and `j` means the
149
+ # position of key vector. We use position relative positions when keys
150
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
151
+ pe_positive = torch.zeros(x.size(1), self.d_model)
152
+ pe_negative = torch.zeros(x.size(1), self.d_model)
153
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
154
+ div_term = torch.exp(
155
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
156
+ * -(math.log(10000.0) / self.d_model)
157
+ )
158
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
159
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
160
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
161
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
162
+
163
+ # Reserve the order of positive indices and concat both positive and
164
+ # negative indices. This is used to support the shifting trick
165
+ # as in https://arxiv.org/abs/1901.02860
166
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
167
+ pe_negative = pe_negative[1:].unsqueeze(0)
168
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
169
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
170
+
171
+ def forward(self, x: torch.Tensor):
172
+ """Add positional encoding.
173
+ Args:
174
+ x (torch.Tensor): Input tensor (batch, time, `*`).
175
+ Returns:
176
+ torch.Tensor: Encoded tensor (batch, time, `*`).
177
+ """
178
+ self.extend_pe(x)
179
+ x = x * self.xscale
180
+ pos_emb = self.pe[
181
+ :,
182
+ self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
183
+ ]
184
+ return self.dropout(x), self.dropout(pos_emb)
espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Label smoothing module."""
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class LabelSmoothingLoss(nn.Module):
14
+ """Label-smoothing loss.
15
+
16
+ :param int size: the number of class
17
+ :param int padding_idx: ignored class id
18
+ :param float smoothing: smoothing rate (0.0 means the conventional CE)
19
+ :param bool normalize_length: normalize loss by sequence length if True
20
+ :param torch.nn.Module criterion: loss function to be smoothed
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ size,
26
+ padding_idx,
27
+ smoothing,
28
+ normalize_length=False,
29
+ criterion=nn.KLDivLoss(reduction="none"),
30
+ ):
31
+ """Construct an LabelSmoothingLoss object."""
32
+ super(LabelSmoothingLoss, self).__init__()
33
+ self.criterion = criterion
34
+ self.padding_idx = padding_idx
35
+ self.confidence = 1.0 - smoothing
36
+ self.smoothing = smoothing
37
+ self.size = size
38
+ self.true_dist = None
39
+ self.normalize_length = normalize_length
40
+
41
+ def forward(self, x, target):
42
+ """Compute loss between x and target.
43
+
44
+ :param torch.Tensor x: prediction (batch, seqlen, class)
45
+ :param torch.Tensor target:
46
+ target signal masked with self.padding_id (batch, seqlen)
47
+ :return: scalar float value
48
+ :rtype torch.Tensor
49
+ """
50
+ assert x.size(2) == self.size
51
+ batch_size = x.size(0)
52
+ x = x.view(-1, self.size)
53
+ target = target.view(-1)
54
+ with torch.no_grad():
55
+ true_dist = x.clone()
56
+ true_dist.fill_(self.smoothing / (self.size - 1))
57
+ ignore = target == self.padding_idx # (B,)
58
+ total = len(target) - ignore.sum().item()
59
+ target = target.masked_fill(ignore, 0) # avoid -1 index
60
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
61
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
62
+ denom = total if self.normalize_length else batch_size
63
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom