Aditeya Kamlesh Prajapati commited on
Commit ·
8096486
1
Parent(s): b0e7e58
Add app and modules
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- average_checkpoints.py +36 -0
- cosine.py +25 -0
- espnet/nets/batch_beam_search.py +349 -0
- espnet/nets/beam_search.py +510 -0
- espnet/nets/ctc_prefix_score.py +357 -0
- espnet/nets/e2e_asr_common.py +199 -0
- espnet/nets/pytorch_backend/__pycache__/ctc.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/__pycache__/ctc.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/__pycache__/e2e_asr_conformer.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/__pycache__/e2e_asr_conformer.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/__pycache__/nets_utils.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/__pycache__/nets_utils.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/ctc.py +242 -0
- espnet/nets/pytorch_backend/decoder/__pycache__/transformer_decoder.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/decoder/__pycache__/transformer_decoder.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/decoder/transformer_decoder.py +334 -0
- espnet/nets/pytorch_backend/e2e_asr_conformer.py +87 -0
- espnet/nets/pytorch_backend/encoder/__pycache__/conformer_encoder.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/encoder/__pycache__/conformer_encoder.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/encoder/conformer_encoder.py +303 -0
- espnet/nets/pytorch_backend/frontend/__pycache__/resnet.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/frontend/__pycache__/resnet.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/frontend/__pycache__/resnet1d.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/frontend/__pycache__/resnet1d.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/frontend/resnet.py +237 -0
- espnet/nets/pytorch_backend/frontend/resnet1d.py +238 -0
- espnet/nets/pytorch_backend/nets_utils.py +306 -0
- espnet/nets/pytorch_backend/transformer/__init__.py +1 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/__init__.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/add_sos_eos.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/add_sos_eos.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/attention.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/attention.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/embedding.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/embedding.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/label_smoothing_loss.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/layer_norm.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/layer_norm.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/mask.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/mask.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/repeat.cpython-310.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/__pycache__/repeat.cpython-311.pyc +0 -0
- espnet/nets/pytorch_backend/transformer/add_sos_eos.py +31 -0
- espnet/nets/pytorch_backend/transformer/attention.py +193 -0
- espnet/nets/pytorch_backend/transformer/embedding.py +184 -0
- espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py +63 -0
average_checkpoints.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def average_checkpoints(last):
|
| 7 |
+
avg = None
|
| 8 |
+
for path in last:
|
| 9 |
+
states = torch.load(path, map_location=lambda storage, loc: storage)["state_dict"]
|
| 10 |
+
states = {k[6:]: v for k, v in states.items() if k.startswith("model.")}
|
| 11 |
+
if avg is None:
|
| 12 |
+
avg = states
|
| 13 |
+
else:
|
| 14 |
+
for k in avg.keys():
|
| 15 |
+
avg[k] += states[k]
|
| 16 |
+
# average
|
| 17 |
+
for k in avg.keys():
|
| 18 |
+
if avg[k] is not None:
|
| 19 |
+
if avg[k].is_floating_point():
|
| 20 |
+
avg[k] /= len(last)
|
| 21 |
+
else:
|
| 22 |
+
avg[k] //= len(last)
|
| 23 |
+
return avg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def ensemble(args):
|
| 27 |
+
last = [
|
| 28 |
+
os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt")
|
| 29 |
+
for n in range(
|
| 30 |
+
args.max_epochs - 10,
|
| 31 |
+
args.max_epochs,
|
| 32 |
+
)
|
| 33 |
+
]
|
| 34 |
+
model_path = os.path.join(args.exp_dir, args.exp_name, f"model_avg_10.pth")
|
| 35 |
+
torch.save(average_checkpoints(last), model_path)
|
| 36 |
+
return model_path
|
cosine.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
optimizer: torch.optim.Optimizer,
|
| 10 |
+
warmup_epochs: int,
|
| 11 |
+
total_epochs: int,
|
| 12 |
+
steps_per_epoch: int,
|
| 13 |
+
last_epoch=-1,
|
| 14 |
+
verbose=False,
|
| 15 |
+
):
|
| 16 |
+
self.warmup_steps = warmup_epochs * steps_per_epoch
|
| 17 |
+
self.total_steps = total_epochs * steps_per_epoch
|
| 18 |
+
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
|
| 19 |
+
|
| 20 |
+
def get_lr(self):
|
| 21 |
+
if self._step_count < self.warmup_steps:
|
| 22 |
+
return [self._step_count / self.warmup_steps * base_lr for base_lr in self.base_lrs]
|
| 23 |
+
decay_steps = self.total_steps - self.warmup_steps
|
| 24 |
+
cos_val = math.cos(math.pi * (self._step_count - self.warmup_steps) / decay_steps)
|
| 25 |
+
return [0.5 * base_lr * (1 + cos_val) for base_lr in self.base_lrs]
|
espnet/nets/batch_beam_search.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Parallel beam search module."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Dict, List, NamedTuple, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from espnet.nets.beam_search import BeamSearch, Hypothesis
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BatchHypothesis(NamedTuple):
|
| 13 |
+
"""Batchfied/Vectorized hypothesis data type."""
|
| 14 |
+
|
| 15 |
+
yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
|
| 16 |
+
score: torch.Tensor = torch.tensor([]) # (batch,)
|
| 17 |
+
length: torch.Tensor = torch.tensor([]) # (batch,)
|
| 18 |
+
scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
|
| 19 |
+
states: Dict[str, Dict] = dict()
|
| 20 |
+
|
| 21 |
+
def __len__(self) -> int:
|
| 22 |
+
"""Return a batch size."""
|
| 23 |
+
return len(self.length)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BatchBeamSearch(BeamSearch):
|
| 27 |
+
"""Batch beam search implementation."""
|
| 28 |
+
|
| 29 |
+
def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
|
| 30 |
+
"""Convert list to batch."""
|
| 31 |
+
if len(hyps) == 0:
|
| 32 |
+
return BatchHypothesis()
|
| 33 |
+
yseq = pad_sequence(
|
| 34 |
+
[h.yseq for h in hyps], batch_first=True, padding_value=self.eos
|
| 35 |
+
)
|
| 36 |
+
return BatchHypothesis(
|
| 37 |
+
yseq=yseq,
|
| 38 |
+
length=torch.tensor(
|
| 39 |
+
[len(h.yseq) for h in hyps], dtype=torch.int64, device=yseq.device
|
| 40 |
+
),
|
| 41 |
+
score=torch.tensor([h.score for h in hyps]).to(yseq.device),
|
| 42 |
+
scores={
|
| 43 |
+
k: torch.tensor([h.scores[k] for h in hyps], device=yseq.device)
|
| 44 |
+
for k in self.scorers
|
| 45 |
+
},
|
| 46 |
+
states={k: [h.states[k] for h in hyps] for k in self.scorers},
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
|
| 50 |
+
return BatchHypothesis(
|
| 51 |
+
yseq=hyps.yseq[ids],
|
| 52 |
+
score=hyps.score[ids],
|
| 53 |
+
length=hyps.length[ids],
|
| 54 |
+
scores={k: v[ids] for k, v in hyps.scores.items()},
|
| 55 |
+
states={
|
| 56 |
+
k: [self.scorers[k].select_state(v, i) for i in ids]
|
| 57 |
+
for k, v in hyps.states.items()
|
| 58 |
+
},
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
|
| 62 |
+
return Hypothesis(
|
| 63 |
+
yseq=hyps.yseq[i, : hyps.length[i]],
|
| 64 |
+
score=hyps.score[i],
|
| 65 |
+
scores={k: v[i] for k, v in hyps.scores.items()},
|
| 66 |
+
states={
|
| 67 |
+
k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
|
| 68 |
+
},
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
|
| 72 |
+
"""Revert batch to list."""
|
| 73 |
+
return [
|
| 74 |
+
Hypothesis(
|
| 75 |
+
yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
|
| 76 |
+
score=batch_hyps.score[i],
|
| 77 |
+
scores={k: batch_hyps.scores[k][i] for k in self.scorers},
|
| 78 |
+
states={
|
| 79 |
+
k: v.select_state(batch_hyps.states[k], i)
|
| 80 |
+
for k, v in self.scorers.items()
|
| 81 |
+
},
|
| 82 |
+
)
|
| 83 |
+
for i in range(len(batch_hyps.length))
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
def batch_beam(
|
| 87 |
+
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
| 88 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 89 |
+
"""Batch-compute topk full token ids and partial token ids.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
| 93 |
+
Its shape is `(n_beam, self.vocab_size)`.
|
| 94 |
+
ids (torch.Tensor): The partial token ids to compute topk.
|
| 95 |
+
Its shape is `(n_beam, self.pre_beam_size)`.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 99 |
+
The topk full (prev_hyp, new_token) ids
|
| 100 |
+
and partial (prev_hyp, new_token) ids.
|
| 101 |
+
Their shapes are all `(self.beam_size,)`
|
| 102 |
+
|
| 103 |
+
"""
|
| 104 |
+
top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
|
| 105 |
+
# Because of the flatten above, `top_ids` is organized as:
|
| 106 |
+
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
|
| 107 |
+
# where V is `self.n_vocab` and K is `self.beam_size`
|
| 108 |
+
prev_hyp_ids = torch.div(top_ids, self.n_vocab, rounding_mode="trunc")
|
| 109 |
+
new_token_ids = top_ids % self.n_vocab
|
| 110 |
+
return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
|
| 111 |
+
|
| 112 |
+
def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
|
| 113 |
+
"""Get an initial hypothesis data.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
x (torch.Tensor): The encoder output feature
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Hypothesis: The initial hypothesis.
|
| 120 |
+
|
| 121 |
+
"""
|
| 122 |
+
init_states = dict()
|
| 123 |
+
init_scores = dict()
|
| 124 |
+
for k, d in self.scorers.items():
|
| 125 |
+
init_states[k] = d.batch_init_state(x)
|
| 126 |
+
init_scores[k] = 0.0
|
| 127 |
+
return self.batchfy(
|
| 128 |
+
[
|
| 129 |
+
Hypothesis(
|
| 130 |
+
score=0.0,
|
| 131 |
+
scores=init_scores,
|
| 132 |
+
states=init_states,
|
| 133 |
+
yseq=torch.tensor([self.sos], device=x.device),
|
| 134 |
+
)
|
| 135 |
+
]
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def score_full(
|
| 139 |
+
self, hyp: BatchHypothesis, x: torch.Tensor
|
| 140 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
| 141 |
+
"""Score new hypothesis by `self.full_scorers`.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
| 145 |
+
x (torch.Tensor): Corresponding input feature
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
| 149 |
+
score dict of `hyp` that has string keys of `self.full_scorers`
|
| 150 |
+
and tensor score values of shape: `(self.n_vocab,)`,
|
| 151 |
+
and state dict that has string keys
|
| 152 |
+
and state values of `self.full_scorers`
|
| 153 |
+
|
| 154 |
+
"""
|
| 155 |
+
scores = dict()
|
| 156 |
+
states = dict()
|
| 157 |
+
for k, d in self.full_scorers.items():
|
| 158 |
+
scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
|
| 159 |
+
return scores, states
|
| 160 |
+
|
| 161 |
+
def score_partial(
|
| 162 |
+
self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
|
| 163 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
| 164 |
+
"""Score new hypothesis by `self.full_scorers`.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
| 168 |
+
ids (torch.Tensor): 2D tensor of new partial tokens to score
|
| 169 |
+
x (torch.Tensor): Corresponding input feature
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
| 173 |
+
score dict of `hyp` that has string keys of `self.full_scorers`
|
| 174 |
+
and tensor score values of shape: `(self.n_vocab,)`,
|
| 175 |
+
and state dict that has string keys
|
| 176 |
+
and state values of `self.full_scorers`
|
| 177 |
+
|
| 178 |
+
"""
|
| 179 |
+
scores = dict()
|
| 180 |
+
states = dict()
|
| 181 |
+
for k, d in self.part_scorers.items():
|
| 182 |
+
scores[k], states[k] = d.batch_score_partial(
|
| 183 |
+
hyp.yseq, ids, hyp.states[k], x
|
| 184 |
+
)
|
| 185 |
+
return scores, states
|
| 186 |
+
|
| 187 |
+
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
| 188 |
+
"""Merge states for new hypothesis.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
states: states of `self.full_scorers`
|
| 192 |
+
part_states: states of `self.part_scorers`
|
| 193 |
+
part_idx (int): The new token id for `part_scores`
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Dict[str, torch.Tensor]: The new score dict.
|
| 197 |
+
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
| 198 |
+
Its values are states of the scorers.
|
| 199 |
+
|
| 200 |
+
"""
|
| 201 |
+
new_states = dict()
|
| 202 |
+
for k, v in states.items():
|
| 203 |
+
new_states[k] = v
|
| 204 |
+
for k, v in part_states.items():
|
| 205 |
+
new_states[k] = v
|
| 206 |
+
return new_states
|
| 207 |
+
|
| 208 |
+
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
|
| 209 |
+
"""Search new tokens for running hypotheses and encoded speech x.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
running_hyps (BatchHypothesis): Running hypotheses on beam
|
| 213 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
BatchHypothesis: Best sorted hypotheses
|
| 217 |
+
|
| 218 |
+
"""
|
| 219 |
+
n_batch = len(running_hyps)
|
| 220 |
+
part_ids = None # no pre-beam
|
| 221 |
+
# batch scoring
|
| 222 |
+
weighted_scores = torch.zeros(
|
| 223 |
+
n_batch, self.n_vocab, dtype=x.dtype, device=x.device
|
| 224 |
+
)
|
| 225 |
+
scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
|
| 226 |
+
for k in self.full_scorers:
|
| 227 |
+
weighted_scores += self.weights[k] * scores[k]
|
| 228 |
+
# partial scoring
|
| 229 |
+
if self.do_pre_beam:
|
| 230 |
+
pre_beam_scores = (
|
| 231 |
+
weighted_scores
|
| 232 |
+
if self.pre_beam_score_key == "full"
|
| 233 |
+
else scores[self.pre_beam_score_key]
|
| 234 |
+
)
|
| 235 |
+
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
|
| 236 |
+
# NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
|
| 237 |
+
# full-size score matrices, which has non-zero scores for part_ids and zeros
|
| 238 |
+
# for others.
|
| 239 |
+
part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
|
| 240 |
+
for k in self.part_scorers:
|
| 241 |
+
weighted_scores += self.weights[k] * part_scores[k]
|
| 242 |
+
# add previous hyp scores
|
| 243 |
+
weighted_scores += running_hyps.score.to(
|
| 244 |
+
dtype=x.dtype, device=x.device
|
| 245 |
+
).unsqueeze(1)
|
| 246 |
+
|
| 247 |
+
# TODO(karita): do not use list. use batch instead
|
| 248 |
+
# see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
|
| 249 |
+
# update hyps
|
| 250 |
+
best_hyps = []
|
| 251 |
+
prev_hyps = self.unbatchfy(running_hyps)
|
| 252 |
+
for (
|
| 253 |
+
full_prev_hyp_id,
|
| 254 |
+
full_new_token_id,
|
| 255 |
+
part_prev_hyp_id,
|
| 256 |
+
part_new_token_id,
|
| 257 |
+
) in zip(*self.batch_beam(weighted_scores, part_ids)):
|
| 258 |
+
prev_hyp = prev_hyps[full_prev_hyp_id]
|
| 259 |
+
best_hyps.append(
|
| 260 |
+
Hypothesis(
|
| 261 |
+
score=weighted_scores[full_prev_hyp_id, full_new_token_id],
|
| 262 |
+
yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
|
| 263 |
+
scores=self.merge_scores(
|
| 264 |
+
prev_hyp.scores,
|
| 265 |
+
{k: v[full_prev_hyp_id] for k, v in scores.items()},
|
| 266 |
+
full_new_token_id,
|
| 267 |
+
{k: v[part_prev_hyp_id] for k, v in part_scores.items()},
|
| 268 |
+
part_new_token_id,
|
| 269 |
+
),
|
| 270 |
+
states=self.merge_states(
|
| 271 |
+
{
|
| 272 |
+
k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
|
| 273 |
+
for k, v in states.items()
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
k: self.part_scorers[k].select_state(
|
| 277 |
+
v, part_prev_hyp_id, part_new_token_id
|
| 278 |
+
)
|
| 279 |
+
for k, v in part_states.items()
|
| 280 |
+
},
|
| 281 |
+
part_new_token_id,
|
| 282 |
+
),
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
return self.batchfy(best_hyps)
|
| 286 |
+
|
| 287 |
+
def post_process(
|
| 288 |
+
self,
|
| 289 |
+
i: int,
|
| 290 |
+
maxlen: int,
|
| 291 |
+
maxlenratio: float,
|
| 292 |
+
running_hyps: BatchHypothesis,
|
| 293 |
+
ended_hyps: List[Hypothesis],
|
| 294 |
+
) -> BatchHypothesis:
|
| 295 |
+
"""Perform post-processing of beam search iterations.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
i (int): The length of hypothesis tokens.
|
| 299 |
+
maxlen (int): The maximum length of tokens in beam search.
|
| 300 |
+
maxlenratio (int): The maximum length ratio in beam search.
|
| 301 |
+
running_hyps (BatchHypothesis): The running hypotheses in beam search.
|
| 302 |
+
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
BatchHypothesis: The new running hypotheses.
|
| 306 |
+
|
| 307 |
+
"""
|
| 308 |
+
n_batch = running_hyps.yseq.shape[0]
|
| 309 |
+
logging.debug(f"the number of running hypothes: {n_batch}")
|
| 310 |
+
if self.token_list is not None:
|
| 311 |
+
logging.debug(
|
| 312 |
+
"best hypo: "
|
| 313 |
+
+ "".join(
|
| 314 |
+
[
|
| 315 |
+
self.token_list[x]
|
| 316 |
+
for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
|
| 317 |
+
]
|
| 318 |
+
)
|
| 319 |
+
)
|
| 320 |
+
# add eos in the final loop to avoid that there are no ended hyps
|
| 321 |
+
if i == maxlen - 1:
|
| 322 |
+
logging.debug("adding <eos> in the last position in the loop")
|
| 323 |
+
yseq_eos = torch.cat(
|
| 324 |
+
(
|
| 325 |
+
running_hyps.yseq,
|
| 326 |
+
torch.full(
|
| 327 |
+
(n_batch, 1),
|
| 328 |
+
self.eos,
|
| 329 |
+
device=running_hyps.yseq.device,
|
| 330 |
+
dtype=torch.int64,
|
| 331 |
+
),
|
| 332 |
+
),
|
| 333 |
+
1,
|
| 334 |
+
)
|
| 335 |
+
running_hyps.yseq.resize_as_(yseq_eos)
|
| 336 |
+
running_hyps.yseq[:] = yseq_eos
|
| 337 |
+
running_hyps.length[:] = yseq_eos.shape[1]
|
| 338 |
+
|
| 339 |
+
# add ended hypotheses to a final list, and removed them from current hypotheses
|
| 340 |
+
# (this will be a probmlem, number of hyps < beam)
|
| 341 |
+
is_eos = (
|
| 342 |
+
running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
|
| 343 |
+
== self.eos
|
| 344 |
+
)
|
| 345 |
+
for b in torch.nonzero(is_eos, as_tuple=False).view(-1):
|
| 346 |
+
hyp = self._select(running_hyps, b)
|
| 347 |
+
ended_hyps.append(hyp)
|
| 348 |
+
remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1)
|
| 349 |
+
return self._batch_select(running_hyps, remained_ids)
|
espnet/nets/beam_search.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Beam search module."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from itertools import chain
|
| 5 |
+
from typing import Any, Dict, List, NamedTuple, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from espnet.nets.e2e_asr_common import end_detect
|
| 10 |
+
from espnet.nets.scorer_interface import PartialScorerInterface, ScorerInterface
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Hypothesis(NamedTuple):
|
| 14 |
+
"""Hypothesis data type."""
|
| 15 |
+
|
| 16 |
+
yseq: torch.Tensor
|
| 17 |
+
score: Union[float, torch.Tensor] = 0
|
| 18 |
+
scores: Dict[str, Union[float, torch.Tensor]] = dict()
|
| 19 |
+
states: Dict[str, Any] = dict()
|
| 20 |
+
|
| 21 |
+
def asdict(self) -> dict:
|
| 22 |
+
"""Convert data to JSON-friendly dict."""
|
| 23 |
+
return self._replace(
|
| 24 |
+
yseq=self.yseq.tolist(),
|
| 25 |
+
score=float(self.score),
|
| 26 |
+
scores={k: float(v) for k, v in self.scores.items()},
|
| 27 |
+
)._asdict()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BeamSearch(torch.nn.Module):
|
| 31 |
+
"""Beam search implementation."""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
scorers: Dict[str, ScorerInterface],
|
| 36 |
+
weights: Dict[str, float],
|
| 37 |
+
beam_size: int,
|
| 38 |
+
vocab_size: int,
|
| 39 |
+
sos: int,
|
| 40 |
+
eos: int,
|
| 41 |
+
token_list: List[str] = None,
|
| 42 |
+
pre_beam_ratio: float = 1.5,
|
| 43 |
+
pre_beam_score_key: str = None,
|
| 44 |
+
):
|
| 45 |
+
"""Initialize beam search.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
| 49 |
+
e.g., Decoder, CTCPrefixScorer, LM
|
| 50 |
+
The scorer will be ignored if it is `None`
|
| 51 |
+
weights (dict[str, float]): Dict of weights for each scorers
|
| 52 |
+
The scorer will be ignored if its weight is 0
|
| 53 |
+
beam_size (int): The number of hypotheses kept during search
|
| 54 |
+
vocab_size (int): The number of vocabulary
|
| 55 |
+
sos (int): Start of sequence id
|
| 56 |
+
eos (int): End of sequence id
|
| 57 |
+
token_list (list[str]): List of tokens for debug log
|
| 58 |
+
pre_beam_score_key (str): key of scores to perform pre-beam search
|
| 59 |
+
pre_beam_ratio (float): beam size in the pre-beam search
|
| 60 |
+
will be `int(pre_beam_ratio * beam_size)`
|
| 61 |
+
|
| 62 |
+
"""
|
| 63 |
+
super().__init__()
|
| 64 |
+
# set scorers
|
| 65 |
+
self.weights = weights
|
| 66 |
+
self.scorers = dict()
|
| 67 |
+
self.full_scorers = dict()
|
| 68 |
+
self.part_scorers = dict()
|
| 69 |
+
# this module dict is required for recursive cast
|
| 70 |
+
# `self.to(device, dtype)` in `recog.py`
|
| 71 |
+
self.nn_dict = torch.nn.ModuleDict()
|
| 72 |
+
for k, v in scorers.items():
|
| 73 |
+
w = weights.get(k, 0)
|
| 74 |
+
if w == 0 or v is None:
|
| 75 |
+
continue
|
| 76 |
+
assert isinstance(
|
| 77 |
+
v, ScorerInterface
|
| 78 |
+
), f"{k} ({type(v)}) does not implement ScorerInterface"
|
| 79 |
+
self.scorers[k] = v
|
| 80 |
+
if isinstance(v, PartialScorerInterface):
|
| 81 |
+
self.part_scorers[k] = v
|
| 82 |
+
else:
|
| 83 |
+
self.full_scorers[k] = v
|
| 84 |
+
if isinstance(v, torch.nn.Module):
|
| 85 |
+
self.nn_dict[k] = v
|
| 86 |
+
|
| 87 |
+
# set configurations
|
| 88 |
+
self.sos = sos
|
| 89 |
+
self.eos = eos
|
| 90 |
+
self.token_list = token_list
|
| 91 |
+
self.pre_beam_size = int(pre_beam_ratio * beam_size)
|
| 92 |
+
self.beam_size = beam_size
|
| 93 |
+
self.n_vocab = vocab_size
|
| 94 |
+
if (
|
| 95 |
+
pre_beam_score_key is not None
|
| 96 |
+
and pre_beam_score_key != "full"
|
| 97 |
+
and pre_beam_score_key not in self.full_scorers
|
| 98 |
+
):
|
| 99 |
+
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
|
| 100 |
+
self.pre_beam_score_key = pre_beam_score_key
|
| 101 |
+
self.do_pre_beam = (
|
| 102 |
+
self.pre_beam_score_key is not None
|
| 103 |
+
and self.pre_beam_size < self.n_vocab
|
| 104 |
+
and len(self.part_scorers) > 0
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
|
| 108 |
+
"""Get an initial hypothesis data.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
x (torch.Tensor): The encoder output feature
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Hypothesis: The initial hypothesis.
|
| 115 |
+
|
| 116 |
+
"""
|
| 117 |
+
init_states = dict()
|
| 118 |
+
init_scores = dict()
|
| 119 |
+
for k, d in self.scorers.items():
|
| 120 |
+
init_states[k] = d.init_state(x)
|
| 121 |
+
init_scores[k] = 0.0
|
| 122 |
+
return [
|
| 123 |
+
Hypothesis(
|
| 124 |
+
score=0.0,
|
| 125 |
+
scores=init_scores,
|
| 126 |
+
states=init_states,
|
| 127 |
+
yseq=torch.tensor([self.sos], device=x.device),
|
| 128 |
+
)
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
|
| 133 |
+
"""Append new token to prefix tokens.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
xs (torch.Tensor): The prefix token
|
| 137 |
+
x (int): The new token to append
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
|
| 141 |
+
|
| 142 |
+
"""
|
| 143 |
+
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
|
| 144 |
+
return torch.cat((xs, x))
|
| 145 |
+
|
| 146 |
+
def score_full(
|
| 147 |
+
self, hyp: Hypothesis, x: torch.Tensor
|
| 148 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
| 149 |
+
"""Score new hypothesis by `self.full_scorers`.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
| 153 |
+
x (torch.Tensor): Corresponding input feature
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
| 157 |
+
score dict of `hyp` that has string keys of `self.full_scorers`
|
| 158 |
+
and tensor score values of shape: `(self.n_vocab,)`,
|
| 159 |
+
and state dict that has string keys
|
| 160 |
+
and state values of `self.full_scorers`
|
| 161 |
+
|
| 162 |
+
"""
|
| 163 |
+
scores = dict()
|
| 164 |
+
states = dict()
|
| 165 |
+
for k, d in self.full_scorers.items():
|
| 166 |
+
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
|
| 167 |
+
return scores, states
|
| 168 |
+
|
| 169 |
+
def score_partial(
|
| 170 |
+
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
|
| 171 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
| 172 |
+
"""Score new hypothesis by `self.part_scorers`.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
| 176 |
+
ids (torch.Tensor): 1D tensor of new partial tokens to score
|
| 177 |
+
x (torch.Tensor): Corresponding input feature
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
| 181 |
+
score dict of `hyp` that has string keys of `self.part_scorers`
|
| 182 |
+
and tensor score values of shape: `(len(ids),)`,
|
| 183 |
+
and state dict that has string keys
|
| 184 |
+
and state values of `self.part_scorers`
|
| 185 |
+
|
| 186 |
+
"""
|
| 187 |
+
scores = dict()
|
| 188 |
+
states = dict()
|
| 189 |
+
for k, d in self.part_scorers.items():
|
| 190 |
+
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
|
| 191 |
+
return scores, states
|
| 192 |
+
|
| 193 |
+
def beam(
|
| 194 |
+
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
| 195 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 196 |
+
"""Compute topk full token ids and partial token ids.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
| 200 |
+
Its shape is `(self.n_vocab,)`.
|
| 201 |
+
ids (torch.Tensor): The partial token ids to compute topk
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Tuple[torch.Tensor, torch.Tensor]:
|
| 205 |
+
The topk full token ids and partial token ids.
|
| 206 |
+
Their shapes are `(self.beam_size,)`
|
| 207 |
+
|
| 208 |
+
"""
|
| 209 |
+
# no pre beam performed
|
| 210 |
+
if weighted_scores.size(0) == ids.size(0):
|
| 211 |
+
top_ids = weighted_scores.topk(self.beam_size)[1]
|
| 212 |
+
return top_ids, top_ids
|
| 213 |
+
|
| 214 |
+
# mask pruned in pre-beam not to select in topk
|
| 215 |
+
tmp = weighted_scores[ids]
|
| 216 |
+
weighted_scores[:] = -float("inf")
|
| 217 |
+
weighted_scores[ids] = tmp
|
| 218 |
+
top_ids = weighted_scores.topk(self.beam_size)[1]
|
| 219 |
+
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
|
| 220 |
+
return top_ids, local_ids
|
| 221 |
+
|
| 222 |
+
@staticmethod
|
| 223 |
+
def merge_scores(
|
| 224 |
+
prev_scores: Dict[str, float],
|
| 225 |
+
next_full_scores: Dict[str, torch.Tensor],
|
| 226 |
+
full_idx: int,
|
| 227 |
+
next_part_scores: Dict[str, torch.Tensor],
|
| 228 |
+
part_idx: int,
|
| 229 |
+
) -> Dict[str, torch.Tensor]:
|
| 230 |
+
"""Merge scores for new hypothesis.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
prev_scores (Dict[str, float]):
|
| 234 |
+
The previous hypothesis scores by `self.scorers`
|
| 235 |
+
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
|
| 236 |
+
full_idx (int): The next token id for `next_full_scores`
|
| 237 |
+
next_part_scores (Dict[str, torch.Tensor]):
|
| 238 |
+
scores of partial tokens by `self.part_scorers`
|
| 239 |
+
part_idx (int): The new token id for `next_part_scores`
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
Dict[str, torch.Tensor]: The new score dict.
|
| 243 |
+
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
| 244 |
+
Its values are scalar tensors by the scorers.
|
| 245 |
+
|
| 246 |
+
"""
|
| 247 |
+
new_scores = dict()
|
| 248 |
+
for k, v in next_full_scores.items():
|
| 249 |
+
new_scores[k] = prev_scores[k] + v[full_idx]
|
| 250 |
+
for k, v in next_part_scores.items():
|
| 251 |
+
new_scores[k] = prev_scores[k] + v[part_idx]
|
| 252 |
+
return new_scores
|
| 253 |
+
|
| 254 |
+
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
| 255 |
+
"""Merge states for new hypothesis.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
states: states of `self.full_scorers`
|
| 259 |
+
part_states: states of `self.part_scorers`
|
| 260 |
+
part_idx (int): The new token id for `part_scores`
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
Dict[str, torch.Tensor]: The new score dict.
|
| 264 |
+
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
| 265 |
+
Its values are states of the scorers.
|
| 266 |
+
|
| 267 |
+
"""
|
| 268 |
+
new_states = dict()
|
| 269 |
+
for k, v in states.items():
|
| 270 |
+
new_states[k] = v
|
| 271 |
+
for k, d in self.part_scorers.items():
|
| 272 |
+
new_states[k] = d.select_state(part_states[k], part_idx)
|
| 273 |
+
return new_states
|
| 274 |
+
|
| 275 |
+
def search(
|
| 276 |
+
self, running_hyps: List[Hypothesis], x: torch.Tensor
|
| 277 |
+
) -> List[Hypothesis]:
|
| 278 |
+
"""Search new tokens for running hypotheses and encoded speech x.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
running_hyps (List[Hypothesis]): Running hypotheses on beam
|
| 282 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
List[Hypotheses]: Best sorted hypotheses
|
| 286 |
+
|
| 287 |
+
"""
|
| 288 |
+
best_hyps = []
|
| 289 |
+
part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
|
| 290 |
+
for hyp in running_hyps:
|
| 291 |
+
# scoring
|
| 292 |
+
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
|
| 293 |
+
scores, states = self.score_full(hyp, x)
|
| 294 |
+
for k in self.full_scorers:
|
| 295 |
+
weighted_scores += self.weights[k] * scores[k]
|
| 296 |
+
# partial scoring
|
| 297 |
+
if self.do_pre_beam:
|
| 298 |
+
pre_beam_scores = (
|
| 299 |
+
weighted_scores
|
| 300 |
+
if self.pre_beam_score_key == "full"
|
| 301 |
+
else scores[self.pre_beam_score_key]
|
| 302 |
+
)
|
| 303 |
+
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
|
| 304 |
+
part_scores, part_states = self.score_partial(hyp, part_ids, x)
|
| 305 |
+
for k in self.part_scorers:
|
| 306 |
+
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
|
| 307 |
+
# add previous hyp score
|
| 308 |
+
weighted_scores += hyp.score
|
| 309 |
+
|
| 310 |
+
# update hyps
|
| 311 |
+
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
|
| 312 |
+
# will be (2 x beam at most)
|
| 313 |
+
best_hyps.append(
|
| 314 |
+
Hypothesis(
|
| 315 |
+
score=weighted_scores[j],
|
| 316 |
+
yseq=self.append_token(hyp.yseq, j),
|
| 317 |
+
scores=self.merge_scores(
|
| 318 |
+
hyp.scores, scores, j, part_scores, part_j
|
| 319 |
+
),
|
| 320 |
+
states=self.merge_states(states, part_states, part_j),
|
| 321 |
+
)
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# sort and prune 2 x beam -> beam
|
| 325 |
+
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
|
| 326 |
+
: min(len(best_hyps), self.beam_size)
|
| 327 |
+
]
|
| 328 |
+
return best_hyps
|
| 329 |
+
|
| 330 |
+
def forward(
|
| 331 |
+
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
|
| 332 |
+
) -> List[Hypothesis]:
|
| 333 |
+
"""Perform beam search.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
| 337 |
+
maxlenratio (float): Input length ratio to obtain max output length.
|
| 338 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
| 339 |
+
to automatically find maximum hypothesis lengths
|
| 340 |
+
If maxlenratio<0.0, its absolute value is interpreted
|
| 341 |
+
as a constant max output length.
|
| 342 |
+
minlenratio (float): Input length ratio to obtain min output length.
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
list[Hypothesis]: N-best decoding results
|
| 346 |
+
|
| 347 |
+
"""
|
| 348 |
+
# set length bounds
|
| 349 |
+
if maxlenratio == 0:
|
| 350 |
+
maxlen = x.shape[0]
|
| 351 |
+
elif maxlenratio < 0:
|
| 352 |
+
maxlen = -1 * int(maxlenratio)
|
| 353 |
+
else:
|
| 354 |
+
maxlen = max(1, int(maxlenratio * x.size(0)))
|
| 355 |
+
minlen = int(minlenratio * x.size(0))
|
| 356 |
+
logging.debug("decoder input length: " + str(x.shape[0]))
|
| 357 |
+
logging.debug("max output length: " + str(maxlen))
|
| 358 |
+
logging.debug("min output length: " + str(minlen))
|
| 359 |
+
|
| 360 |
+
# main loop of prefix search
|
| 361 |
+
running_hyps = self.init_hyp(x)
|
| 362 |
+
ended_hyps = []
|
| 363 |
+
for i in range(maxlen):
|
| 364 |
+
logging.debug("position " + str(i))
|
| 365 |
+
best = self.search(running_hyps, x)
|
| 366 |
+
# post process of one iteration
|
| 367 |
+
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
|
| 368 |
+
# end detection
|
| 369 |
+
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
|
| 370 |
+
logging.debug(f"end detected at {i}")
|
| 371 |
+
break
|
| 372 |
+
if len(running_hyps) == 0:
|
| 373 |
+
logging.debug("no hypothesis. Finish decoding.")
|
| 374 |
+
break
|
| 375 |
+
else:
|
| 376 |
+
logging.debug(f"remained hypotheses: {len(running_hyps)}")
|
| 377 |
+
|
| 378 |
+
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
| 379 |
+
# check the number of hypotheses reaching to eos
|
| 380 |
+
if len(nbest_hyps) == 0:
|
| 381 |
+
logging.warning(
|
| 382 |
+
"there is no N-best results, perform recognition "
|
| 383 |
+
"again with smaller minlenratio."
|
| 384 |
+
)
|
| 385 |
+
return (
|
| 386 |
+
[]
|
| 387 |
+
if minlenratio < 0.1
|
| 388 |
+
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# report the best result
|
| 392 |
+
best = nbest_hyps[0]
|
| 393 |
+
for k, v in best.scores.items():
|
| 394 |
+
logging.debug(
|
| 395 |
+
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
|
| 396 |
+
)
|
| 397 |
+
logging.debug(f"total log probability: {best.score:.2f}")
|
| 398 |
+
logging.debug(f"normalized log probability: {best.score / len(best.yseq):.2f}")
|
| 399 |
+
logging.debug(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
| 400 |
+
if self.token_list is not None:
|
| 401 |
+
logging.debug(
|
| 402 |
+
"best hypo: "
|
| 403 |
+
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
|
| 404 |
+
+ "\n"
|
| 405 |
+
)
|
| 406 |
+
return nbest_hyps
|
| 407 |
+
|
| 408 |
+
def post_process(
|
| 409 |
+
self,
|
| 410 |
+
i: int,
|
| 411 |
+
maxlen: int,
|
| 412 |
+
maxlenratio: float,
|
| 413 |
+
running_hyps: List[Hypothesis],
|
| 414 |
+
ended_hyps: List[Hypothesis],
|
| 415 |
+
) -> List[Hypothesis]:
|
| 416 |
+
"""Perform post-processing of beam search iterations.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
i (int): The length of hypothesis tokens.
|
| 420 |
+
maxlen (int): The maximum length of tokens in beam search.
|
| 421 |
+
maxlenratio (int): The maximum length ratio in beam search.
|
| 422 |
+
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
|
| 423 |
+
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
List[Hypothesis]: The new running hypotheses.
|
| 427 |
+
|
| 428 |
+
"""
|
| 429 |
+
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
|
| 430 |
+
if self.token_list is not None:
|
| 431 |
+
logging.debug(
|
| 432 |
+
"best hypo: "
|
| 433 |
+
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
|
| 434 |
+
)
|
| 435 |
+
# add eos in the final loop to avoid that there are no ended hyps
|
| 436 |
+
if i == maxlen - 1:
|
| 437 |
+
logging.debug("adding <eos> in the last position in the loop")
|
| 438 |
+
running_hyps = [
|
| 439 |
+
h._replace(yseq=self.append_token(h.yseq, self.eos))
|
| 440 |
+
for h in running_hyps
|
| 441 |
+
]
|
| 442 |
+
|
| 443 |
+
# add ended hypotheses to a final list, and removed them from current hypotheses
|
| 444 |
+
# (this will be a problem, number of hyps < beam)
|
| 445 |
+
remained_hyps = []
|
| 446 |
+
for hyp in running_hyps:
|
| 447 |
+
if hyp.yseq[-1] == self.eos:
|
| 448 |
+
# e.g., Word LM needs to add final <eos> score
|
| 449 |
+
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
|
| 450 |
+
s = d.final_score(hyp.states[k])
|
| 451 |
+
hyp.scores[k] += s
|
| 452 |
+
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
|
| 453 |
+
ended_hyps.append(hyp)
|
| 454 |
+
else:
|
| 455 |
+
remained_hyps.append(hyp)
|
| 456 |
+
return remained_hyps
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def beam_search(
|
| 460 |
+
x: torch.Tensor,
|
| 461 |
+
sos: int,
|
| 462 |
+
eos: int,
|
| 463 |
+
beam_size: int,
|
| 464 |
+
vocab_size: int,
|
| 465 |
+
scorers: Dict[str, ScorerInterface],
|
| 466 |
+
weights: Dict[str, float],
|
| 467 |
+
token_list: List[str] = None,
|
| 468 |
+
maxlenratio: float = 0.0,
|
| 469 |
+
minlenratio: float = 0.0,
|
| 470 |
+
pre_beam_ratio: float = 1.5,
|
| 471 |
+
pre_beam_score_key: str = "full",
|
| 472 |
+
) -> list:
|
| 473 |
+
"""Perform beam search with scorers.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
| 477 |
+
sos (int): Start of sequence id
|
| 478 |
+
eos (int): End of sequence id
|
| 479 |
+
beam_size (int): The number of hypotheses kept during search
|
| 480 |
+
vocab_size (int): The number of vocabulary
|
| 481 |
+
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
| 482 |
+
e.g., Decoder, CTCPrefixScorer, LM
|
| 483 |
+
The scorer will be ignored if it is `None`
|
| 484 |
+
weights (dict[str, float]): Dict of weights for each scorers
|
| 485 |
+
The scorer will be ignored if its weight is 0
|
| 486 |
+
token_list (list[str]): List of tokens for debug log
|
| 487 |
+
maxlenratio (float): Input length ratio to obtain max output length.
|
| 488 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
| 489 |
+
to automatically find maximum hypothesis lengths
|
| 490 |
+
minlenratio (float): Input length ratio to obtain min output length.
|
| 491 |
+
pre_beam_score_key (str): key of scores to perform pre-beam search
|
| 492 |
+
pre_beam_ratio (float): beam size in the pre-beam search
|
| 493 |
+
will be `int(pre_beam_ratio * beam_size)`
|
| 494 |
+
|
| 495 |
+
Returns:
|
| 496 |
+
list: N-best decoding results
|
| 497 |
+
|
| 498 |
+
"""
|
| 499 |
+
ret = BeamSearch(
|
| 500 |
+
scorers,
|
| 501 |
+
weights,
|
| 502 |
+
beam_size=beam_size,
|
| 503 |
+
vocab_size=vocab_size,
|
| 504 |
+
pre_beam_ratio=pre_beam_ratio,
|
| 505 |
+
pre_beam_score_key=pre_beam_score_key,
|
| 506 |
+
sos=sos,
|
| 507 |
+
eos=eos,
|
| 508 |
+
token_list=token_list,
|
| 509 |
+
).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
|
| 510 |
+
return [h.asdict() for h in ret]
|
espnet/nets/ctc_prefix_score.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
|
| 4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CTCPrefixScoreTH(object):
|
| 11 |
+
"""Batch processing of CTCPrefixScore
|
| 12 |
+
|
| 13 |
+
which is based on Algorithm 2 in WATANABE et al.
|
| 14 |
+
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
| 15 |
+
but extended to efficiently compute the label probablities for multiple
|
| 16 |
+
hypotheses simultaneously
|
| 17 |
+
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
|
| 18 |
+
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, x, xlens, blank, eos, margin=0):
|
| 22 |
+
"""Construct CTC prefix scorer
|
| 23 |
+
|
| 24 |
+
:param torch.Tensor x: input label posterior sequences (B, T, O)
|
| 25 |
+
:param torch.Tensor xlens: input lengths (B,)
|
| 26 |
+
:param int blank: blank label id
|
| 27 |
+
:param int eos: end-of-sequence id
|
| 28 |
+
:param int margin: margin parameter for windowing (0 means no windowing)
|
| 29 |
+
"""
|
| 30 |
+
# In the comment lines,
|
| 31 |
+
# we assume T: input_length, B: batch size, W: beam width, O: output dim.
|
| 32 |
+
self.logzero = -10000000000.0
|
| 33 |
+
self.blank = blank
|
| 34 |
+
self.eos = eos
|
| 35 |
+
self.batch = x.size(0)
|
| 36 |
+
self.input_length = x.size(1)
|
| 37 |
+
self.odim = x.size(2)
|
| 38 |
+
self.dtype = x.dtype
|
| 39 |
+
self.device = (
|
| 40 |
+
torch.device("cuda:%d" % x.get_device())
|
| 41 |
+
if x.is_cuda
|
| 42 |
+
else torch.device("cpu")
|
| 43 |
+
)
|
| 44 |
+
# Pad the rest of posteriors in the batch
|
| 45 |
+
# TODO(takaaki-hori): need a better way without for-loops
|
| 46 |
+
for i, l in enumerate(xlens):
|
| 47 |
+
if l < self.input_length:
|
| 48 |
+
x[i, l:, :] = self.logzero
|
| 49 |
+
x[i, l:, blank] = 0
|
| 50 |
+
# Reshape input x
|
| 51 |
+
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
|
| 52 |
+
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
| 53 |
+
self.x = torch.stack([xn, xb]) # (2, T, B, O)
|
| 54 |
+
self.end_frames = torch.as_tensor(xlens) - 1
|
| 55 |
+
|
| 56 |
+
# Setup CTC windowing
|
| 57 |
+
self.margin = margin
|
| 58 |
+
if margin > 0:
|
| 59 |
+
self.frame_ids = torch.arange(
|
| 60 |
+
self.input_length, dtype=self.dtype, device=self.device
|
| 61 |
+
)
|
| 62 |
+
# Base indices for index conversion
|
| 63 |
+
self.idx_bh = None
|
| 64 |
+
self.idx_b = torch.arange(self.batch, device=self.device)
|
| 65 |
+
self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
|
| 66 |
+
|
| 67 |
+
def __call__(self, y, state, scoring_ids=None, att_w=None):
|
| 68 |
+
"""Compute CTC prefix scores for next labels
|
| 69 |
+
|
| 70 |
+
:param list y: prefix label sequences
|
| 71 |
+
:param tuple state: previous CTC state
|
| 72 |
+
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
|
| 73 |
+
:param torch.Tensor att_w: attention weights to decide CTC window
|
| 74 |
+
:return new_state, ctc_local_scores (BW, O)
|
| 75 |
+
"""
|
| 76 |
+
output_length = len(y[0]) - 1 # ignore sos
|
| 77 |
+
last_ids = [yi[-1] for yi in y] # last output label ids
|
| 78 |
+
n_bh = len(last_ids) # batch * hyps
|
| 79 |
+
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
|
| 80 |
+
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
|
| 81 |
+
# prepare state info
|
| 82 |
+
if state is None:
|
| 83 |
+
r_prev = torch.full(
|
| 84 |
+
(self.input_length, 2, self.batch, n_hyps),
|
| 85 |
+
self.logzero,
|
| 86 |
+
dtype=self.dtype,
|
| 87 |
+
device=self.device,
|
| 88 |
+
)
|
| 89 |
+
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
|
| 90 |
+
r_prev = r_prev.view(-1, 2, n_bh)
|
| 91 |
+
s_prev = 0.0
|
| 92 |
+
f_min_prev = 0
|
| 93 |
+
f_max_prev = 1
|
| 94 |
+
else:
|
| 95 |
+
r_prev, s_prev, f_min_prev, f_max_prev = state
|
| 96 |
+
|
| 97 |
+
# select input dimensions for scoring
|
| 98 |
+
if self.scoring_num > 0:
|
| 99 |
+
scoring_idmap = torch.full(
|
| 100 |
+
(n_bh, self.odim), -1, dtype=torch.long, device=self.device
|
| 101 |
+
)
|
| 102 |
+
snum = self.scoring_num
|
| 103 |
+
if self.idx_bh is None or n_bh > len(self.idx_bh):
|
| 104 |
+
self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
|
| 105 |
+
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
|
| 106 |
+
snum, device=self.device
|
| 107 |
+
)
|
| 108 |
+
scoring_idx = (
|
| 109 |
+
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
|
| 110 |
+
).view(-1)
|
| 111 |
+
x_ = torch.index_select(
|
| 112 |
+
self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
|
| 113 |
+
).view(2, -1, n_bh, snum)
|
| 114 |
+
else:
|
| 115 |
+
scoring_ids = None
|
| 116 |
+
scoring_idmap = None
|
| 117 |
+
snum = self.odim
|
| 118 |
+
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
|
| 119 |
+
|
| 120 |
+
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
|
| 121 |
+
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
|
| 122 |
+
r = torch.full(
|
| 123 |
+
(self.input_length, 2, n_bh, snum),
|
| 124 |
+
self.logzero,
|
| 125 |
+
dtype=self.dtype,
|
| 126 |
+
device=self.device,
|
| 127 |
+
)
|
| 128 |
+
if output_length == 0:
|
| 129 |
+
r[0, 0] = x_[0, 0]
|
| 130 |
+
|
| 131 |
+
r_sum = torch.logsumexp(r_prev, 1)
|
| 132 |
+
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
|
| 133 |
+
if scoring_ids is not None:
|
| 134 |
+
for idx in range(n_bh):
|
| 135 |
+
pos = scoring_idmap[idx, last_ids[idx]]
|
| 136 |
+
if pos >= 0:
|
| 137 |
+
log_phi[:, idx, pos] = r_prev[:, 1, idx]
|
| 138 |
+
else:
|
| 139 |
+
for idx in range(n_bh):
|
| 140 |
+
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
|
| 141 |
+
|
| 142 |
+
# decide start and end frames based on attention weights
|
| 143 |
+
if att_w is not None and self.margin > 0:
|
| 144 |
+
f_arg = torch.matmul(att_w, self.frame_ids)
|
| 145 |
+
f_min = max(int(f_arg.min().cpu()), f_min_prev)
|
| 146 |
+
f_max = max(int(f_arg.max().cpu()), f_max_prev)
|
| 147 |
+
start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
|
| 148 |
+
end = min(f_max + self.margin, self.input_length)
|
| 149 |
+
else:
|
| 150 |
+
f_min = f_max = 0
|
| 151 |
+
start = max(output_length, 1)
|
| 152 |
+
end = self.input_length
|
| 153 |
+
|
| 154 |
+
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
|
| 155 |
+
for t in range(start, end):
|
| 156 |
+
rp = r[t - 1]
|
| 157 |
+
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
|
| 158 |
+
2, 2, n_bh, snum
|
| 159 |
+
)
|
| 160 |
+
r[t] = torch.logsumexp(rr, 1) + x_[:, t]
|
| 161 |
+
|
| 162 |
+
# compute log prefix probabilities log(psi)
|
| 163 |
+
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
|
| 164 |
+
if scoring_ids is not None:
|
| 165 |
+
log_psi = torch.full(
|
| 166 |
+
(n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device
|
| 167 |
+
)
|
| 168 |
+
log_psi_ = torch.logsumexp(
|
| 169 |
+
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
|
| 170 |
+
dim=0,
|
| 171 |
+
)
|
| 172 |
+
for si in range(n_bh):
|
| 173 |
+
log_psi[si, scoring_ids[si]] = log_psi_[si]
|
| 174 |
+
else:
|
| 175 |
+
log_psi = torch.logsumexp(
|
| 176 |
+
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
|
| 177 |
+
dim=0,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
for si in range(n_bh):
|
| 181 |
+
log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
|
| 182 |
+
|
| 183 |
+
# exclude blank probs
|
| 184 |
+
log_psi[:, self.blank] = self.logzero
|
| 185 |
+
|
| 186 |
+
return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
|
| 187 |
+
|
| 188 |
+
def index_select_state(self, state, best_ids):
|
| 189 |
+
"""Select CTC states according to best ids
|
| 190 |
+
|
| 191 |
+
:param state : CTC state
|
| 192 |
+
:param best_ids : index numbers selected by beam pruning (B, W)
|
| 193 |
+
:return selected_state
|
| 194 |
+
"""
|
| 195 |
+
r, s, f_min, f_max, scoring_idmap = state
|
| 196 |
+
# convert ids to BHO space
|
| 197 |
+
n_bh = len(s)
|
| 198 |
+
n_hyps = n_bh // self.batch
|
| 199 |
+
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
|
| 200 |
+
# select hypothesis scores
|
| 201 |
+
s_new = torch.index_select(s.view(-1), 0, vidx)
|
| 202 |
+
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
|
| 203 |
+
# convert ids to BHS space (S: scoring_num)
|
| 204 |
+
if scoring_idmap is not None:
|
| 205 |
+
snum = self.scoring_num
|
| 206 |
+
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
|
| 207 |
+
-1
|
| 208 |
+
)
|
| 209 |
+
label_ids = torch.fmod(best_ids, self.odim).view(-1)
|
| 210 |
+
score_idx = scoring_idmap[hyp_idx, label_ids]
|
| 211 |
+
score_idx[score_idx == -1] = 0
|
| 212 |
+
vidx = score_idx + hyp_idx * snum
|
| 213 |
+
else:
|
| 214 |
+
snum = self.odim
|
| 215 |
+
# select forward probabilities
|
| 216 |
+
r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
|
| 217 |
+
-1, 2, n_bh
|
| 218 |
+
)
|
| 219 |
+
return r_new, s_new, f_min, f_max
|
| 220 |
+
|
| 221 |
+
def extend_prob(self, x):
|
| 222 |
+
"""Extend CTC prob.
|
| 223 |
+
|
| 224 |
+
:param torch.Tensor x: input label posterior sequences (B, T, O)
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
|
| 228 |
+
# Pad the rest of posteriors in the batch
|
| 229 |
+
# TODO(takaaki-hori): need a better way without for-loops
|
| 230 |
+
xlens = [x.size(1)]
|
| 231 |
+
for i, l in enumerate(xlens):
|
| 232 |
+
if l < self.input_length:
|
| 233 |
+
x[i, l:, :] = self.logzero
|
| 234 |
+
x[i, l:, self.blank] = 0
|
| 235 |
+
tmp_x = self.x
|
| 236 |
+
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
|
| 237 |
+
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
| 238 |
+
self.x = torch.stack([xn, xb]) # (2, T, B, O)
|
| 239 |
+
self.x[:, : tmp_x.shape[1], :, :] = tmp_x
|
| 240 |
+
self.input_length = x.size(1)
|
| 241 |
+
self.end_frames = torch.as_tensor(xlens) - 1
|
| 242 |
+
|
| 243 |
+
def extend_state(self, state):
|
| 244 |
+
"""Compute CTC prefix state.
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
:param state : CTC state
|
| 248 |
+
:return ctc_state
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
if state is None:
|
| 252 |
+
# nothing to do
|
| 253 |
+
return state
|
| 254 |
+
else:
|
| 255 |
+
r_prev, s_prev, f_min_prev, f_max_prev = state
|
| 256 |
+
|
| 257 |
+
r_prev_new = torch.full(
|
| 258 |
+
(self.input_length, 2),
|
| 259 |
+
self.logzero,
|
| 260 |
+
dtype=self.dtype,
|
| 261 |
+
device=self.device,
|
| 262 |
+
)
|
| 263 |
+
start = max(r_prev.shape[0], 1)
|
| 264 |
+
r_prev_new[0:start] = r_prev
|
| 265 |
+
for t in range(start, self.input_length):
|
| 266 |
+
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
|
| 267 |
+
|
| 268 |
+
return (r_prev_new, s_prev, f_min_prev, f_max_prev)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class CTCPrefixScore(object):
|
| 272 |
+
"""Compute CTC label sequence scores
|
| 273 |
+
|
| 274 |
+
which is based on Algorithm 2 in WATANABE et al.
|
| 275 |
+
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
| 276 |
+
but extended to efficiently compute the probablities of multiple labels
|
| 277 |
+
simultaneously
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
def __init__(self, x, blank, eos, xp):
|
| 281 |
+
self.xp = xp
|
| 282 |
+
self.logzero = -10000000000.0
|
| 283 |
+
self.blank = blank
|
| 284 |
+
self.eos = eos
|
| 285 |
+
self.input_length = len(x)
|
| 286 |
+
self.x = x
|
| 287 |
+
|
| 288 |
+
def initial_state(self):
|
| 289 |
+
"""Obtain an initial CTC state
|
| 290 |
+
|
| 291 |
+
:return: CTC state
|
| 292 |
+
"""
|
| 293 |
+
# initial CTC state is made of a frame x 2 tensor that corresponds to
|
| 294 |
+
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
|
| 295 |
+
# superscripts n and b (non-blank and blank), respectively.
|
| 296 |
+
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
|
| 297 |
+
r[0, 1] = self.x[0, self.blank]
|
| 298 |
+
for i in range(1, self.input_length):
|
| 299 |
+
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
|
| 300 |
+
return r
|
| 301 |
+
|
| 302 |
+
def __call__(self, y, cs, r_prev):
|
| 303 |
+
"""Compute CTC prefix scores for next labels
|
| 304 |
+
|
| 305 |
+
:param y : prefix label sequence
|
| 306 |
+
:param cs : array of next labels
|
| 307 |
+
:param r_prev: previous CTC state
|
| 308 |
+
:return ctc_scores, ctc_states
|
| 309 |
+
"""
|
| 310 |
+
# initialize CTC states
|
| 311 |
+
output_length = len(y) - 1 # ignore sos
|
| 312 |
+
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
|
| 313 |
+
# that corresponds to r_t^n(h) and r_t^b(h).
|
| 314 |
+
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
|
| 315 |
+
xs = self.x[:, cs]
|
| 316 |
+
if output_length == 0:
|
| 317 |
+
r[0, 0] = xs[0]
|
| 318 |
+
r[0, 1] = self.logzero
|
| 319 |
+
else:
|
| 320 |
+
r[output_length - 1] = self.logzero
|
| 321 |
+
|
| 322 |
+
# prepare forward probabilities for the last label
|
| 323 |
+
r_sum = self.xp.logaddexp(
|
| 324 |
+
r_prev[:, 0], r_prev[:, 1]
|
| 325 |
+
) # log(r_t^n(g) + r_t^b(g))
|
| 326 |
+
last = y[-1]
|
| 327 |
+
if output_length > 0 and last in cs:
|
| 328 |
+
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
|
| 329 |
+
for i in range(len(cs)):
|
| 330 |
+
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
|
| 331 |
+
else:
|
| 332 |
+
log_phi = r_sum
|
| 333 |
+
|
| 334 |
+
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
|
| 335 |
+
# and log prefix probabilities log(psi)
|
| 336 |
+
start = max(output_length, 1)
|
| 337 |
+
log_psi = r[start - 1, 0]
|
| 338 |
+
for t in range(start, self.input_length):
|
| 339 |
+
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
|
| 340 |
+
r[t, 1] = (
|
| 341 |
+
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
|
| 342 |
+
)
|
| 343 |
+
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
|
| 344 |
+
|
| 345 |
+
# get P(...eos|X) that ends with the prefix itself
|
| 346 |
+
eos_pos = self.xp.where(cs == self.eos)[0]
|
| 347 |
+
if len(eos_pos) > 0:
|
| 348 |
+
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
|
| 349 |
+
|
| 350 |
+
# exclude blank probs
|
| 351 |
+
blank_pos = self.xp.where(cs == self.blank)[0]
|
| 352 |
+
if len(blank_pos) > 0:
|
| 353 |
+
log_psi[blank_pos] = self.logzero
|
| 354 |
+
|
| 355 |
+
# return the log prefix probability and CTC states, where the label axis
|
| 356 |
+
# of the CTC states is moved to the first axis to slice it easily
|
| 357 |
+
return log_psi, self.xp.rollaxis(r, 2)
|
espnet/nets/e2e_asr_common.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Common functions for ASR."""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import sys
|
| 12 |
+
from itertools import groupby
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
| 18 |
+
"""End detection.
|
| 19 |
+
|
| 20 |
+
described in Eq. (50) of S. Watanabe et al
|
| 21 |
+
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
| 22 |
+
|
| 23 |
+
:param ended_hyps:
|
| 24 |
+
:param i:
|
| 25 |
+
:param M:
|
| 26 |
+
:param D_end:
|
| 27 |
+
:return:
|
| 28 |
+
"""
|
| 29 |
+
if len(ended_hyps) == 0:
|
| 30 |
+
return False
|
| 31 |
+
count = 0
|
| 32 |
+
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
|
| 33 |
+
for m in range(M):
|
| 34 |
+
# get ended_hyps with their length is i - m
|
| 35 |
+
hyp_length = i - m
|
| 36 |
+
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
|
| 37 |
+
if len(hyps_same_length) > 0:
|
| 38 |
+
best_hyp_same_length = sorted(
|
| 39 |
+
hyps_same_length, key=lambda x: x["score"], reverse=True
|
| 40 |
+
)[0]
|
| 41 |
+
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
|
| 42 |
+
count += 1
|
| 43 |
+
|
| 44 |
+
if count == M:
|
| 45 |
+
return True
|
| 46 |
+
else:
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ErrorCalculator(object):
|
| 51 |
+
"""Calculate CER and WER for E2E_ASR and CTC models during training.
|
| 52 |
+
|
| 53 |
+
:param y_hats: numpy array with predicted text
|
| 54 |
+
:param y_pads: numpy array with true (target) text
|
| 55 |
+
:param char_list:
|
| 56 |
+
:param sym_space:
|
| 57 |
+
:param sym_blank:
|
| 58 |
+
:return:
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
|
| 63 |
+
):
|
| 64 |
+
"""Construct an ErrorCalculator object."""
|
| 65 |
+
super(ErrorCalculator, self).__init__()
|
| 66 |
+
|
| 67 |
+
self.report_cer = report_cer
|
| 68 |
+
self.report_wer = report_wer
|
| 69 |
+
|
| 70 |
+
self.char_list = char_list
|
| 71 |
+
self.space = sym_space
|
| 72 |
+
self.blank = sym_blank
|
| 73 |
+
self.idx_blank = self.char_list.index(self.blank)
|
| 74 |
+
if self.space in self.char_list:
|
| 75 |
+
self.idx_space = self.char_list.index(self.space)
|
| 76 |
+
else:
|
| 77 |
+
self.idx_space = None
|
| 78 |
+
|
| 79 |
+
def __call__(self, ys_hat, ys_pad, is_ctc=False):
|
| 80 |
+
"""Calculate sentence-level WER/CER score.
|
| 81 |
+
|
| 82 |
+
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
| 83 |
+
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
| 84 |
+
:param bool is_ctc: calculate CER score for CTC
|
| 85 |
+
:return: sentence-level WER score
|
| 86 |
+
:rtype float
|
| 87 |
+
:return: sentence-level CER score
|
| 88 |
+
:rtype float
|
| 89 |
+
"""
|
| 90 |
+
cer, wer = None, None
|
| 91 |
+
if is_ctc:
|
| 92 |
+
return self.calculate_cer_ctc(ys_hat, ys_pad)
|
| 93 |
+
elif not self.report_cer and not self.report_wer:
|
| 94 |
+
return cer, wer
|
| 95 |
+
|
| 96 |
+
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
|
| 97 |
+
if self.report_cer:
|
| 98 |
+
cer = self.calculate_cer(seqs_hat, seqs_true)
|
| 99 |
+
|
| 100 |
+
if self.report_wer:
|
| 101 |
+
wer = self.calculate_wer(seqs_hat, seqs_true)
|
| 102 |
+
return cer, wer
|
| 103 |
+
|
| 104 |
+
def calculate_cer_ctc(self, ys_hat, ys_pad):
|
| 105 |
+
"""Calculate sentence-level CER score for CTC.
|
| 106 |
+
|
| 107 |
+
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
| 108 |
+
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
| 109 |
+
:return: average sentence-level CER score
|
| 110 |
+
:rtype float
|
| 111 |
+
"""
|
| 112 |
+
import editdistance
|
| 113 |
+
|
| 114 |
+
cers, char_ref_lens = [], []
|
| 115 |
+
for i, y in enumerate(ys_hat):
|
| 116 |
+
y_hat = [x[0] for x in groupby(y)]
|
| 117 |
+
y_true = ys_pad[i]
|
| 118 |
+
seq_hat, seq_true = [], []
|
| 119 |
+
for idx in y_hat:
|
| 120 |
+
idx = int(idx)
|
| 121 |
+
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
| 122 |
+
seq_hat.append(self.char_list[int(idx)])
|
| 123 |
+
|
| 124 |
+
for idx in y_true:
|
| 125 |
+
idx = int(idx)
|
| 126 |
+
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
| 127 |
+
seq_true.append(self.char_list[int(idx)])
|
| 128 |
+
|
| 129 |
+
hyp_chars = "".join(seq_hat)
|
| 130 |
+
ref_chars = "".join(seq_true)
|
| 131 |
+
if len(ref_chars) > 0:
|
| 132 |
+
cers.append(editdistance.eval(hyp_chars, ref_chars))
|
| 133 |
+
char_ref_lens.append(len(ref_chars))
|
| 134 |
+
|
| 135 |
+
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
|
| 136 |
+
return cer_ctc
|
| 137 |
+
|
| 138 |
+
def convert_to_char(self, ys_hat, ys_pad):
|
| 139 |
+
"""Convert index to character.
|
| 140 |
+
|
| 141 |
+
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
|
| 142 |
+
:param torch.Tensor seqs_true: reference (batch, seqlen)
|
| 143 |
+
:return: token list of prediction
|
| 144 |
+
:rtype list
|
| 145 |
+
:return: token list of reference
|
| 146 |
+
:rtype list
|
| 147 |
+
"""
|
| 148 |
+
seqs_hat, seqs_true = [], []
|
| 149 |
+
for i, y_hat in enumerate(ys_hat):
|
| 150 |
+
y_true = ys_pad[i]
|
| 151 |
+
eos_true = np.where(y_true == -1)[0]
|
| 152 |
+
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
| 153 |
+
# NOTE: padding index (-1) in y_true is used to pad y_hat
|
| 154 |
+
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
|
| 155 |
+
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
| 156 |
+
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
|
| 157 |
+
seq_hat_text = seq_hat_text.replace(self.blank, "")
|
| 158 |
+
seq_true_text = "".join(seq_true).replace(self.space, " ")
|
| 159 |
+
seqs_hat.append(seq_hat_text)
|
| 160 |
+
seqs_true.append(seq_true_text)
|
| 161 |
+
return seqs_hat, seqs_true
|
| 162 |
+
|
| 163 |
+
def calculate_cer(self, seqs_hat, seqs_true):
|
| 164 |
+
"""Calculate sentence-level CER score.
|
| 165 |
+
|
| 166 |
+
:param list seqs_hat: prediction
|
| 167 |
+
:param list seqs_true: reference
|
| 168 |
+
:return: average sentence-level CER score
|
| 169 |
+
:rtype float
|
| 170 |
+
"""
|
| 171 |
+
import editdistance
|
| 172 |
+
|
| 173 |
+
char_eds, char_ref_lens = [], []
|
| 174 |
+
for i, seq_hat_text in enumerate(seqs_hat):
|
| 175 |
+
seq_true_text = seqs_true[i]
|
| 176 |
+
hyp_chars = seq_hat_text.replace(" ", "")
|
| 177 |
+
ref_chars = seq_true_text.replace(" ", "")
|
| 178 |
+
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
| 179 |
+
char_ref_lens.append(len(ref_chars))
|
| 180 |
+
return float(sum(char_eds)) / sum(char_ref_lens)
|
| 181 |
+
|
| 182 |
+
def calculate_wer(self, seqs_hat, seqs_true):
|
| 183 |
+
"""Calculate sentence-level WER score.
|
| 184 |
+
|
| 185 |
+
:param list seqs_hat: prediction
|
| 186 |
+
:param list seqs_true: reference
|
| 187 |
+
:return: average sentence-level WER score
|
| 188 |
+
:rtype float
|
| 189 |
+
"""
|
| 190 |
+
import editdistance
|
| 191 |
+
|
| 192 |
+
word_eds, word_ref_lens = [], []
|
| 193 |
+
for i, seq_hat_text in enumerate(seqs_hat):
|
| 194 |
+
seq_true_text = seqs_true[i]
|
| 195 |
+
hyp_words = seq_hat_text.split()
|
| 196 |
+
ref_words = seq_true_text.split()
|
| 197 |
+
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
| 198 |
+
word_ref_lens.append(len(ref_words))
|
| 199 |
+
return float(sum(word_eds)) / sum(word_ref_lens)
|
espnet/nets/pytorch_backend/__pycache__/ctc.cpython-310.pyc
ADDED
|
Binary file (8.1 kB). View file
|
|
|
espnet/nets/pytorch_backend/__pycache__/ctc.cpython-311.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
espnet/nets/pytorch_backend/__pycache__/e2e_asr_conformer.cpython-310.pyc
ADDED
|
Binary file (2.82 kB). View file
|
|
|
espnet/nets/pytorch_backend/__pycache__/e2e_asr_conformer.cpython-311.pyc
ADDED
|
Binary file (4.76 kB). View file
|
|
|
espnet/nets/pytorch_backend/__pycache__/nets_utils.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
espnet/nets/pytorch_backend/__pycache__/nets_utils.cpython-311.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
espnet/nets/pytorch_backend/ctc.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from espnet.nets.pytorch_backend.nets_utils import to_device
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CTC(torch.nn.Module):
|
| 9 |
+
"""CTC module
|
| 10 |
+
|
| 11 |
+
:param int odim: dimension of outputs
|
| 12 |
+
:param int eprojs: number of encoder projection units
|
| 13 |
+
:param float dropout_rate: dropout rate (0.0 ~ 1.0)
|
| 14 |
+
:param bool reduce: reduce the CTC loss into a scalar
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, odim, eprojs, dropout_rate, reduce=True):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.dropout_rate = dropout_rate
|
| 20 |
+
self.loss = None
|
| 21 |
+
self.ctc_lo = torch.nn.Linear(eprojs, odim)
|
| 22 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 23 |
+
self.probs = None # for visualization
|
| 24 |
+
|
| 25 |
+
reduction_type = "sum" if reduce else "none"
|
| 26 |
+
self.ctc_loss = torch.nn.CTCLoss(
|
| 27 |
+
reduction=reduction_type, zero_infinity=True
|
| 28 |
+
)
|
| 29 |
+
self.ignore_id = -1
|
| 30 |
+
self.reduce = reduce
|
| 31 |
+
|
| 32 |
+
def loss_fn(self, th_pred, th_target, th_ilen, th_olen):
|
| 33 |
+
th_pred = th_pred.log_softmax(2)
|
| 34 |
+
with torch.backends.cudnn.flags(deterministic=True):
|
| 35 |
+
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
|
| 36 |
+
# Batch-size average
|
| 37 |
+
loss = loss / th_pred.size(1)
|
| 38 |
+
return loss
|
| 39 |
+
|
| 40 |
+
def forward(self, hs_pad, hlens, ys_pad):
|
| 41 |
+
"""CTC forward
|
| 42 |
+
|
| 43 |
+
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
|
| 44 |
+
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
|
| 45 |
+
:param torch.Tensor ys_pad:
|
| 46 |
+
batch of padded character id sequence tensor (B, Lmax)
|
| 47 |
+
:return: ctc loss value
|
| 48 |
+
:rtype: torch.Tensor
|
| 49 |
+
"""
|
| 50 |
+
# TODO(kan-bayashi): need to make more smart way
|
| 51 |
+
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
|
| 52 |
+
|
| 53 |
+
# zero padding for hs
|
| 54 |
+
ys_hat = self.ctc_lo(self.dropout(hs_pad))
|
| 55 |
+
ys_hat = ys_hat.transpose(0, 1)
|
| 56 |
+
|
| 57 |
+
olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys]))
|
| 58 |
+
hlens = hlens.long()
|
| 59 |
+
ys_pad = torch.cat(ys) # without this the code breaks for asr_mix
|
| 60 |
+
self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens)
|
| 61 |
+
|
| 62 |
+
if self.reduce:
|
| 63 |
+
self.loss = self.loss.sum()
|
| 64 |
+
|
| 65 |
+
return self.loss, ys_hat
|
| 66 |
+
|
| 67 |
+
def softmax(self, hs_pad):
|
| 68 |
+
"""softmax of frame activations
|
| 69 |
+
|
| 70 |
+
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
| 71 |
+
:return: log softmax applied 3d tensor (B, Tmax, odim)
|
| 72 |
+
:rtype: torch.Tensor
|
| 73 |
+
"""
|
| 74 |
+
self.probs = F.softmax(self.ctc_lo(hs_pad), dim=-1)
|
| 75 |
+
return self.probs
|
| 76 |
+
|
| 77 |
+
def log_softmax(self, hs_pad):
|
| 78 |
+
"""log_softmax of frame activations
|
| 79 |
+
|
| 80 |
+
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
| 81 |
+
:return: log softmax applied 3d tensor (B, Tmax, odim)
|
| 82 |
+
:rtype: torch.Tensor
|
| 83 |
+
"""
|
| 84 |
+
return F.log_softmax(self.ctc_lo(hs_pad), dim=-1)
|
| 85 |
+
|
| 86 |
+
def argmax(self, hs_pad):
|
| 87 |
+
"""argmax of frame activations
|
| 88 |
+
|
| 89 |
+
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
| 90 |
+
:return: argmax applied 2d tensor (B, Tmax)
|
| 91 |
+
:rtype: torch.Tensor
|
| 92 |
+
"""
|
| 93 |
+
return torch.argmax(self.ctc_lo(hs_pad), dim=-1)
|
| 94 |
+
|
| 95 |
+
def forced_align(self, h, y, blank_id=0):
|
| 96 |
+
"""forced alignment.
|
| 97 |
+
|
| 98 |
+
:param torch.Tensor h: hidden state sequence, 2d tensor (T, D)
|
| 99 |
+
:param torch.Tensor y: id sequence tensor 1d tensor (L)
|
| 100 |
+
:param int y: blank symbol index
|
| 101 |
+
:return: best alignment results
|
| 102 |
+
:rtype: list
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def interpolate_blank(label, blank_id=0):
|
| 106 |
+
"""Insert blank token between every two label token."""
|
| 107 |
+
label = np.expand_dims(label, 1)
|
| 108 |
+
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
|
| 109 |
+
label = np.concatenate([blanks, label], axis=1)
|
| 110 |
+
label = label.reshape(-1)
|
| 111 |
+
label = np.append(label, label[0])
|
| 112 |
+
return label
|
| 113 |
+
|
| 114 |
+
lpz = self.log_softmax(h)
|
| 115 |
+
lpz = lpz.squeeze(0)
|
| 116 |
+
|
| 117 |
+
y_int = interpolate_blank(y, blank_id)
|
| 118 |
+
|
| 119 |
+
logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero
|
| 120 |
+
state_path = (
|
| 121 |
+
np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1
|
| 122 |
+
) # state path
|
| 123 |
+
|
| 124 |
+
logdelta[0, 0] = lpz[0][y_int[0]]
|
| 125 |
+
logdelta[0, 1] = lpz[0][y_int[1]]
|
| 126 |
+
|
| 127 |
+
for t in range(1, lpz.size(0)):
|
| 128 |
+
for s in range(len(y_int)):
|
| 129 |
+
if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]:
|
| 130 |
+
candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]])
|
| 131 |
+
prev_state = [s, s - 1]
|
| 132 |
+
else:
|
| 133 |
+
candidates = np.array(
|
| 134 |
+
[
|
| 135 |
+
logdelta[t - 1, s],
|
| 136 |
+
logdelta[t - 1, s - 1],
|
| 137 |
+
logdelta[t - 1, s - 2],
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
prev_state = [s, s - 1, s - 2]
|
| 141 |
+
logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]]
|
| 142 |
+
state_path[t, s] = prev_state[np.argmax(candidates)]
|
| 143 |
+
|
| 144 |
+
state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16)
|
| 145 |
+
|
| 146 |
+
candidates = np.array(
|
| 147 |
+
[logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]]
|
| 148 |
+
)
|
| 149 |
+
prev_state = [len(y_int) - 1, len(y_int) - 2]
|
| 150 |
+
state_seq[-1] = prev_state[np.argmax(candidates)]
|
| 151 |
+
for t in range(lpz.size(0) - 2, -1, -1):
|
| 152 |
+
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
|
| 153 |
+
|
| 154 |
+
output_state_seq = []
|
| 155 |
+
for t in range(0, lpz.size(0)):
|
| 156 |
+
output_state_seq.append(y_int[state_seq[t, 0]])
|
| 157 |
+
|
| 158 |
+
return output_state_seq
|
| 159 |
+
|
| 160 |
+
def forced_align_batch(self, hs_pad, ys_pad, ilens, blank_id=0):
|
| 161 |
+
"""forced alignment with batch processing.
|
| 162 |
+
|
| 163 |
+
:param torch.Tensor hs_pad: hidden state sequence, 3d tensor (T, B, D)
|
| 164 |
+
:param torch.Tensor ys_pad: id sequence tensor 2d tensor (B, L)
|
| 165 |
+
:param torch.Tensor ilens: Input length of each utterance (B,)
|
| 166 |
+
:param int blank_id: blank symbol index
|
| 167 |
+
:return: best alignment results
|
| 168 |
+
:rtype: list of numpy.array
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def interpolate_blank(label, olens_int):
|
| 172 |
+
"""Insert blank token between every two label token."""
|
| 173 |
+
lab_len = label.shape[1] * 2 + 1
|
| 174 |
+
label_out = np.full((label.shape[0], lab_len), blank_id, dtype=np.int64)
|
| 175 |
+
label_out[:, 1::2] = label
|
| 176 |
+
for b in range(label.shape[0]):
|
| 177 |
+
label_out[b, olens_int[b] * 2 + 1 :] = self.ignore_id
|
| 178 |
+
return label_out
|
| 179 |
+
|
| 180 |
+
neginf = float("-inf") # log of zero
|
| 181 |
+
# lpz = self.log_softmax(hs_pad).cpu().detach().numpy()
|
| 182 |
+
# hs_pad = hs_pad.transpose(1,0)
|
| 183 |
+
lpz = F.log_softmax(hs_pad, dim=-1).cpu().detach().numpy()
|
| 184 |
+
ilens = ilens.cpu().detach().numpy()
|
| 185 |
+
|
| 186 |
+
ys_pad = ys_pad.cpu().detach().numpy()
|
| 187 |
+
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
|
| 188 |
+
olens = np.array([len(s) for s in ys])
|
| 189 |
+
olens_int = olens * 2 + 1
|
| 190 |
+
ys_int = interpolate_blank(ys_pad, olens_int)
|
| 191 |
+
|
| 192 |
+
Tmax, B, _ = lpz.shape
|
| 193 |
+
Lmax = ys_int.shape[-1]
|
| 194 |
+
logdelta = np.full((Tmax, B, Lmax), neginf, dtype=lpz.dtype)
|
| 195 |
+
state_path = -np.ones(logdelta.shape, dtype=np.int16) # state path
|
| 196 |
+
|
| 197 |
+
b_indx = np.arange(B, dtype=np.int64)
|
| 198 |
+
t_0 = np.zeros(B, dtype=np.int64)
|
| 199 |
+
logdelta[0, :, 0] = lpz[t_0, b_indx, ys_int[:, 0]]
|
| 200 |
+
logdelta[0, :, 1] = lpz[t_0, b_indx, ys_int[:, 1]]
|
| 201 |
+
|
| 202 |
+
s_indx_mat = np.arange(Lmax)[None, :].repeat(B, 0)
|
| 203 |
+
notignore_mat = ys_int != self.ignore_id
|
| 204 |
+
same_lab_mat = np.zeros((B, Lmax), dtype=bool)
|
| 205 |
+
same_lab_mat[:, 3::2] = ys_int[:, 3::2] == ys_int[:, 1:-2:2]
|
| 206 |
+
Lmin = olens_int.min()
|
| 207 |
+
for t in range(1, Tmax):
|
| 208 |
+
s_start = max(0, Lmin - (Tmax - t) * 2)
|
| 209 |
+
s_end = min(Lmax, t * 2 + 2)
|
| 210 |
+
candidates = np.full((B, Lmax, 3), neginf, dtype=logdelta.dtype)
|
| 211 |
+
candidates[:, :, 0] = logdelta[t - 1, :, :]
|
| 212 |
+
candidates[:, 1:, 1] = logdelta[t - 1, :, :-1]
|
| 213 |
+
candidates[:, 3::2, 2] = logdelta[t - 1, :, 1:-2:2]
|
| 214 |
+
candidates[same_lab_mat, 2] = neginf
|
| 215 |
+
candidates_ = candidates[:, s_start:s_end, :]
|
| 216 |
+
idx = candidates_.argmax(-1)
|
| 217 |
+
b_i, s_i = np.ogrid[:B, : idx.shape[-1]]
|
| 218 |
+
nignore = notignore_mat[:, s_start:s_end]
|
| 219 |
+
logdelta[t, :, s_start:s_end][nignore] = (
|
| 220 |
+
candidates_[b_i, s_i, idx][nignore]
|
| 221 |
+
+ lpz[t, b_i, ys_int[:, s_start:s_end]][nignore]
|
| 222 |
+
)
|
| 223 |
+
s = s_indx_mat[:, s_start:s_end]
|
| 224 |
+
state_path[t, :, s_start:s_end][nignore] = (s - idx)[nignore]
|
| 225 |
+
|
| 226 |
+
alignments = []
|
| 227 |
+
prev_states = logdelta[
|
| 228 |
+
ilens[:, None] - 1,
|
| 229 |
+
b_indx[:, None],
|
| 230 |
+
np.stack([olens_int - 2, olens_int - 1], -1),
|
| 231 |
+
].argmax(-1)
|
| 232 |
+
for b in range(B):
|
| 233 |
+
T, L = ilens[b], olens_int[b]
|
| 234 |
+
prev_state = prev_states[b] + L - 2
|
| 235 |
+
ali = np.empty(T, dtype=ys_int.dtype)
|
| 236 |
+
ali[T - 1] = ys_int[b, prev_state]
|
| 237 |
+
for t in range(T - 2, -1, -1):
|
| 238 |
+
prev_state = state_path[t + 1, b, prev_state]
|
| 239 |
+
ali[t] = ys_int[b, prev_state]
|
| 240 |
+
alignments.append(ali)
|
| 241 |
+
|
| 242 |
+
return alignments
|
espnet/nets/pytorch_backend/decoder/__pycache__/transformer_decoder.cpython-310.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
espnet/nets/pytorch_backend/decoder/__pycache__/transformer_decoder.cpython-311.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
espnet/nets/pytorch_backend/decoder/transformer_decoder.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Decoder definition."""
|
| 8 |
+
|
| 9 |
+
from typing import Any, List, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
|
| 13 |
+
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
|
| 14 |
+
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
|
| 15 |
+
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
| 16 |
+
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
|
| 17 |
+
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
| 18 |
+
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
| 19 |
+
from espnet.nets.scorer_interface import BatchScorerInterface
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DecoderLayer(torch.nn.Module):
|
| 23 |
+
"""Single decoder layer module.
|
| 24 |
+
:param int size: input dim
|
| 25 |
+
:param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention
|
| 26 |
+
self_attn: self attention module
|
| 27 |
+
:param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention
|
| 28 |
+
src_attn: source attention module
|
| 29 |
+
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
|
| 30 |
+
PositionwiseFeedForward feed_forward: feed forward layer module
|
| 31 |
+
:param float dropout_rate: dropout rate
|
| 32 |
+
:param bool normalize_before: whether to use layer_norm before the first block
|
| 33 |
+
:param bool concat_after: whether to concat attention layer's input and output
|
| 34 |
+
if True, additional linear will be applied.
|
| 35 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
| 36 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
size,
|
| 42 |
+
self_attn,
|
| 43 |
+
src_attn,
|
| 44 |
+
feed_forward,
|
| 45 |
+
dropout_rate,
|
| 46 |
+
normalize_before=True,
|
| 47 |
+
concat_after=False,
|
| 48 |
+
):
|
| 49 |
+
"""Construct an DecoderLayer object."""
|
| 50 |
+
super(DecoderLayer, self).__init__()
|
| 51 |
+
self.size = size
|
| 52 |
+
self.self_attn = self_attn
|
| 53 |
+
self.src_attn = src_attn
|
| 54 |
+
self.feed_forward = feed_forward
|
| 55 |
+
self.norm1 = LayerNorm(size)
|
| 56 |
+
self.norm2 = LayerNorm(size)
|
| 57 |
+
self.norm3 = LayerNorm(size)
|
| 58 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 59 |
+
self.normalize_before = normalize_before
|
| 60 |
+
self.concat_after = concat_after
|
| 61 |
+
if self.concat_after:
|
| 62 |
+
self.concat_linear1 = torch.nn.Linear(size + size, size)
|
| 63 |
+
self.concat_linear2 = torch.nn.Linear(size + size, size)
|
| 64 |
+
|
| 65 |
+
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
|
| 66 |
+
"""Compute decoded features.
|
| 67 |
+
Args:
|
| 68 |
+
tgt (torch.Tensor):
|
| 69 |
+
decoded previous target features (batch, max_time_out, size)
|
| 70 |
+
tgt_mask (torch.Tensor): mask for x (batch, max_time_out)
|
| 71 |
+
memory (torch.Tensor): encoded source features (batch, max_time_in, size)
|
| 72 |
+
memory_mask (torch.Tensor): mask for memory (batch, max_time_in)
|
| 73 |
+
cache (torch.Tensor): cached output (batch, max_time_out-1, size)
|
| 74 |
+
"""
|
| 75 |
+
residual = tgt
|
| 76 |
+
if self.normalize_before:
|
| 77 |
+
tgt = self.norm1(tgt)
|
| 78 |
+
|
| 79 |
+
if cache is None:
|
| 80 |
+
tgt_q = tgt
|
| 81 |
+
tgt_q_mask = tgt_mask
|
| 82 |
+
else:
|
| 83 |
+
# compute only the last frame query keeping dim: max_time_out -> 1
|
| 84 |
+
assert cache.shape == (
|
| 85 |
+
tgt.shape[0],
|
| 86 |
+
tgt.shape[1] - 1,
|
| 87 |
+
self.size,
|
| 88 |
+
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
| 89 |
+
tgt_q = tgt[:, -1:, :]
|
| 90 |
+
residual = residual[:, -1:, :]
|
| 91 |
+
tgt_q_mask = None
|
| 92 |
+
if tgt_mask is not None:
|
| 93 |
+
tgt_q_mask = tgt_mask[:, -1:, :]
|
| 94 |
+
|
| 95 |
+
if self.concat_after:
|
| 96 |
+
tgt_concat = torch.cat(
|
| 97 |
+
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
|
| 98 |
+
)
|
| 99 |
+
x = residual + self.concat_linear1(tgt_concat)
|
| 100 |
+
else:
|
| 101 |
+
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
| 102 |
+
if not self.normalize_before:
|
| 103 |
+
x = self.norm1(x)
|
| 104 |
+
|
| 105 |
+
residual = x
|
| 106 |
+
if self.normalize_before:
|
| 107 |
+
x = self.norm2(x)
|
| 108 |
+
if self.concat_after:
|
| 109 |
+
x_concat = torch.cat(
|
| 110 |
+
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
|
| 111 |
+
)
|
| 112 |
+
x = residual + self.concat_linear2(x_concat)
|
| 113 |
+
else:
|
| 114 |
+
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
|
| 115 |
+
if not self.normalize_before:
|
| 116 |
+
x = self.norm2(x)
|
| 117 |
+
|
| 118 |
+
residual = x
|
| 119 |
+
if self.normalize_before:
|
| 120 |
+
x = self.norm3(x)
|
| 121 |
+
x = residual + self.dropout(self.feed_forward(x))
|
| 122 |
+
if not self.normalize_before:
|
| 123 |
+
x = self.norm3(x)
|
| 124 |
+
|
| 125 |
+
if cache is not None:
|
| 126 |
+
x = torch.cat([cache, x], dim=1)
|
| 127 |
+
|
| 128 |
+
return x, tgt_mask, memory, memory_mask
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _pre_hook(
|
| 132 |
+
state_dict,
|
| 133 |
+
prefix,
|
| 134 |
+
local_metadata,
|
| 135 |
+
strict,
|
| 136 |
+
missing_keys,
|
| 137 |
+
unexpected_keys,
|
| 138 |
+
error_msgs,
|
| 139 |
+
):
|
| 140 |
+
# https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
|
| 141 |
+
rename_state_dict(prefix + "output_norm.", prefix + "after_norm.", state_dict)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class TransformerDecoder(BatchScorerInterface, torch.nn.Module):
|
| 145 |
+
"""Transfomer decoder module.
|
| 146 |
+
|
| 147 |
+
:param int odim: output dim
|
| 148 |
+
:param int attention_dim: dimention of attention
|
| 149 |
+
:param int attention_heads: the number of heads of multi head attention
|
| 150 |
+
:param int linear_units: the number of units of position-wise feed forward
|
| 151 |
+
:param int num_blocks: the number of decoder blocks
|
| 152 |
+
:param float dropout_rate: dropout rate
|
| 153 |
+
:param float attention_dropout_rate: dropout rate for attention
|
| 154 |
+
:param str or torch.nn.Module input_layer: input layer type
|
| 155 |
+
:param bool use_output_layer: whether to use output layer
|
| 156 |
+
:param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
| 157 |
+
:param bool normalize_before: whether to use layer_norm before the first block
|
| 158 |
+
:param bool concat_after: whether to concat attention layer's input and output
|
| 159 |
+
if True, additional linear will be applied.
|
| 160 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
| 161 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
odim,
|
| 167 |
+
attention_dim=256,
|
| 168 |
+
attention_heads=4,
|
| 169 |
+
linear_units=2048,
|
| 170 |
+
num_blocks=6,
|
| 171 |
+
dropout_rate=0.1,
|
| 172 |
+
positional_dropout_rate=0.1,
|
| 173 |
+
self_attention_dropout_rate=0.1,
|
| 174 |
+
src_attention_dropout_rate=0.1,
|
| 175 |
+
input_layer="embed",
|
| 176 |
+
use_output_layer=True,
|
| 177 |
+
pos_enc_class=PositionalEncoding,
|
| 178 |
+
normalize_before=True,
|
| 179 |
+
concat_after=False,
|
| 180 |
+
layer_drop_rate=0.0,
|
| 181 |
+
):
|
| 182 |
+
"""Construct an Decoder object."""
|
| 183 |
+
torch.nn.Module.__init__(self)
|
| 184 |
+
self._register_load_state_dict_pre_hook(_pre_hook)
|
| 185 |
+
if input_layer == "embed":
|
| 186 |
+
self.embed = torch.nn.Sequential(
|
| 187 |
+
torch.nn.Embedding(odim, attention_dim),
|
| 188 |
+
pos_enc_class(attention_dim, positional_dropout_rate),
|
| 189 |
+
)
|
| 190 |
+
elif input_layer == "linear":
|
| 191 |
+
self.embed = torch.nn.Sequential(
|
| 192 |
+
torch.nn.Linear(odim, attention_dim),
|
| 193 |
+
torch.nn.LayerNorm(attention_dim),
|
| 194 |
+
torch.nn.Dropout(dropout_rate),
|
| 195 |
+
torch.nn.ReLU(),
|
| 196 |
+
pos_enc_class(attention_dim, positional_dropout_rate),
|
| 197 |
+
)
|
| 198 |
+
elif isinstance(input_layer, torch.nn.Module):
|
| 199 |
+
self.embed = torch.nn.Sequential(
|
| 200 |
+
input_layer, pos_enc_class(attention_dim, positional_dropout_rate)
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
raise NotImplementedError("only `embed` or torch.nn.Module is supported.")
|
| 204 |
+
self.normalize_before = normalize_before
|
| 205 |
+
self.decoders = repeat(
|
| 206 |
+
num_blocks,
|
| 207 |
+
lambda lnum: DecoderLayer(
|
| 208 |
+
attention_dim,
|
| 209 |
+
MultiHeadedAttention(
|
| 210 |
+
attention_heads, attention_dim, self_attention_dropout_rate
|
| 211 |
+
),
|
| 212 |
+
MultiHeadedAttention(
|
| 213 |
+
attention_heads, attention_dim, src_attention_dropout_rate
|
| 214 |
+
),
|
| 215 |
+
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
| 216 |
+
dropout_rate,
|
| 217 |
+
normalize_before,
|
| 218 |
+
concat_after,
|
| 219 |
+
),
|
| 220 |
+
layer_drop_rate,
|
| 221 |
+
)
|
| 222 |
+
if self.normalize_before:
|
| 223 |
+
self.after_norm = LayerNorm(attention_dim)
|
| 224 |
+
if use_output_layer:
|
| 225 |
+
self.output_layer = torch.nn.Linear(attention_dim, odim)
|
| 226 |
+
else:
|
| 227 |
+
self.output_layer = None
|
| 228 |
+
|
| 229 |
+
def forward(self, tgt, tgt_mask, memory, memory_mask):
|
| 230 |
+
"""Forward decoder.
|
| 231 |
+
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
|
| 232 |
+
if input_layer == "embed"
|
| 233 |
+
input tensor (batch, maxlen_out, #mels)
|
| 234 |
+
in the other cases
|
| 235 |
+
:param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
|
| 236 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 237 |
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
| 238 |
+
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
|
| 239 |
+
:param torch.Tensor memory_mask: encoded memory mask, (batch, maxlen_in)
|
| 240 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 241 |
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
| 242 |
+
:return x: decoded token score before softmax (batch, maxlen_out, token)
|
| 243 |
+
if use_output_layer is True,
|
| 244 |
+
final block outputs (batch, maxlen_out, attention_dim)
|
| 245 |
+
in the other cases
|
| 246 |
+
:rtype: torch.Tensor
|
| 247 |
+
:return tgt_mask: score mask before softmax (batch, maxlen_out)
|
| 248 |
+
:rtype: torch.Tensor
|
| 249 |
+
"""
|
| 250 |
+
x = self.embed(tgt)
|
| 251 |
+
x, tgt_mask, memory, memory_mask = self.decoders(
|
| 252 |
+
x, tgt_mask, memory, memory_mask
|
| 253 |
+
)
|
| 254 |
+
if self.normalize_before:
|
| 255 |
+
x = self.after_norm(x)
|
| 256 |
+
if self.output_layer is not None:
|
| 257 |
+
x = self.output_layer(x)
|
| 258 |
+
return x, tgt_mask
|
| 259 |
+
|
| 260 |
+
def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
|
| 261 |
+
"""Forward one step.
|
| 262 |
+
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
|
| 263 |
+
:param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
|
| 264 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 265 |
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
| 266 |
+
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
|
| 267 |
+
:param List[torch.Tensor] cache:
|
| 268 |
+
cached output list of (batch, max_time_out-1, size)
|
| 269 |
+
:return y, cache: NN output value and cache per `self.decoders`.
|
| 270 |
+
`y.shape` is (batch, maxlen_out, token)
|
| 271 |
+
:rtype: Tuple[torch.Tensor, List[torch.Tensor]]
|
| 272 |
+
"""
|
| 273 |
+
x = self.embed(tgt)
|
| 274 |
+
if cache is None:
|
| 275 |
+
cache = [None] * len(self.decoders)
|
| 276 |
+
new_cache = []
|
| 277 |
+
for c, decoder in zip(cache, self.decoders):
|
| 278 |
+
x, tgt_mask, memory, memory_mask = decoder(
|
| 279 |
+
x, tgt_mask, memory, memory_mask, cache=c
|
| 280 |
+
)
|
| 281 |
+
new_cache.append(x)
|
| 282 |
+
|
| 283 |
+
if self.normalize_before:
|
| 284 |
+
y = self.after_norm(x[:, -1])
|
| 285 |
+
else:
|
| 286 |
+
y = x[:, -1]
|
| 287 |
+
if self.output_layer is not None:
|
| 288 |
+
y = torch.log_softmax(self.output_layer(y), dim=-1)
|
| 289 |
+
|
| 290 |
+
return y, new_cache
|
| 291 |
+
|
| 292 |
+
# beam search API (see ScorerInterface)
|
| 293 |
+
def score(self, ys, state, x):
|
| 294 |
+
"""Score."""
|
| 295 |
+
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
|
| 296 |
+
logp, state = self.forward_one_step(
|
| 297 |
+
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
|
| 298 |
+
)
|
| 299 |
+
return logp.squeeze(0), state
|
| 300 |
+
|
| 301 |
+
# batch beam search API (see BatchScorerInterface)
|
| 302 |
+
def batch_score(
|
| 303 |
+
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
| 304 |
+
) -> Tuple[torch.Tensor, List[Any]]:
|
| 305 |
+
"""Score new token batch (required).
|
| 306 |
+
Args:
|
| 307 |
+
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
| 308 |
+
states (List[Any]): Scorer states for prefix tokens.
|
| 309 |
+
xs (torch.Tensor):
|
| 310 |
+
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
| 311 |
+
Returns:
|
| 312 |
+
tuple[torch.Tensor, List[Any]]: Tuple of
|
| 313 |
+
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
| 314 |
+
and next state list for ys.
|
| 315 |
+
"""
|
| 316 |
+
# merge states
|
| 317 |
+
n_batch = len(ys)
|
| 318 |
+
n_layers = len(self.decoders)
|
| 319 |
+
if states[0] is None:
|
| 320 |
+
batch_state = None
|
| 321 |
+
else:
|
| 322 |
+
# transpose state of [batch, layer] into [layer, batch]
|
| 323 |
+
batch_state = [
|
| 324 |
+
torch.stack([states[b][l] for b in range(n_batch)])
|
| 325 |
+
for l in range(n_layers)
|
| 326 |
+
]
|
| 327 |
+
|
| 328 |
+
# batch decoding
|
| 329 |
+
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
|
| 330 |
+
logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
|
| 331 |
+
|
| 332 |
+
# transpose state of [layer, batch] into [batch, layer]
|
| 333 |
+
state_list = [[states[l][b] for l in range(n_layers)] for b in range(n_batch)]
|
| 334 |
+
return logp, state_list
|
espnet/nets/pytorch_backend/e2e_asr_conformer.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2023 Imperial College London (Pingchuan Ma)
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from espnet.nets.pytorch_backend.frontend.resnet import video_resnet
|
| 10 |
+
from espnet.nets.pytorch_backend.frontend.resnet1d import audio_resnet
|
| 11 |
+
from espnet.nets.pytorch_backend.ctc import CTC
|
| 12 |
+
from espnet.nets.pytorch_backend.encoder.conformer_encoder import ConformerEncoder
|
| 13 |
+
from espnet.nets.pytorch_backend.decoder.transformer_decoder import TransformerDecoder
|
| 14 |
+
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask, th_accuracy
|
| 15 |
+
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
|
| 16 |
+
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import LabelSmoothingLoss
|
| 17 |
+
from espnet.nets.pytorch_backend.transformer.mask import target_mask
|
| 18 |
+
from espnet.nets.scorers.ctc import CTCPrefixScorer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class E2E(torch.nn.Module):
|
| 22 |
+
def __init__(self, odim, modality, ctc_weight=0.1, ignore_id=-1):
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
self.modality = modality
|
| 26 |
+
if modality == "audio":
|
| 27 |
+
self.frontend = audio_resnet()
|
| 28 |
+
elif modality == "video":
|
| 29 |
+
self.frontend = video_resnet()
|
| 30 |
+
|
| 31 |
+
self.proj_encoder = torch.nn.Linear(512, 768)
|
| 32 |
+
|
| 33 |
+
self.encoder = ConformerEncoder(
|
| 34 |
+
attention_dim=768,
|
| 35 |
+
attention_heads=12,
|
| 36 |
+
linear_units=3072,
|
| 37 |
+
num_blocks=12,
|
| 38 |
+
cnn_module_kernel=31,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.decoder = TransformerDecoder(
|
| 42 |
+
odim=odim,
|
| 43 |
+
attention_dim=768,
|
| 44 |
+
attention_heads=12,
|
| 45 |
+
linear_units=3072,
|
| 46 |
+
num_blocks=6,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.blank = 0
|
| 50 |
+
self.sos = odim - 1
|
| 51 |
+
self.eos = odim - 1
|
| 52 |
+
self.odim = odim
|
| 53 |
+
self.ignore_id = ignore_id
|
| 54 |
+
|
| 55 |
+
# loss
|
| 56 |
+
self.ctc_weight = ctc_weight
|
| 57 |
+
self.ctc = CTC(odim, 768, 0.1, reduce=True)
|
| 58 |
+
self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, 0.1, False)
|
| 59 |
+
|
| 60 |
+
def scorers(self):
|
| 61 |
+
return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
|
| 62 |
+
|
| 63 |
+
def forward(self, x, lengths, label):
|
| 64 |
+
if self.modality == "audio":
|
| 65 |
+
lengths = torch.div(lengths, 640, rounding_mode="trunc")
|
| 66 |
+
|
| 67 |
+
padding_mask = make_non_pad_mask(lengths).to(x.device).unsqueeze(-2)
|
| 68 |
+
|
| 69 |
+
x = self.frontend(x)
|
| 70 |
+
x = self.proj_encoder(x)
|
| 71 |
+
x, _ = self.encoder(x, padding_mask)
|
| 72 |
+
|
| 73 |
+
# ctc loss
|
| 74 |
+
loss_ctc, ys_hat = self.ctc(x, lengths, label)
|
| 75 |
+
|
| 76 |
+
# decoder loss
|
| 77 |
+
ys_in_pad, ys_out_pad = add_sos_eos(label, self.sos, self.eos, self.ignore_id)
|
| 78 |
+
ys_mask = target_mask(ys_in_pad, self.ignore_id)
|
| 79 |
+
pred_pad, _ = self.decoder(ys_in_pad, ys_mask, x, padding_mask)
|
| 80 |
+
loss_att = self.criterion(pred_pad, ys_out_pad)
|
| 81 |
+
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
| 82 |
+
|
| 83 |
+
acc = th_accuracy(
|
| 84 |
+
pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return loss, loss_ctc, loss_att, acc
|
espnet/nets/pytorch_backend/encoder/__pycache__/conformer_encoder.cpython-310.pyc
ADDED
|
Binary file (9.5 kB). View file
|
|
|
espnet/nets/pytorch_backend/encoder/__pycache__/conformer_encoder.cpython-311.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
espnet/nets/pytorch_backend/encoder/conformer_encoder.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Encoder definition."""
|
| 8 |
+
|
| 9 |
+
import copy
|
| 10 |
+
import torch
|
| 11 |
+
from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
|
| 12 |
+
from espnet.nets.pytorch_backend.transformer.attention import RelPositionMultiHeadedAttention
|
| 13 |
+
from espnet.nets.pytorch_backend.transformer.embedding import RelPositionalEncoding
|
| 14 |
+
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
| 15 |
+
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
| 16 |
+
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ConvolutionModule(torch.nn.Module):
|
| 20 |
+
def __init__(self, channels, kernel_size, bias=True):
|
| 21 |
+
super().__init__()
|
| 22 |
+
assert (kernel_size - 1) % 2 == 0
|
| 23 |
+
|
| 24 |
+
self.pointwise_cov1 = torch.nn.Conv1d(channels, 2 * channels, 1, bias=bias)
|
| 25 |
+
self.depthwise_conv = torch.nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size - 1) // 2, groups=channels, bias=bias)
|
| 26 |
+
self.norm = torch.nn.BatchNorm1d(channels)
|
| 27 |
+
self.pointwise_cov2 = torch.nn.Conv1d(channels, channels, 1, bias=bias)
|
| 28 |
+
self.activation = torch.nn.SiLU(inplace=True)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
x = x.transpose(1, 2)
|
| 32 |
+
x = torch.nn.functional.glu(self.pointwise_cov1(x), dim=1)
|
| 33 |
+
x = self.activation(self.norm(self.depthwise_conv(x)))
|
| 34 |
+
x = self.pointwise_cov2(x)
|
| 35 |
+
return x.transpose(1, 2)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class EncoderLayer(torch.nn.Module):
|
| 39 |
+
"""Encoder layer module.
|
| 40 |
+
|
| 41 |
+
:param int size: input dim
|
| 42 |
+
:param espnet.nets.pytorch_backend.transformer.attention.
|
| 43 |
+
MultiHeadedAttention self_attn: self attention module
|
| 44 |
+
RelPositionMultiHeadedAttention self_attn: self attention module
|
| 45 |
+
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
|
| 46 |
+
PositionwiseFeedForward feed_forward:
|
| 47 |
+
feed forward module
|
| 48 |
+
:param espnet.nets.pytorch_backend.transformer.convolution.
|
| 49 |
+
ConvolutionModule feed_foreard:
|
| 50 |
+
feed forward module
|
| 51 |
+
:param float dropout_rate: dropout rate
|
| 52 |
+
:param bool normalize_before: whether to use layer_norm before the first block
|
| 53 |
+
:param bool concat_after: whether to concat attention layer's input and output
|
| 54 |
+
if True, additional linear will be applied.
|
| 55 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
| 56 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
| 57 |
+
:param bool macaron_style: whether to use macaron style for PositionwiseFeedForward
|
| 58 |
+
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
size,
|
| 64 |
+
self_attn,
|
| 65 |
+
feed_forward,
|
| 66 |
+
conv_module,
|
| 67 |
+
dropout_rate,
|
| 68 |
+
normalize_before=True,
|
| 69 |
+
concat_after=False,
|
| 70 |
+
macaron_style=False,
|
| 71 |
+
):
|
| 72 |
+
"""Construct an EncoderLayer object."""
|
| 73 |
+
super(EncoderLayer, self).__init__()
|
| 74 |
+
self.self_attn = self_attn
|
| 75 |
+
self.feed_forward = feed_forward
|
| 76 |
+
self.ff_scale = 1.0
|
| 77 |
+
self.conv_module = conv_module
|
| 78 |
+
self.macaron_style = macaron_style
|
| 79 |
+
self.norm_ff = LayerNorm(size) # for the FNN module
|
| 80 |
+
self.norm_mha = LayerNorm(size) # for the MHA module
|
| 81 |
+
if self.macaron_style:
|
| 82 |
+
self.feed_forward_macaron = copy.deepcopy(feed_forward)
|
| 83 |
+
self.ff_scale = 0.5
|
| 84 |
+
# for another FNN module in macaron style
|
| 85 |
+
self.norm_ff_macaron = LayerNorm(size)
|
| 86 |
+
if self.conv_module is not None:
|
| 87 |
+
self.norm_conv = LayerNorm(size) # for the CNN module
|
| 88 |
+
self.norm_final = LayerNorm(size) # for the final output of the block
|
| 89 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 90 |
+
self.size = size
|
| 91 |
+
self.normalize_before = normalize_before
|
| 92 |
+
self.concat_after = concat_after
|
| 93 |
+
if self.concat_after:
|
| 94 |
+
self.concat_linear = torch.nn.Linear(size + size, size)
|
| 95 |
+
|
| 96 |
+
def forward(self, x_input, mask, cache=None):
|
| 97 |
+
"""Compute encoded features.
|
| 98 |
+
|
| 99 |
+
:param torch.Tensor x_input: encoded source features (batch, max_time_in, size)
|
| 100 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 101 |
+
:param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
|
| 102 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 103 |
+
"""
|
| 104 |
+
if isinstance(x_input, tuple):
|
| 105 |
+
x, pos_emb = x_input[0], x_input[1]
|
| 106 |
+
else:
|
| 107 |
+
x, pos_emb = x_input, None
|
| 108 |
+
|
| 109 |
+
# whether to use macaron style
|
| 110 |
+
if self.macaron_style:
|
| 111 |
+
residual = x
|
| 112 |
+
if self.normalize_before:
|
| 113 |
+
x = self.norm_ff_macaron(x)
|
| 114 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
| 115 |
+
if not self.normalize_before:
|
| 116 |
+
x = self.norm_ff_macaron(x)
|
| 117 |
+
|
| 118 |
+
# multi-headed self-attention module
|
| 119 |
+
residual = x
|
| 120 |
+
if self.normalize_before:
|
| 121 |
+
x = self.norm_mha(x)
|
| 122 |
+
|
| 123 |
+
if cache is None:
|
| 124 |
+
x_q = x
|
| 125 |
+
else:
|
| 126 |
+
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
| 127 |
+
x_q = x[:, -1:, :]
|
| 128 |
+
residual = residual[:, -1:, :]
|
| 129 |
+
mask = None if mask is None else mask[:, -1:, :]
|
| 130 |
+
|
| 131 |
+
if pos_emb is not None:
|
| 132 |
+
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
| 133 |
+
else:
|
| 134 |
+
x_att = self.self_attn(x_q, x, x, mask)
|
| 135 |
+
|
| 136 |
+
if self.concat_after:
|
| 137 |
+
x_concat = torch.cat((x, x_att), dim=-1)
|
| 138 |
+
x = residual + self.concat_linear(x_concat)
|
| 139 |
+
else:
|
| 140 |
+
x = residual + self.dropout(x_att)
|
| 141 |
+
if not self.normalize_before:
|
| 142 |
+
x = self.norm_mha(x)
|
| 143 |
+
|
| 144 |
+
# convolution module
|
| 145 |
+
if self.conv_module is not None:
|
| 146 |
+
residual = x
|
| 147 |
+
if self.normalize_before:
|
| 148 |
+
x = self.norm_conv(x)
|
| 149 |
+
x = residual + self.dropout(self.conv_module(x))
|
| 150 |
+
if not self.normalize_before:
|
| 151 |
+
x = self.norm_conv(x)
|
| 152 |
+
|
| 153 |
+
# feed forward module
|
| 154 |
+
residual = x
|
| 155 |
+
if self.normalize_before:
|
| 156 |
+
x = self.norm_ff(x)
|
| 157 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
| 158 |
+
if not self.normalize_before:
|
| 159 |
+
x = self.norm_ff(x)
|
| 160 |
+
|
| 161 |
+
if self.conv_module is not None:
|
| 162 |
+
x = self.norm_final(x)
|
| 163 |
+
|
| 164 |
+
if cache is not None:
|
| 165 |
+
x = torch.cat([cache, x], dim=1)
|
| 166 |
+
|
| 167 |
+
if pos_emb is not None:
|
| 168 |
+
return (x, pos_emb), mask
|
| 169 |
+
else:
|
| 170 |
+
return x, mask
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _pre_hook(
|
| 174 |
+
state_dict,
|
| 175 |
+
prefix,
|
| 176 |
+
local_metadata,
|
| 177 |
+
strict,
|
| 178 |
+
missing_keys,
|
| 179 |
+
unexpected_keys,
|
| 180 |
+
error_msgs,
|
| 181 |
+
):
|
| 182 |
+
rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
|
| 183 |
+
rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class ConformerEncoder(torch.nn.Module):
|
| 187 |
+
"""Transformer encoder module.
|
| 188 |
+
|
| 189 |
+
:param int idim: input dim
|
| 190 |
+
:param int attention_dim: dimention of attention
|
| 191 |
+
:param int attention_heads: the number of heads of multi head attention
|
| 192 |
+
:param int linear_units: the number of units of position-wise feed forward
|
| 193 |
+
:param int num_blocks: the number of decoder blocks
|
| 194 |
+
:param float dropout_rate: dropout rate
|
| 195 |
+
:param float attention_dropout_rate: dropout rate in attention
|
| 196 |
+
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
| 197 |
+
:param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
| 198 |
+
:param bool normalize_before: whether to use layer_norm before the first block
|
| 199 |
+
:param bool concat_after: whether to concat attention layer's input and output
|
| 200 |
+
if True, additional linear will be applied.
|
| 201 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
| 202 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
| 203 |
+
:param str positionwise_layer_type: linear of conv1d
|
| 204 |
+
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
| 205 |
+
:param bool macaron_style: whether to use macaron style for positionwise layer
|
| 206 |
+
:param bool use_cnn_module: whether to use convolution module
|
| 207 |
+
:param bool zero_triu: whether to zero the upper triangular part of attention matrix
|
| 208 |
+
:param int cnn_module_kernel: kernerl size of convolution module
|
| 209 |
+
:param int padding_idx: padding_idx for input_layer=embed
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
attention_dim=768,
|
| 215 |
+
attention_heads=12,
|
| 216 |
+
linear_units=3072,
|
| 217 |
+
num_blocks=12,
|
| 218 |
+
dropout_rate=0.1,
|
| 219 |
+
positional_dropout_rate=0.1,
|
| 220 |
+
attention_dropout_rate=0.0,
|
| 221 |
+
normalize_before=True,
|
| 222 |
+
concat_after=False,
|
| 223 |
+
macaron_style=True,
|
| 224 |
+
use_cnn_module=True,
|
| 225 |
+
zero_triu=False,
|
| 226 |
+
cnn_module_kernel=31,
|
| 227 |
+
padding_idx=-1,
|
| 228 |
+
relu_type="swish",
|
| 229 |
+
layer_drop_rate=0.0,
|
| 230 |
+
):
|
| 231 |
+
"""Construct an Encoder object."""
|
| 232 |
+
super(ConformerEncoder, self).__init__()
|
| 233 |
+
self._register_load_state_dict_pre_hook(_pre_hook)
|
| 234 |
+
|
| 235 |
+
self.embed = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate))
|
| 236 |
+
self.normalize_before = normalize_before
|
| 237 |
+
|
| 238 |
+
positionwise_layer = PositionwiseFeedForward
|
| 239 |
+
positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
|
| 240 |
+
|
| 241 |
+
encoder_attn_layer = RelPositionMultiHeadedAttention
|
| 242 |
+
encoder_attn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
|
| 243 |
+
|
| 244 |
+
convolution_layer = ConvolutionModule
|
| 245 |
+
convolution_layer_args = (attention_dim, cnn_module_kernel)
|
| 246 |
+
|
| 247 |
+
self.encoders = repeat(
|
| 248 |
+
num_blocks,
|
| 249 |
+
lambda lnum: EncoderLayer(
|
| 250 |
+
attention_dim,
|
| 251 |
+
encoder_attn_layer(*encoder_attn_layer_args),
|
| 252 |
+
positionwise_layer(*positionwise_layer_args),
|
| 253 |
+
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
| 254 |
+
dropout_rate,
|
| 255 |
+
normalize_before,
|
| 256 |
+
concat_after,
|
| 257 |
+
macaron_style,
|
| 258 |
+
),
|
| 259 |
+
layer_drop_rate=0.0,
|
| 260 |
+
)
|
| 261 |
+
if self.normalize_before:
|
| 262 |
+
self.after_norm = LayerNorm(attention_dim)
|
| 263 |
+
|
| 264 |
+
def forward(self, xs, masks):
|
| 265 |
+
"""Encode input sequence.
|
| 266 |
+
|
| 267 |
+
:param torch.Tensor xs: input tensor
|
| 268 |
+
:param torch.Tensor masks: input mask
|
| 269 |
+
:return: position embedded tensor and mask
|
| 270 |
+
:rtype Tuple[torch.Tensor, torch.Tensor]:
|
| 271 |
+
"""
|
| 272 |
+
xs = self.embed(xs)
|
| 273 |
+
|
| 274 |
+
xs, masks = self.encoders(xs, masks)
|
| 275 |
+
|
| 276 |
+
if isinstance(xs, tuple):
|
| 277 |
+
xs = xs[0]
|
| 278 |
+
|
| 279 |
+
if self.normalize_before:
|
| 280 |
+
xs = self.after_norm(xs)
|
| 281 |
+
|
| 282 |
+
return xs, masks
|
| 283 |
+
|
| 284 |
+
def forward_one_step(self, xs, masks, cache=None):
|
| 285 |
+
"""Encode input frame.
|
| 286 |
+
|
| 287 |
+
:param torch.Tensor xs: input tensor
|
| 288 |
+
:param torch.Tensor masks: input mask
|
| 289 |
+
:param List[torch.Tensor] cache: cache tensors
|
| 290 |
+
:return: position embedded tensor, mask and new cache
|
| 291 |
+
:rtype Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
| 292 |
+
"""
|
| 293 |
+
xs, masks = self.embed(xs, masks)
|
| 294 |
+
|
| 295 |
+
if cache is None:
|
| 296 |
+
cache = [None for _ in range(len(self.encoders))]
|
| 297 |
+
new_cache = []
|
| 298 |
+
for c, e in zip(cache, self.encoders):
|
| 299 |
+
xs, masks = e(xs, masks, cache=c)
|
| 300 |
+
new_cache.append(xs)
|
| 301 |
+
if self.normalize_before:
|
| 302 |
+
xs = self.after_norm(xs)
|
| 303 |
+
return xs, masks, new_cache
|
espnet/nets/pytorch_backend/frontend/__pycache__/resnet.cpython-310.pyc
ADDED
|
Binary file (6.07 kB). View file
|
|
|
espnet/nets/pytorch_backend/frontend/__pycache__/resnet.cpython-311.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
espnet/nets/pytorch_backend/frontend/__pycache__/resnet1d.cpython-310.pyc
ADDED
|
Binary file (6.23 kB). View file
|
|
|
espnet/nets/pytorch_backend/frontend/__pycache__/resnet1d.cpython-311.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
espnet/nets/pytorch_backend/frontend/resnet.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 5 |
+
"""conv3x3.
|
| 6 |
+
:param in_planes: int, number of channels in the input sequence.
|
| 7 |
+
:param out_planes: int, number of channels produced by the convolution.
|
| 8 |
+
:param stride: int, size of the convolving kernel.
|
| 9 |
+
"""
|
| 10 |
+
return nn.Conv2d(
|
| 11 |
+
in_planes,
|
| 12 |
+
out_planes,
|
| 13 |
+
kernel_size=3,
|
| 14 |
+
stride=stride,
|
| 15 |
+
padding=1,
|
| 16 |
+
bias=False,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def downsample_basic_block(inplanes, outplanes, stride):
|
| 21 |
+
"""downsample_basic_block.
|
| 22 |
+
:param inplanes: int, number of channels in the input sequence.
|
| 23 |
+
:param outplanes: int, number of channels produced by the convolution.
|
| 24 |
+
:param stride: int, size of the convolving kernel.
|
| 25 |
+
"""
|
| 26 |
+
return nn.Sequential(
|
| 27 |
+
nn.Conv2d(
|
| 28 |
+
inplanes,
|
| 29 |
+
outplanes,
|
| 30 |
+
kernel_size=1,
|
| 31 |
+
stride=stride,
|
| 32 |
+
bias=False,
|
| 33 |
+
),
|
| 34 |
+
nn.BatchNorm2d(outplanes),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BasicBlock(nn.Module):
|
| 39 |
+
expansion = 1
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
inplanes,
|
| 44 |
+
planes,
|
| 45 |
+
stride=1,
|
| 46 |
+
downsample=None,
|
| 47 |
+
relu_type="swish",
|
| 48 |
+
):
|
| 49 |
+
"""__init__.
|
| 50 |
+
:param inplanes: int, number of channels in the input sequence.
|
| 51 |
+
:param planes: int, number of channels produced by the convolution.
|
| 52 |
+
:param stride: int, size of the convolving kernel.
|
| 53 |
+
:param downsample: boolean, if True, the temporal resolution is downsampled.
|
| 54 |
+
:param relu_type: str, type of activation function.
|
| 55 |
+
"""
|
| 56 |
+
super(BasicBlock, self).__init__()
|
| 57 |
+
|
| 58 |
+
assert relu_type in ["relu", "prelu", "swish"]
|
| 59 |
+
|
| 60 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 61 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 62 |
+
|
| 63 |
+
if relu_type == "relu":
|
| 64 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 65 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 66 |
+
elif relu_type == "prelu":
|
| 67 |
+
self.relu1 = nn.PReLU(num_parameters=planes)
|
| 68 |
+
self.relu2 = nn.PReLU(num_parameters=planes)
|
| 69 |
+
elif relu_type == "swish":
|
| 70 |
+
self.relu1 = nn.SiLU(inplace=True)
|
| 71 |
+
self.relu2 = nn.SiLU(inplace=True)
|
| 72 |
+
else:
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
# --------
|
| 75 |
+
|
| 76 |
+
self.conv2 = conv3x3(planes, planes)
|
| 77 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 78 |
+
|
| 79 |
+
self.downsample = downsample
|
| 80 |
+
self.stride = stride
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
"""forward.
|
| 84 |
+
:param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
|
| 85 |
+
"""
|
| 86 |
+
residual = x
|
| 87 |
+
out = self.conv1(x)
|
| 88 |
+
out = self.bn1(out)
|
| 89 |
+
out = self.relu1(out)
|
| 90 |
+
out = self.conv2(out)
|
| 91 |
+
out = self.bn2(out)
|
| 92 |
+
if self.downsample is not None:
|
| 93 |
+
residual = self.downsample(x)
|
| 94 |
+
|
| 95 |
+
out += residual
|
| 96 |
+
out = self.relu2(out)
|
| 97 |
+
|
| 98 |
+
return out
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class ResNet(nn.Module):
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
block,
|
| 105 |
+
layers,
|
| 106 |
+
relu_type="swish",
|
| 107 |
+
):
|
| 108 |
+
super(ResNet, self).__init__()
|
| 109 |
+
self.inplanes = 64
|
| 110 |
+
self.relu_type = relu_type
|
| 111 |
+
self.downsample_block = downsample_basic_block
|
| 112 |
+
|
| 113 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 114 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 115 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 116 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 117 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 118 |
+
|
| 119 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 120 |
+
"""_make_layer.
|
| 121 |
+
:param block: torch.nn.Module, class of blocks.
|
| 122 |
+
:param planes: int, number of channels produced by the convolution.
|
| 123 |
+
:param blocks: int, number of layers in a block.
|
| 124 |
+
:param stride: int, size of the convolving kernel.
|
| 125 |
+
"""
|
| 126 |
+
downsample = None
|
| 127 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 128 |
+
downsample = self.downsample_block(
|
| 129 |
+
inplanes=self.inplanes,
|
| 130 |
+
outplanes=planes * block.expansion,
|
| 131 |
+
stride=stride,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
layers = []
|
| 135 |
+
layers.append(
|
| 136 |
+
block(
|
| 137 |
+
self.inplanes,
|
| 138 |
+
planes,
|
| 139 |
+
stride,
|
| 140 |
+
downsample,
|
| 141 |
+
relu_type=self.relu_type,
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
self.inplanes = planes * block.expansion
|
| 145 |
+
for _ in range(1, blocks):
|
| 146 |
+
layers.append(
|
| 147 |
+
block(
|
| 148 |
+
self.inplanes,
|
| 149 |
+
planes,
|
| 150 |
+
relu_type=self.relu_type,
|
| 151 |
+
)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return nn.Sequential(*layers)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
"""forward.
|
| 158 |
+
:param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
|
| 159 |
+
"""
|
| 160 |
+
x = self.layer1(x)
|
| 161 |
+
x = self.layer2(x)
|
| 162 |
+
x = self.layer3(x)
|
| 163 |
+
x = self.layer4(x)
|
| 164 |
+
x = self.avgpool(x)
|
| 165 |
+
x = x.view(x.size(0), -1)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# -- auxiliary functions
|
| 170 |
+
def threeD_to_2D_tensor(x):
|
| 171 |
+
n_batch, n_channels, s_time, sx, sy = x.shape
|
| 172 |
+
x = x.transpose(1, 2)
|
| 173 |
+
return x.reshape(n_batch * s_time, n_channels, sx, sy)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class Conv3dResNet(nn.Module):
|
| 177 |
+
"""Conv3dResNet module"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, backbone_type="resnet", relu_type="swish"):
|
| 180 |
+
"""__init__.
|
| 181 |
+
:param backbone_type: str, the type of a visual front-end.
|
| 182 |
+
:param relu_type: str, activation function used in an audio front-end.
|
| 183 |
+
"""
|
| 184 |
+
super(Conv3dResNet, self).__init__()
|
| 185 |
+
|
| 186 |
+
self.backbone_type = backbone_type
|
| 187 |
+
|
| 188 |
+
self.frontend_nout = 64
|
| 189 |
+
self.trunk = ResNet(
|
| 190 |
+
BasicBlock,
|
| 191 |
+
[2, 2, 2, 2],
|
| 192 |
+
relu_type=relu_type,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# -- frontend3D
|
| 196 |
+
if relu_type == "relu":
|
| 197 |
+
frontend_relu = nn.ReLU(True)
|
| 198 |
+
elif relu_type == "prelu":
|
| 199 |
+
frontend_relu = nn.PReLU(self.frontend_nout)
|
| 200 |
+
elif relu_type == "swish":
|
| 201 |
+
frontend_relu = nn.SiLU(inplace=True)
|
| 202 |
+
|
| 203 |
+
self.frontend3D = nn.Sequential(
|
| 204 |
+
nn.Conv3d(
|
| 205 |
+
in_channels=1,
|
| 206 |
+
out_channels=self.frontend_nout,
|
| 207 |
+
kernel_size=(5, 7, 7),
|
| 208 |
+
stride=(1, 2, 2),
|
| 209 |
+
padding=(2, 3, 3),
|
| 210 |
+
bias=False,
|
| 211 |
+
),
|
| 212 |
+
nn.BatchNorm3d(self.frontend_nout),
|
| 213 |
+
frontend_relu,
|
| 214 |
+
nn.MaxPool3d(
|
| 215 |
+
kernel_size=(1, 3, 3),
|
| 216 |
+
stride=(1, 2, 2),
|
| 217 |
+
padding=(0, 1, 1),
|
| 218 |
+
),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
def forward(self, xs_pad):
|
| 222 |
+
"""forward.
|
| 223 |
+
:param xs_pad: torch.Tensor, batch of padded input sequences.
|
| 224 |
+
"""
|
| 225 |
+
# -- include Channel dimension
|
| 226 |
+
xs_pad = xs_pad.transpose(2, 1)
|
| 227 |
+
B, C, T, H, W = xs_pad.size()
|
| 228 |
+
xs_pad = self.frontend3D(xs_pad)
|
| 229 |
+
Tnew = xs_pad.shape[2] # outpu should be B x C2 x Tnew x H x W
|
| 230 |
+
xs_pad = threeD_to_2D_tensor(xs_pad)
|
| 231 |
+
xs_pad = self.trunk(xs_pad)
|
| 232 |
+
xs_pad = xs_pad.view(B, Tnew, xs_pad.size(1))
|
| 233 |
+
return xs_pad
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def video_resnet():
|
| 237 |
+
return Conv3dResNet()
|
espnet/nets/pytorch_backend/frontend/resnet1d.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 5 |
+
"""conv3x3.
|
| 6 |
+
:param in_planes: int, number of channels in the input sequence.
|
| 7 |
+
:param out_planes: int, number of channels produced by the convolution.
|
| 8 |
+
:param stride: int, size of the convolving kernel.
|
| 9 |
+
"""
|
| 10 |
+
return nn.Conv1d(
|
| 11 |
+
in_planes,
|
| 12 |
+
out_planes,
|
| 13 |
+
kernel_size=3,
|
| 14 |
+
stride=stride,
|
| 15 |
+
padding=1,
|
| 16 |
+
bias=False,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def downsample_basic_block(inplanes, outplanes, stride):
|
| 21 |
+
"""downsample_basic_block.
|
| 22 |
+
:param inplanes: int, number of channels in the input sequence.
|
| 23 |
+
:param outplanes: int, number of channels produced by the convolution.
|
| 24 |
+
:param stride: int, size of the convolving kernel.
|
| 25 |
+
"""
|
| 26 |
+
return nn.Sequential(
|
| 27 |
+
nn.Conv1d(
|
| 28 |
+
inplanes,
|
| 29 |
+
outplanes,
|
| 30 |
+
kernel_size=1,
|
| 31 |
+
stride=stride,
|
| 32 |
+
bias=False,
|
| 33 |
+
),
|
| 34 |
+
nn.BatchNorm1d(outplanes),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BasicBlock1D(nn.Module):
|
| 39 |
+
expansion = 1
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
inplanes,
|
| 44 |
+
planes,
|
| 45 |
+
stride=1,
|
| 46 |
+
downsample=None,
|
| 47 |
+
relu_type="relu",
|
| 48 |
+
):
|
| 49 |
+
"""__init__.
|
| 50 |
+
:param inplanes: int, number of channels in the input sequence.
|
| 51 |
+
:param planes: int, number of channels produced by the convolution.
|
| 52 |
+
:param stride: int, size of the convolving kernel.
|
| 53 |
+
:param downsample: boolean, if True, the temporal resolution is downsampled.
|
| 54 |
+
:param relu_type: str, type of activation function.
|
| 55 |
+
"""
|
| 56 |
+
super(BasicBlock1D, self).__init__()
|
| 57 |
+
|
| 58 |
+
assert relu_type in ["relu", "prelu", "swish"]
|
| 59 |
+
|
| 60 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 61 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
| 62 |
+
|
| 63 |
+
# type of ReLU is an input option
|
| 64 |
+
if relu_type == "relu":
|
| 65 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 66 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 67 |
+
elif relu_type == "prelu":
|
| 68 |
+
self.relu1 = nn.PReLU(num_parameters=planes)
|
| 69 |
+
self.relu2 = nn.PReLU(num_parameters=planes)
|
| 70 |
+
elif relu_type == "swish":
|
| 71 |
+
self.relu1 = nn.SiLU(inplace=True)
|
| 72 |
+
self.relu2 = nn.SiLU(inplace=True)
|
| 73 |
+
else:
|
| 74 |
+
raise NotImplementedError
|
| 75 |
+
# --------
|
| 76 |
+
|
| 77 |
+
self.conv2 = conv3x3(planes, planes)
|
| 78 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
| 79 |
+
|
| 80 |
+
self.downsample = downsample
|
| 81 |
+
self.stride = stride
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
"""forward.
|
| 85 |
+
:param x: torch.Tensor, input tensor with input size (B, C, T)
|
| 86 |
+
"""
|
| 87 |
+
residual = x
|
| 88 |
+
out = self.conv1(x)
|
| 89 |
+
out = self.bn1(out)
|
| 90 |
+
out = self.relu1(out)
|
| 91 |
+
out = self.conv2(out)
|
| 92 |
+
out = self.bn2(out)
|
| 93 |
+
if self.downsample is not None:
|
| 94 |
+
residual = self.downsample(x)
|
| 95 |
+
|
| 96 |
+
out += residual
|
| 97 |
+
out = self.relu2(out)
|
| 98 |
+
|
| 99 |
+
return out
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ResNet1D(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
block,
|
| 106 |
+
layers,
|
| 107 |
+
relu_type="swish",
|
| 108 |
+
a_upsample_ratio=1,
|
| 109 |
+
):
|
| 110 |
+
"""__init__.
|
| 111 |
+
:param block: torch.nn.Module, class of blocks.
|
| 112 |
+
:param layers: List, customised layers in each block.
|
| 113 |
+
:param relu_type: str, type of activation function.
|
| 114 |
+
:param a_upsample_ratio: int, The ratio related to the \
|
| 115 |
+
temporal resolution of output features of the frontend. \
|
| 116 |
+
a_upsample_ratio=1 produce features with a fps of 25.
|
| 117 |
+
"""
|
| 118 |
+
super(ResNet1D, self).__init__()
|
| 119 |
+
self.inplanes = 64
|
| 120 |
+
self.relu_type = relu_type
|
| 121 |
+
self.downsample_block = downsample_basic_block
|
| 122 |
+
self.a_upsample_ratio = a_upsample_ratio
|
| 123 |
+
|
| 124 |
+
self.conv1 = nn.Conv1d(
|
| 125 |
+
in_channels=1,
|
| 126 |
+
out_channels=self.inplanes,
|
| 127 |
+
kernel_size=80,
|
| 128 |
+
stride=4,
|
| 129 |
+
padding=38,
|
| 130 |
+
bias=False,
|
| 131 |
+
)
|
| 132 |
+
self.bn1 = nn.BatchNorm1d(self.inplanes)
|
| 133 |
+
|
| 134 |
+
if relu_type == "relu":
|
| 135 |
+
self.relu = nn.ReLU(inplace=True)
|
| 136 |
+
elif relu_type == "prelu":
|
| 137 |
+
self.relu = nn.PReLU(num_parameters=self.inplanes)
|
| 138 |
+
elif relu_type == "swish":
|
| 139 |
+
self.relu = nn.SiLU(inplace=True)
|
| 140 |
+
|
| 141 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 142 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 143 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 144 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 145 |
+
self.avgpool = nn.AvgPool1d(
|
| 146 |
+
kernel_size=20 // self.a_upsample_ratio,
|
| 147 |
+
stride=20 // self.a_upsample_ratio,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 151 |
+
"""_make_layer.
|
| 152 |
+
:param block: torch.nn.Module, class of blocks.
|
| 153 |
+
:param planes: int, number of channels produced by the convolution.
|
| 154 |
+
:param blocks: int, number of layers in a block.
|
| 155 |
+
:param stride: int, size of the convolving kernel.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
downsample = None
|
| 159 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 160 |
+
downsample = self.downsample_block(
|
| 161 |
+
inplanes=self.inplanes,
|
| 162 |
+
outplanes=planes * block.expansion,
|
| 163 |
+
stride=stride,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
layers = []
|
| 167 |
+
layers.append(
|
| 168 |
+
block(
|
| 169 |
+
self.inplanes,
|
| 170 |
+
planes,
|
| 171 |
+
stride,
|
| 172 |
+
downsample,
|
| 173 |
+
relu_type=self.relu_type,
|
| 174 |
+
)
|
| 175 |
+
)
|
| 176 |
+
self.inplanes = planes * block.expansion
|
| 177 |
+
for _ in range(1, blocks):
|
| 178 |
+
layers.append(
|
| 179 |
+
block(
|
| 180 |
+
self.inplanes,
|
| 181 |
+
planes,
|
| 182 |
+
relu_type=self.relu_type,
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
return nn.Sequential(*layers)
|
| 187 |
+
|
| 188 |
+
def forward(self, x):
|
| 189 |
+
"""forward.
|
| 190 |
+
:param x: torch.Tensor, input tensor with input size (B, C, T)
|
| 191 |
+
"""
|
| 192 |
+
x = self.conv1(x)
|
| 193 |
+
x = self.bn1(x)
|
| 194 |
+
x = self.relu(x)
|
| 195 |
+
|
| 196 |
+
x = self.layer1(x)
|
| 197 |
+
x = self.layer2(x)
|
| 198 |
+
x = self.layer3(x)
|
| 199 |
+
x = self.layer4(x)
|
| 200 |
+
x = self.avgpool(x)
|
| 201 |
+
return x
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class Conv1dResNet(nn.Module):
|
| 205 |
+
"""Conv1dResNet"""
|
| 206 |
+
|
| 207 |
+
def __init__(self, relu_type="swish", a_upsample_ratio=1):
|
| 208 |
+
"""__init__.
|
| 209 |
+
:param relu_type: str, Activation function used in an audio front-end.
|
| 210 |
+
:param a_upsample_ratio: int, The ratio related to the \
|
| 211 |
+
temporal resolution of output features of the frontend. \
|
| 212 |
+
a_upsample_ratio=1 produce features with a fps of 25.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
super(Conv1dResNet, self).__init__()
|
| 216 |
+
self.a_upsample_ratio = a_upsample_ratio
|
| 217 |
+
self.trunk = ResNet1D(
|
| 218 |
+
BasicBlock1D,
|
| 219 |
+
[2, 2, 2, 2],
|
| 220 |
+
relu_type=relu_type,
|
| 221 |
+
a_upsample_ratio=a_upsample_ratio,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def forward(self, xs_pad):
|
| 225 |
+
"""forward.
|
| 226 |
+
:param xs_pad: torch.Tensor, batch of padded input sequences (B, Tmax, idim)
|
| 227 |
+
"""
|
| 228 |
+
B, T, C = xs_pad.size()
|
| 229 |
+
xs_pad = xs_pad[:, : T // 640 * 640, :]
|
| 230 |
+
xs_pad = xs_pad.transpose(1, 2)
|
| 231 |
+
xs_pad = self.trunk(xs_pad)
|
| 232 |
+
# -- from B x C x T to B x T x C
|
| 233 |
+
xs_pad = xs_pad.transpose(1, 2)
|
| 234 |
+
return xs_pad
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def audio_resnet():
|
| 238 |
+
return Conv1dResNet()
|
espnet/nets/pytorch_backend/nets_utils.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
"""Network related utility tools."""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def to_device(m, x):
|
| 13 |
+
"""Send tensor into the device of the module.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
m (torch.nn.Module): Torch module.
|
| 17 |
+
x (Tensor): Torch tensor.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tensor: Torch tensor located in the same place as torch module.
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
if isinstance(m, torch.nn.Module):
|
| 24 |
+
device = next(m.parameters()).device
|
| 25 |
+
elif isinstance(m, torch.Tensor):
|
| 26 |
+
device = m.device
|
| 27 |
+
else:
|
| 28 |
+
raise TypeError(
|
| 29 |
+
"Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
|
| 30 |
+
)
|
| 31 |
+
return x.to(device)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def pad_list(xs, pad_value):
|
| 35 |
+
"""Perform padding for the list of tensors.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
| 39 |
+
pad_value (float): Value for padding.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Tensor: Padded tensor (B, Tmax, `*`).
|
| 43 |
+
|
| 44 |
+
Examples:
|
| 45 |
+
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
| 46 |
+
>>> x
|
| 47 |
+
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
| 48 |
+
>>> pad_list(x, 0)
|
| 49 |
+
tensor([[1., 1., 1., 1.],
|
| 50 |
+
[1., 1., 0., 0.],
|
| 51 |
+
[1., 0., 0., 0.]])
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
n_batch = len(xs)
|
| 55 |
+
max_len = max(x.size(0) for x in xs)
|
| 56 |
+
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
| 57 |
+
|
| 58 |
+
for i in range(n_batch):
|
| 59 |
+
pad[i, : xs[i].size(0)] = xs[i]
|
| 60 |
+
|
| 61 |
+
return pad
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
| 65 |
+
"""Make mask tensor containing indices of padded part.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
| 69 |
+
xs (Tensor, optional): The reference tensor.
|
| 70 |
+
If set, masks will be the same shape as this tensor.
|
| 71 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
| 72 |
+
See the example.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Tensor: Mask tensor containing indices of padded part.
|
| 76 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 77 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
| 78 |
+
|
| 79 |
+
Examples:
|
| 80 |
+
With only lengths.
|
| 81 |
+
|
| 82 |
+
>>> lengths = [5, 3, 2]
|
| 83 |
+
>>> make_pad_mask(lengths)
|
| 84 |
+
masks = [[0, 0, 0, 0 ,0],
|
| 85 |
+
[0, 0, 0, 1, 1],
|
| 86 |
+
[0, 0, 1, 1, 1]]
|
| 87 |
+
|
| 88 |
+
With the reference tensor.
|
| 89 |
+
|
| 90 |
+
>>> xs = torch.zeros((3, 2, 4))
|
| 91 |
+
>>> make_pad_mask(lengths, xs)
|
| 92 |
+
tensor([[[0, 0, 0, 0],
|
| 93 |
+
[0, 0, 0, 0]],
|
| 94 |
+
[[0, 0, 0, 1],
|
| 95 |
+
[0, 0, 0, 1]],
|
| 96 |
+
[[0, 0, 1, 1],
|
| 97 |
+
[0, 0, 1, 1]]], dtype=torch.uint8)
|
| 98 |
+
>>> xs = torch.zeros((3, 2, 6))
|
| 99 |
+
>>> make_pad_mask(lengths, xs)
|
| 100 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
| 101 |
+
[0, 0, 0, 0, 0, 1]],
|
| 102 |
+
[[0, 0, 0, 1, 1, 1],
|
| 103 |
+
[0, 0, 0, 1, 1, 1]],
|
| 104 |
+
[[0, 0, 1, 1, 1, 1],
|
| 105 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
| 106 |
+
|
| 107 |
+
With the reference tensor and dimension indicator.
|
| 108 |
+
|
| 109 |
+
>>> xs = torch.zeros((3, 6, 6))
|
| 110 |
+
>>> make_pad_mask(lengths, xs, 1)
|
| 111 |
+
tensor([[[0, 0, 0, 0, 0, 0],
|
| 112 |
+
[0, 0, 0, 0, 0, 0],
|
| 113 |
+
[0, 0, 0, 0, 0, 0],
|
| 114 |
+
[0, 0, 0, 0, 0, 0],
|
| 115 |
+
[0, 0, 0, 0, 0, 0],
|
| 116 |
+
[1, 1, 1, 1, 1, 1]],
|
| 117 |
+
[[0, 0, 0, 0, 0, 0],
|
| 118 |
+
[0, 0, 0, 0, 0, 0],
|
| 119 |
+
[0, 0, 0, 0, 0, 0],
|
| 120 |
+
[1, 1, 1, 1, 1, 1],
|
| 121 |
+
[1, 1, 1, 1, 1, 1],
|
| 122 |
+
[1, 1, 1, 1, 1, 1]],
|
| 123 |
+
[[0, 0, 0, 0, 0, 0],
|
| 124 |
+
[0, 0, 0, 0, 0, 0],
|
| 125 |
+
[1, 1, 1, 1, 1, 1],
|
| 126 |
+
[1, 1, 1, 1, 1, 1],
|
| 127 |
+
[1, 1, 1, 1, 1, 1],
|
| 128 |
+
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
| 129 |
+
>>> make_pad_mask(lengths, xs, 2)
|
| 130 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
| 131 |
+
[0, 0, 0, 0, 0, 1],
|
| 132 |
+
[0, 0, 0, 0, 0, 1],
|
| 133 |
+
[0, 0, 0, 0, 0, 1],
|
| 134 |
+
[0, 0, 0, 0, 0, 1],
|
| 135 |
+
[0, 0, 0, 0, 0, 1]],
|
| 136 |
+
[[0, 0, 0, 1, 1, 1],
|
| 137 |
+
[0, 0, 0, 1, 1, 1],
|
| 138 |
+
[0, 0, 0, 1, 1, 1],
|
| 139 |
+
[0, 0, 0, 1, 1, 1],
|
| 140 |
+
[0, 0, 0, 1, 1, 1],
|
| 141 |
+
[0, 0, 0, 1, 1, 1]],
|
| 142 |
+
[[0, 0, 1, 1, 1, 1],
|
| 143 |
+
[0, 0, 1, 1, 1, 1],
|
| 144 |
+
[0, 0, 1, 1, 1, 1],
|
| 145 |
+
[0, 0, 1, 1, 1, 1],
|
| 146 |
+
[0, 0, 1, 1, 1, 1],
|
| 147 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
| 148 |
+
|
| 149 |
+
"""
|
| 150 |
+
if length_dim == 0:
|
| 151 |
+
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
| 152 |
+
|
| 153 |
+
if not isinstance(lengths, list):
|
| 154 |
+
lengths = lengths.tolist()
|
| 155 |
+
bs = int(len(lengths))
|
| 156 |
+
if maxlen is None:
|
| 157 |
+
if xs is None:
|
| 158 |
+
maxlen = int(max(lengths))
|
| 159 |
+
else:
|
| 160 |
+
maxlen = xs.size(length_dim)
|
| 161 |
+
else:
|
| 162 |
+
assert xs is None
|
| 163 |
+
assert maxlen >= int(max(lengths))
|
| 164 |
+
|
| 165 |
+
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
| 166 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
| 167 |
+
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
| 168 |
+
mask = seq_range_expand >= seq_length_expand
|
| 169 |
+
|
| 170 |
+
if xs is not None:
|
| 171 |
+
assert xs.size(0) == bs, (xs.size(0), bs)
|
| 172 |
+
|
| 173 |
+
if length_dim < 0:
|
| 174 |
+
length_dim = xs.dim() + length_dim
|
| 175 |
+
# ind = (:, None, ..., None, :, , None, ..., None)
|
| 176 |
+
ind = tuple(
|
| 177 |
+
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
|
| 178 |
+
)
|
| 179 |
+
mask = mask[ind].expand_as(xs).to(xs.device)
|
| 180 |
+
return mask
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
| 184 |
+
"""Make mask tensor containing indices of non-padded part.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
| 188 |
+
xs (Tensor, optional): The reference tensor.
|
| 189 |
+
If set, masks will be the same shape as this tensor.
|
| 190 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
| 191 |
+
See the example.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
ByteTensor: mask tensor containing indices of padded part.
|
| 195 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 196 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
| 197 |
+
|
| 198 |
+
Examples:
|
| 199 |
+
With only lengths.
|
| 200 |
+
|
| 201 |
+
>>> lengths = [5, 3, 2]
|
| 202 |
+
>>> make_non_pad_mask(lengths)
|
| 203 |
+
masks = [[1, 1, 1, 1 ,1],
|
| 204 |
+
[1, 1, 1, 0, 0],
|
| 205 |
+
[1, 1, 0, 0, 0]]
|
| 206 |
+
|
| 207 |
+
With the reference tensor.
|
| 208 |
+
|
| 209 |
+
>>> xs = torch.zeros((3, 2, 4))
|
| 210 |
+
>>> make_non_pad_mask(lengths, xs)
|
| 211 |
+
tensor([[[1, 1, 1, 1],
|
| 212 |
+
[1, 1, 1, 1]],
|
| 213 |
+
[[1, 1, 1, 0],
|
| 214 |
+
[1, 1, 1, 0]],
|
| 215 |
+
[[1, 1, 0, 0],
|
| 216 |
+
[1, 1, 0, 0]]], dtype=torch.uint8)
|
| 217 |
+
>>> xs = torch.zeros((3, 2, 6))
|
| 218 |
+
>>> make_non_pad_mask(lengths, xs)
|
| 219 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
| 220 |
+
[1, 1, 1, 1, 1, 0]],
|
| 221 |
+
[[1, 1, 1, 0, 0, 0],
|
| 222 |
+
[1, 1, 1, 0, 0, 0]],
|
| 223 |
+
[[1, 1, 0, 0, 0, 0],
|
| 224 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
| 225 |
+
|
| 226 |
+
With the reference tensor and dimension indicator.
|
| 227 |
+
|
| 228 |
+
>>> xs = torch.zeros((3, 6, 6))
|
| 229 |
+
>>> make_non_pad_mask(lengths, xs, 1)
|
| 230 |
+
tensor([[[1, 1, 1, 1, 1, 1],
|
| 231 |
+
[1, 1, 1, 1, 1, 1],
|
| 232 |
+
[1, 1, 1, 1, 1, 1],
|
| 233 |
+
[1, 1, 1, 1, 1, 1],
|
| 234 |
+
[1, 1, 1, 1, 1, 1],
|
| 235 |
+
[0, 0, 0, 0, 0, 0]],
|
| 236 |
+
[[1, 1, 1, 1, 1, 1],
|
| 237 |
+
[1, 1, 1, 1, 1, 1],
|
| 238 |
+
[1, 1, 1, 1, 1, 1],
|
| 239 |
+
[0, 0, 0, 0, 0, 0],
|
| 240 |
+
[0, 0, 0, 0, 0, 0],
|
| 241 |
+
[0, 0, 0, 0, 0, 0]],
|
| 242 |
+
[[1, 1, 1, 1, 1, 1],
|
| 243 |
+
[1, 1, 1, 1, 1, 1],
|
| 244 |
+
[0, 0, 0, 0, 0, 0],
|
| 245 |
+
[0, 0, 0, 0, 0, 0],
|
| 246 |
+
[0, 0, 0, 0, 0, 0],
|
| 247 |
+
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
| 248 |
+
>>> make_non_pad_mask(lengths, xs, 2)
|
| 249 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
| 250 |
+
[1, 1, 1, 1, 1, 0],
|
| 251 |
+
[1, 1, 1, 1, 1, 0],
|
| 252 |
+
[1, 1, 1, 1, 1, 0],
|
| 253 |
+
[1, 1, 1, 1, 1, 0],
|
| 254 |
+
[1, 1, 1, 1, 1, 0]],
|
| 255 |
+
[[1, 1, 1, 0, 0, 0],
|
| 256 |
+
[1, 1, 1, 0, 0, 0],
|
| 257 |
+
[1, 1, 1, 0, 0, 0],
|
| 258 |
+
[1, 1, 1, 0, 0, 0],
|
| 259 |
+
[1, 1, 1, 0, 0, 0],
|
| 260 |
+
[1, 1, 1, 0, 0, 0]],
|
| 261 |
+
[[1, 1, 0, 0, 0, 0],
|
| 262 |
+
[1, 1, 0, 0, 0, 0],
|
| 263 |
+
[1, 1, 0, 0, 0, 0],
|
| 264 |
+
[1, 1, 0, 0, 0, 0],
|
| 265 |
+
[1, 1, 0, 0, 0, 0],
|
| 266 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
| 267 |
+
|
| 268 |
+
"""
|
| 269 |
+
return ~make_pad_mask(lengths, xs, length_dim)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
| 273 |
+
"""Calculate accuracy.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
| 277 |
+
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
| 278 |
+
ignore_label (int): Ignore label id.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
float: Accuracy value (0.0 - 1.0).
|
| 282 |
+
|
| 283 |
+
"""
|
| 284 |
+
pad_pred = pad_outputs.view(
|
| 285 |
+
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
|
| 286 |
+
).argmax(2)
|
| 287 |
+
mask = pad_targets != ignore_label
|
| 288 |
+
numerator = torch.sum(
|
| 289 |
+
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
|
| 290 |
+
)
|
| 291 |
+
denominator = torch.sum(mask)
|
| 292 |
+
return float(numerator) / float(denominator)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def rename_state_dict(
|
| 296 |
+
old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
|
| 297 |
+
):
|
| 298 |
+
"""Replace keys of old prefix with new prefix in state dict."""
|
| 299 |
+
# need this list not to break the dict iterator
|
| 300 |
+
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
| 301 |
+
if len(old_keys) > 0:
|
| 302 |
+
logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
|
| 303 |
+
for k in old_keys:
|
| 304 |
+
v = state_dict.pop(k)
|
| 305 |
+
new_k = k.replace(old_prefix, new_prefix)
|
| 306 |
+
state_dict[new_k] = v
|
espnet/nets/pytorch_backend/transformer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Initialize sub package."""
|
espnet/nets/pytorch_backend/transformer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (205 Bytes). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (236 Bytes). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/add_sos_eos.cpython-310.pyc
ADDED
|
Binary file (1.32 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/add_sos_eos.cpython-311.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/attention.cpython-310.pyc
ADDED
|
Binary file (7.08 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/attention.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/embedding.cpython-310.pyc
ADDED
|
Binary file (5.9 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/embedding.cpython-311.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/label_smoothing_loss.cpython-310.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/label_smoothing_loss.cpython-311.pyc
ADDED
|
Binary file (3.71 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/layer_norm.cpython-310.pyc
ADDED
|
Binary file (1.17 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/layer_norm.cpython-311.pyc
ADDED
|
Binary file (1.84 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/mask.cpython-310.pyc
ADDED
|
Binary file (1.18 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/mask.cpython-311.pyc
ADDED
|
Binary file (1.59 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/positionwise_feed_forward.cpython-310.pyc
ADDED
|
Binary file (1.21 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/repeat.cpython-310.pyc
ADDED
|
Binary file (1.73 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/__pycache__/repeat.cpython-311.pyc
ADDED
|
Binary file (2.43 kB). View file
|
|
|
espnet/nets/pytorch_backend/transformer/add_sos_eos.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Unility funcitons for Transformer."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def add_sos_eos(ys_pad, sos, eos, ignore_id):
|
| 13 |
+
"""Add <sos> and <eos> labels.
|
| 14 |
+
|
| 15 |
+
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
|
| 16 |
+
:param int sos: index of <sos>
|
| 17 |
+
:param int eos: index of <eeos>
|
| 18 |
+
:param int ignore_id: index of padding
|
| 19 |
+
:return: padded tensor (B, Lmax)
|
| 20 |
+
:rtype: torch.Tensor
|
| 21 |
+
:return: padded tensor (B, Lmax)
|
| 22 |
+
:rtype: torch.Tensor
|
| 23 |
+
"""
|
| 24 |
+
from espnet.nets.pytorch_backend.nets_utils import pad_list
|
| 25 |
+
|
| 26 |
+
_sos = ys_pad.new([sos])
|
| 27 |
+
_eos = ys_pad.new([eos])
|
| 28 |
+
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
|
| 29 |
+
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
|
| 30 |
+
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
|
| 31 |
+
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
|
espnet/nets/pytorch_backend/transformer/attention.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Multi-Head Attention layer definition."""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import numpy
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MultiHeadedAttention(nn.Module):
|
| 17 |
+
"""Multi-Head Attention layer.
|
| 18 |
+
Args:
|
| 19 |
+
n_head (int): The number of heads.
|
| 20 |
+
n_feat (int): The number of features.
|
| 21 |
+
dropout_rate (float): Dropout rate.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, n_head, n_feat, dropout_rate):
|
| 25 |
+
"""Construct an MultiHeadedAttention object."""
|
| 26 |
+
super(MultiHeadedAttention, self).__init__()
|
| 27 |
+
assert n_feat % n_head == 0
|
| 28 |
+
# We assume d_v always equals d_k
|
| 29 |
+
self.d_k = n_feat // n_head
|
| 30 |
+
self.h = n_head
|
| 31 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
| 32 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
| 33 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
| 34 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
| 35 |
+
self.attn = None
|
| 36 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 37 |
+
|
| 38 |
+
def forward_qkv(self, query, key, value):
|
| 39 |
+
"""Transform query, key and value.
|
| 40 |
+
Args:
|
| 41 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 42 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 43 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 44 |
+
Returns:
|
| 45 |
+
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
| 46 |
+
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
| 47 |
+
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
| 48 |
+
"""
|
| 49 |
+
n_batch = query.size(0)
|
| 50 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
| 51 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
| 52 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
| 53 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
| 54 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 55 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 56 |
+
|
| 57 |
+
return q, k, v
|
| 58 |
+
|
| 59 |
+
def forward_attention(self, value, scores, mask, rtn_attn=False):
|
| 60 |
+
"""Compute attention context vector.
|
| 61 |
+
Args:
|
| 62 |
+
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
| 63 |
+
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
| 64 |
+
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
| 65 |
+
rtn_attn (boolean): Flag of return attention score
|
| 66 |
+
Returns:
|
| 67 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
| 68 |
+
weighted by the attention score (#batch, time1, time2).
|
| 69 |
+
"""
|
| 70 |
+
n_batch = value.size(0)
|
| 71 |
+
if mask is not None:
|
| 72 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
| 73 |
+
min_value = torch.finfo(scores.dtype).min
|
| 74 |
+
scores = scores.masked_fill(mask, min_value)
|
| 75 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
| 76 |
+
mask, 0.0
|
| 77 |
+
) # (batch, head, time1, time2)
|
| 78 |
+
else:
|
| 79 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 80 |
+
|
| 81 |
+
p_attn = self.dropout(self.attn)
|
| 82 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
| 83 |
+
x = (
|
| 84 |
+
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
| 85 |
+
) # (batch, time1, d_model)
|
| 86 |
+
if rtn_attn:
|
| 87 |
+
return self.linear_out(x), self.attn
|
| 88 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
| 89 |
+
|
| 90 |
+
def forward(self, query, key, value, mask, rtn_attn=False):
|
| 91 |
+
"""Compute scaled dot product attention.
|
| 92 |
+
Args:
|
| 93 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 94 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 95 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 96 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 97 |
+
(#batch, time1, time2).
|
| 98 |
+
rtn_attn (boolean): Flag of return attention score
|
| 99 |
+
Returns:
|
| 100 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 101 |
+
"""
|
| 102 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 103 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 104 |
+
return self.forward_attention(v, scores, mask, rtn_attn)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
| 108 |
+
"""Multi-Head Attention layer with relative position encoding (new implementation).
|
| 109 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
| 110 |
+
Paper: https://arxiv.org/abs/1901.02860
|
| 111 |
+
Args:
|
| 112 |
+
n_head (int): The number of heads.
|
| 113 |
+
n_feat (int): The number of features.
|
| 114 |
+
dropout_rate (float): Dropout rate.
|
| 115 |
+
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
| 119 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
| 120 |
+
super().__init__(n_head, n_feat, dropout_rate)
|
| 121 |
+
self.zero_triu = zero_triu
|
| 122 |
+
# linear transformation for positional encoding
|
| 123 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
| 124 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 125 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 126 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 127 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 128 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
| 129 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
| 130 |
+
|
| 131 |
+
def rel_shift(self, x):
|
| 132 |
+
"""Compute relative positional encoding.
|
| 133 |
+
Args:
|
| 134 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
| 135 |
+
time1 means the length of query vector.
|
| 136 |
+
Returns:
|
| 137 |
+
torch.Tensor: Output tensor.
|
| 138 |
+
"""
|
| 139 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
| 140 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
| 141 |
+
|
| 142 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
| 143 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
| 144 |
+
:, :, :, : x.size(-1) // 2 + 1
|
| 145 |
+
] # only keep the positions from 0 to time2
|
| 146 |
+
|
| 147 |
+
if self.zero_triu:
|
| 148 |
+
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
|
| 149 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
| 150 |
+
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
def forward(self, query, key, value, pos_emb, mask):
|
| 154 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
| 155 |
+
Args:
|
| 156 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 157 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 158 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 159 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
| 160 |
+
(#batch, 2*time1-1, size).
|
| 161 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 162 |
+
(#batch, time1, time2).
|
| 163 |
+
Returns:
|
| 164 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 165 |
+
"""
|
| 166 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 167 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
| 168 |
+
|
| 169 |
+
n_batch_pos = pos_emb.size(0)
|
| 170 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
| 171 |
+
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
| 172 |
+
|
| 173 |
+
# (batch, head, time1, d_k)
|
| 174 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
| 175 |
+
# (batch, head, time1, d_k)
|
| 176 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
| 177 |
+
|
| 178 |
+
# compute attention score
|
| 179 |
+
# first compute matrix a and matrix c
|
| 180 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 181 |
+
# (batch, head, time1, time2)
|
| 182 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
| 183 |
+
|
| 184 |
+
# compute matrix b and matrix d
|
| 185 |
+
# (batch, head, time1, 2*time1-1)
|
| 186 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
| 187 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
| 188 |
+
|
| 189 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
| 190 |
+
self.d_k
|
| 191 |
+
) # (batch, head, time1, time2)
|
| 192 |
+
|
| 193 |
+
return self.forward_attention(v, scores, mask)
|
espnet/nets/pytorch_backend/transformer/embedding.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Positional Encoding Module."""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _pre_hook(
|
| 15 |
+
state_dict,
|
| 16 |
+
prefix,
|
| 17 |
+
local_metadata,
|
| 18 |
+
strict,
|
| 19 |
+
missing_keys,
|
| 20 |
+
unexpected_keys,
|
| 21 |
+
error_msgs,
|
| 22 |
+
):
|
| 23 |
+
"""Perform pre-hook in load_state_dict for backward compatibility.
|
| 24 |
+
Note:
|
| 25 |
+
We saved self.pe until v.0.5.2 but we have omitted it later.
|
| 26 |
+
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
| 27 |
+
"""
|
| 28 |
+
k = prefix + "pe"
|
| 29 |
+
if k in state_dict:
|
| 30 |
+
state_dict.pop(k)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class PositionalEncoding(torch.nn.Module):
|
| 34 |
+
"""Positional encoding.
|
| 35 |
+
Args:
|
| 36 |
+
d_model (int): Embedding dimension.
|
| 37 |
+
dropout_rate (float): Dropout rate.
|
| 38 |
+
max_len (int): Maximum input length.
|
| 39 |
+
reverse (bool): Whether to reverse the input position. Only for
|
| 40 |
+
the class LegacyRelPositionalEncoding. We remove it in the current
|
| 41 |
+
class RelPositionalEncoding.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
| 45 |
+
"""Construct an PositionalEncoding object."""
|
| 46 |
+
super(PositionalEncoding, self).__init__()
|
| 47 |
+
self.d_model = d_model
|
| 48 |
+
self.reverse = reverse
|
| 49 |
+
self.xscale = math.sqrt(self.d_model)
|
| 50 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 51 |
+
self.pe = None
|
| 52 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
| 53 |
+
self._register_load_state_dict_pre_hook(_pre_hook)
|
| 54 |
+
|
| 55 |
+
def extend_pe(self, x):
|
| 56 |
+
"""Reset the positional encodings."""
|
| 57 |
+
if self.pe is not None:
|
| 58 |
+
if self.pe.size(1) >= x.size(1):
|
| 59 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 60 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 61 |
+
return
|
| 62 |
+
pe = torch.zeros(x.size(1), self.d_model)
|
| 63 |
+
if self.reverse:
|
| 64 |
+
position = torch.arange(
|
| 65 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
| 66 |
+
).unsqueeze(1)
|
| 67 |
+
else:
|
| 68 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 69 |
+
div_term = torch.exp(
|
| 70 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 71 |
+
* -(math.log(10000.0) / self.d_model)
|
| 72 |
+
)
|
| 73 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 74 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 75 |
+
pe = pe.unsqueeze(0)
|
| 76 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 77 |
+
|
| 78 |
+
def forward(self, x: torch.Tensor):
|
| 79 |
+
"""Add positional encoding.
|
| 80 |
+
Args:
|
| 81 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 82 |
+
Returns:
|
| 83 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 84 |
+
"""
|
| 85 |
+
self.extend_pe(x)
|
| 86 |
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
| 87 |
+
return self.dropout(x)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ScaledPositionalEncoding(PositionalEncoding):
|
| 91 |
+
"""Scaled positional encoding module.
|
| 92 |
+
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
| 93 |
+
Args:
|
| 94 |
+
d_model (int): Embedding dimension.
|
| 95 |
+
dropout_rate (float): Dropout rate.
|
| 96 |
+
max_len (int): Maximum input length.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
| 100 |
+
"""Initialize class."""
|
| 101 |
+
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
| 102 |
+
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
| 103 |
+
|
| 104 |
+
def reset_parameters(self):
|
| 105 |
+
"""Reset parameters."""
|
| 106 |
+
self.alpha.data = torch.tensor(1.0)
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
"""Add positional encoding.
|
| 110 |
+
Args:
|
| 111 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 112 |
+
Returns:
|
| 113 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 114 |
+
"""
|
| 115 |
+
self.extend_pe(x)
|
| 116 |
+
x = x + self.alpha * self.pe[:, : x.size(1)]
|
| 117 |
+
return self.dropout(x)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class RelPositionalEncoding(torch.nn.Module):
|
| 121 |
+
"""Relative positional encoding module (new implementation).
|
| 122 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
| 123 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
| 124 |
+
Args:
|
| 125 |
+
d_model (int): Embedding dimension.
|
| 126 |
+
dropout_rate (float): Dropout rate.
|
| 127 |
+
max_len (int): Maximum input length.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
| 131 |
+
"""Construct an PositionalEncoding object."""
|
| 132 |
+
super(RelPositionalEncoding, self).__init__()
|
| 133 |
+
self.d_model = d_model
|
| 134 |
+
self.xscale = math.sqrt(self.d_model)
|
| 135 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 136 |
+
self.pe = None
|
| 137 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
| 138 |
+
|
| 139 |
+
def extend_pe(self, x):
|
| 140 |
+
"""Reset the positional encodings."""
|
| 141 |
+
if self.pe is not None:
|
| 142 |
+
# self.pe contains both positive and negative parts
|
| 143 |
+
# the length of self.pe is 2 * input_len - 1
|
| 144 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
| 145 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 146 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 147 |
+
return
|
| 148 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
| 149 |
+
# position of key vector. We use position relative positions when keys
|
| 150 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 151 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
| 152 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
| 153 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 154 |
+
div_term = torch.exp(
|
| 155 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 156 |
+
* -(math.log(10000.0) / self.d_model)
|
| 157 |
+
)
|
| 158 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 159 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 160 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 161 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 162 |
+
|
| 163 |
+
# Reserve the order of positive indices and concat both positive and
|
| 164 |
+
# negative indices. This is used to support the shifting trick
|
| 165 |
+
# as in https://arxiv.org/abs/1901.02860
|
| 166 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 167 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 168 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 169 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 170 |
+
|
| 171 |
+
def forward(self, x: torch.Tensor):
|
| 172 |
+
"""Add positional encoding.
|
| 173 |
+
Args:
|
| 174 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 175 |
+
Returns:
|
| 176 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 177 |
+
"""
|
| 178 |
+
self.extend_pe(x)
|
| 179 |
+
x = x * self.xscale
|
| 180 |
+
pos_emb = self.pe[
|
| 181 |
+
:,
|
| 182 |
+
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
|
| 183 |
+
]
|
| 184 |
+
return self.dropout(x), self.dropout(pos_emb)
|
espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Label smoothing module."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LabelSmoothingLoss(nn.Module):
|
| 14 |
+
"""Label-smoothing loss.
|
| 15 |
+
|
| 16 |
+
:param int size: the number of class
|
| 17 |
+
:param int padding_idx: ignored class id
|
| 18 |
+
:param float smoothing: smoothing rate (0.0 means the conventional CE)
|
| 19 |
+
:param bool normalize_length: normalize loss by sequence length if True
|
| 20 |
+
:param torch.nn.Module criterion: loss function to be smoothed
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
size,
|
| 26 |
+
padding_idx,
|
| 27 |
+
smoothing,
|
| 28 |
+
normalize_length=False,
|
| 29 |
+
criterion=nn.KLDivLoss(reduction="none"),
|
| 30 |
+
):
|
| 31 |
+
"""Construct an LabelSmoothingLoss object."""
|
| 32 |
+
super(LabelSmoothingLoss, self).__init__()
|
| 33 |
+
self.criterion = criterion
|
| 34 |
+
self.padding_idx = padding_idx
|
| 35 |
+
self.confidence = 1.0 - smoothing
|
| 36 |
+
self.smoothing = smoothing
|
| 37 |
+
self.size = size
|
| 38 |
+
self.true_dist = None
|
| 39 |
+
self.normalize_length = normalize_length
|
| 40 |
+
|
| 41 |
+
def forward(self, x, target):
|
| 42 |
+
"""Compute loss between x and target.
|
| 43 |
+
|
| 44 |
+
:param torch.Tensor x: prediction (batch, seqlen, class)
|
| 45 |
+
:param torch.Tensor target:
|
| 46 |
+
target signal masked with self.padding_id (batch, seqlen)
|
| 47 |
+
:return: scalar float value
|
| 48 |
+
:rtype torch.Tensor
|
| 49 |
+
"""
|
| 50 |
+
assert x.size(2) == self.size
|
| 51 |
+
batch_size = x.size(0)
|
| 52 |
+
x = x.view(-1, self.size)
|
| 53 |
+
target = target.view(-1)
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
true_dist = x.clone()
|
| 56 |
+
true_dist.fill_(self.smoothing / (self.size - 1))
|
| 57 |
+
ignore = target == self.padding_idx # (B,)
|
| 58 |
+
total = len(target) - ignore.sum().item()
|
| 59 |
+
target = target.masked_fill(ignore, 0) # avoid -1 index
|
| 60 |
+
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
| 61 |
+
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
| 62 |
+
denom = total if self.normalize_length else batch_size
|
| 63 |
+
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|