Shikhar commited on
Commit
84f8437
·
1 Parent(s): d876521

Deploy PhoneticXeus Gradio demo (CPU)

Browse files
README.md CHANGED
@@ -1,12 +1,22 @@
1
  ---
2
  title: PhoneticXeus
3
- emoji: 🚀
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.11.0
8
  app_file: app.py
9
  pinned: false
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
1
  ---
2
  title: PhoneticXeus
3
+ emoji: "\U0001F4DE"
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: "5.0"
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ models:
12
+ - changelinglab/PhoneticXeus
13
+ hardware: cpu-basic
14
  ---
15
 
16
+ # PhoneticXeus -- Multilingual Phone Recognition
17
+
18
+ Record or upload audio to get an IPA phone transcription.
19
+
20
+ Based on [PhoneticXeus](https://huggingface.co/changelinglab/PhoneticXeus), a multilingual phone recognition model using self-conditioned CTC on the XEUS speech encoder.
21
+
22
+ Paper: [An Empirical Recipe for Universal Phone Recognition](https://arxiv.org/abs/2603.29042)
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Ensure vendored src/ is importable
5
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import torchaudio
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ from src.model.xeusphoneme.builders import build_xeus_pr_inference
13
+
14
+ MAX_SECONDS = 60
15
+ SAMPLE_RATE = 16000
16
+
17
+ inference = None
18
+
19
+
20
+ def load_model():
21
+ ckpt = hf_hub_download(
22
+ "changelinglab/PhoneticXeus", "checkpoint-22000.ckpt"
23
+ )
24
+ vocab = os.path.join(
25
+ os.path.dirname(__file__),
26
+ "src", "model", "xeusphoneme", "resources", "ipa_vocab.json",
27
+ )
28
+ return build_xeus_pr_inference(
29
+ work_dir="/tmp/cache/xeus",
30
+ checkpoint=ckpt,
31
+ vocab_file=vocab,
32
+ hf_repo="espnet/xeus",
33
+ device="cpu",
34
+ )
35
+
36
+
37
+ def transcribe(audio_path):
38
+ """Run phone recognition on uploaded/recorded audio."""
39
+ global inference
40
+ if audio_path is None:
41
+ return "", ""
42
+
43
+ if inference is None:
44
+ inference = load_model()
45
+
46
+ waveform, sr = torchaudio.load(audio_path)
47
+ if sr != SAMPLE_RATE:
48
+ waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
49
+ waveform = waveform.mean(dim=0) # mono
50
+ waveform = waveform[: SAMPLE_RATE * MAX_SECONDS]
51
+
52
+ if waveform.numel() == 0:
53
+ return "", ""
54
+
55
+ results = inference(waveform)
56
+
57
+ processed = results[0]["processed_transcript"]
58
+ predicted = results[0]["predicted_transcript"]
59
+ spaced = " ".join(
60
+ t for t in predicted.split("/")
61
+ if not (t.startswith("<") and t.endswith(">"))
62
+ )
63
+ return spaced, processed
64
+
65
+
66
+ with gr.Blocks(title="PhoneticXeus") as demo:
67
+ gr.Markdown(
68
+ "# PhoneticXeus\n"
69
+ "Multilingual phone recognition -- record or upload audio "
70
+ "to get an IPA transcription.\n\n"
71
+ "Model: [changelinglab/PhoneticXeus]"
72
+ "(https://huggingface.co/changelinglab/PhoneticXeus) "
73
+ "| Paper: [arXiv 2603.29042]"
74
+ "(https://arxiv.org/abs/2603.29042)"
75
+ )
76
+
77
+ with gr.Row():
78
+ audio_input = gr.Audio(
79
+ sources=["microphone", "upload"],
80
+ type="filepath",
81
+ label="Input Audio",
82
+ )
83
+
84
+ btn = gr.Button("Transcribe", variant="primary")
85
+
86
+ with gr.Row():
87
+ phones_output = gr.Textbox(
88
+ label="IPA Phones (space-separated)",
89
+ lines=3,
90
+ show_copy_button=True,
91
+ )
92
+ raw_output = gr.Textbox(
93
+ label="Raw output (concatenated)",
94
+ lines=3,
95
+ show_copy_button=True,
96
+ )
97
+
98
+ btn.click(
99
+ fn=transcribe,
100
+ inputs=[audio_input],
101
+ outputs=[phones_output, raw_output],
102
+ )
103
+
104
+ gr.Markdown(
105
+ "---\n"
106
+ f"Max audio length: {MAX_SECONDS}s. "
107
+ "Audio is resampled to 16 kHz mono."
108
+ )
109
+
110
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ huggingface_hub
4
+ pyyaml
5
+ typeguard
6
+ packaging
7
+ numpy
8
+ gradio>=5.0
src/__init__.py ADDED
File without changes
src/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Core modules for PhoneticXeus."""
src/core/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from huggingface_hub import snapshot_download
3
+ from huggingface_hub.utils import LocalEntryNotFoundError
4
+
5
+
6
+ def download_hf_snapshot(
7
+ repo_id: str,
8
+ work_dir: str,
9
+ force_download: bool = False,
10
+ **kwargs,
11
+ ) -> str:
12
+ """Download a snapshot from Hugging Face Hub to `work_dir`.
13
+
14
+ Args:
15
+ repo_id: e.g. "espnet/xeus"
16
+ work_dir: path to local directory where to store snapshot
17
+ force_download: if True, enforce re-download
18
+ **kwargs: other snapshot_download arguments
19
+
20
+ Returns:
21
+ The path to the local snapshot folder
22
+ """
23
+ if force_download:
24
+ logging.info(
25
+ f"Force-downloading snapshot for {repo_id} into {work_dir}..."
26
+ )
27
+ path = snapshot_download(
28
+ repo_id=repo_id,
29
+ local_dir=work_dir,
30
+ force_download=True,
31
+ local_files_only=False,
32
+ **kwargs,
33
+ )
34
+ logging.info(f"Downloaded snapshot for {repo_id} to {path}")
35
+ return path
36
+
37
+ try:
38
+ path = snapshot_download(
39
+ repo_id=repo_id,
40
+ local_dir=work_dir,
41
+ local_files_only=True,
42
+ **kwargs,
43
+ )
44
+ logging.info(
45
+ f"Using existing local snapshot for {repo_id} at {path}"
46
+ )
47
+ return path
48
+ except LocalEntryNotFoundError:
49
+ logging.info(
50
+ f"No local snapshot found for {repo_id}. Downloading now..."
51
+ )
52
+ path = snapshot_download(
53
+ repo_id=repo_id,
54
+ local_dir=work_dir,
55
+ local_files_only=False,
56
+ **kwargs,
57
+ )
58
+ logging.info(f"Downloaded snapshot for {repo_id} to {path}")
59
+ return path
src/espnet_import/__init__.py ADDED
File without changes
src/espnet_import/attention.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Multi-Head Attention layer definition."""
8
+
9
+ import logging
10
+ import math
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from src.espnet_import.layer_norm import LayerNorm
16
+
17
+ try:
18
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
19
+ from flash_attn.bert_padding import pad_input, unpad_input
20
+ except Exception:
21
+ pass
22
+
23
+
24
+ class MultiHeadedAttention(nn.Module):
25
+ """Multi-Head Attention layer.
26
+
27
+ Args:
28
+ n_head (int): The number of heads.
29
+ n_feat (int): The number of features.
30
+ dropout_rate (float): Dropout rate.
31
+ qk_norm (bool): Normalize q and k before dot product.
32
+ use_flash_attn (bool): Use flash_attn implementation.
33
+ causal (bool): Apply causal attention.
34
+ cross_attn (bool): Cross attention instead of self attention.
35
+ use_sdpa (bool): Use PyTorch's scaled dot product attention.
36
+
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ n_head,
42
+ n_feat,
43
+ dropout_rate,
44
+ qk_norm=False,
45
+ use_flash_attn=False,
46
+ causal=False,
47
+ cross_attn=False,
48
+ use_sdpa=False,
49
+ ):
50
+ """Construct an MultiHeadedAttention object."""
51
+ super(MultiHeadedAttention, self).__init__()
52
+
53
+ assert n_feat % n_head == 0
54
+ # We assume d_v always equals d_k
55
+ self.d_k = n_feat // n_head
56
+ self.h = n_head
57
+ self.linear_q = nn.Linear(n_feat, n_feat)
58
+ self.linear_k = nn.Linear(n_feat, n_feat)
59
+ self.linear_v = nn.Linear(n_feat, n_feat)
60
+ self.linear_out = nn.Linear(n_feat, n_feat)
61
+ self.attn = None
62
+ self.dropout = (
63
+ nn.Dropout(p=dropout_rate) if not use_flash_attn else nn.Identity()
64
+ )
65
+ self.dropout_rate = dropout_rate
66
+
67
+ # LayerNorm for q and k
68
+ self.q_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity()
69
+ self.k_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity()
70
+
71
+ self.use_flash_attn = use_flash_attn
72
+ self.causal = causal # only used with flash_attn
73
+ self.cross_attn = cross_attn # only used with flash_attn
74
+
75
+ self.use_sdpa = use_sdpa
76
+
77
+ def forward_qkv(self, query, key, value, expand_kv=False):
78
+ """Transform query, key and value.
79
+
80
+ Args:
81
+ query (torch.Tensor): Query tensor (#batch, time1, size).
82
+ key (torch.Tensor): Key tensor (#batch, time2, size).
83
+ value (torch.Tensor): Value tensor (#batch, time2, size).
84
+ expand_kv (bool): Used only for partially autoregressive (PAR) decoding.
85
+
86
+ Returns:
87
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
88
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
89
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
90
+
91
+ """
92
+ n_batch = query.size(0)
93
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
94
+
95
+ if expand_kv:
96
+ k_shape = key.shape
97
+ k = (
98
+ self.linear_k(key[:1, :, :])
99
+ .expand(n_batch, k_shape[1], k_shape[2])
100
+ .view(n_batch, -1, self.h, self.d_k)
101
+ )
102
+ v_shape = value.shape
103
+ v = (
104
+ self.linear_v(value[:1, :, :])
105
+ .expand(n_batch, v_shape[1], v_shape[2])
106
+ .view(n_batch, -1, self.h, self.d_k)
107
+ )
108
+ else:
109
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
110
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
111
+
112
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
113
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
114
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
115
+
116
+ q = self.q_norm(q)
117
+ k = self.k_norm(k)
118
+
119
+ return q, k, v
120
+
121
+ def forward_attention(self, value, scores, mask):
122
+ """Compute attention context vector.
123
+
124
+ Args:
125
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
126
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
127
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
128
+
129
+ Returns:
130
+ torch.Tensor: Transformed value (#batch, time1, d_model)
131
+ weighted by the attention score (#batch, time1, time2).
132
+
133
+ """
134
+ n_batch = value.size(0)
135
+ if mask is not None:
136
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
137
+ min_value = torch.finfo(scores.dtype).min
138
+ scores = scores.masked_fill(mask, min_value)
139
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
140
+ mask, 0.0
141
+ ) # (batch, head, time1, time2)
142
+ else:
143
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
144
+
145
+ p_attn = self.dropout(self.attn)
146
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
147
+ x = (
148
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
149
+ ) # (batch, time1, d_model)
150
+
151
+ return self.linear_out(x) # (batch, time1, d_model)
152
+
153
+ def forward(self, query, key, value, mask, expand_kv=False):
154
+ """Compute scaled dot product attention.
155
+
156
+ Args:
157
+ query (torch.Tensor): Query tensor (#batch, time1, size).
158
+ key (torch.Tensor): Key tensor (#batch, time2, size).
159
+ value (torch.Tensor): Value tensor (#batch, time2, size).
160
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
161
+ (#batch, time1, time2).
162
+ expand_kv (bool): Used only for partially autoregressive (PAR) decoding.
163
+ When set to `True`, `Linear` layers are computed only for the first
164
+ batch. This is useful to reduce the memory usage during decoding
165
+ when the batch size is #beam_size x #mask_count, which can be large.
166
+ Typically, in single waveform inference of PAR, `Linear` layers
167
+ should not be computed for all batches for source-attention.
168
+
169
+ Returns:
170
+ torch.Tensor: Output tensor (#batch, time1, d_model).
171
+ """
172
+ # Use PyTorch's Scaled Dot Product Attention implementation
173
+ if getattr(self, "use_sdpa", False):
174
+ q, k, v = self.forward_qkv(query, key, value, expand_kv)
175
+
176
+ # The shape of mask must be broadcastable to the shape of attention weights
177
+ out = torch.nn.functional.scaled_dot_product_attention(
178
+ q,
179
+ k,
180
+ v,
181
+ mask.unsqueeze(1) if mask is not None else None,
182
+ dropout_p=self.dropout_rate if self.training else 0.0,
183
+ ) # (batch, head, time1, d_k)
184
+
185
+ out = out.transpose(1, 2) # (batch, time1, head, d_k)
186
+ out = out.reshape(out.shape[0], out.shape[1], -1) # (batch, time1, d_model)
187
+ return self.linear_out(out) # (batch, time1, d_model)
188
+
189
+ # Use Flash Attention implementation
190
+ if self.use_flash_attn:
191
+ try:
192
+ # In the causal case, the last row will be the key mask
193
+ key_nonpad_mask = mask[:, -1, :] # (#batch, time2)
194
+ if self.cross_attn:
195
+ # For cross attention, we do not know the query padding
196
+ query_nonpad_mask = torch.ones(
197
+ size=query.shape[:2], dtype=torch.bool, device=query.device
198
+ )
199
+ else:
200
+ query_nonpad_mask = key_nonpad_mask
201
+
202
+ if key_nonpad_mask.eq(0).any():
203
+ # Use variable length implementation if padded
204
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
205
+ query, query_nonpad_mask
206
+ )[:4]
207
+ k, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(
208
+ key, key_nonpad_mask
209
+ )[:4]
210
+ v, _, _, _ = unpad_input(value, key_nonpad_mask)[:4]
211
+
212
+ q = self.linear_q(q).reshape(-1, self.h, self.d_k)
213
+ k = self.linear_k(k).reshape(-1, self.h, self.d_k)
214
+ v = self.linear_v(v).reshape(-1, self.h, self.d_k)
215
+
216
+ q = self.q_norm(q)
217
+ k = self.k_norm(k)
218
+
219
+ out = flash_attn_varlen_func(
220
+ q,
221
+ k,
222
+ v,
223
+ cu_seqlens_q,
224
+ cu_seqlens_k,
225
+ max_seqlen_q,
226
+ max_seqlen_k,
227
+ dropout_p=self.dropout_rate if self.training else 0.0,
228
+ causal=self.causal,
229
+ ) # (total, nheads, headdim)
230
+
231
+ out = out.reshape(out.shape[0], -1)
232
+ out = self.linear_out(out)
233
+
234
+ out = pad_input(out, indices_q, query.shape[0], query.shape[1])
235
+ return out
236
+
237
+ else:
238
+ # Use fixed length implementation if not padded,
239
+ # which is faster than the variable length implementation
240
+ del key_nonpad_mask
241
+ q, k, v = self.forward_qkv(query, key, value)
242
+
243
+ out = flash_attn_func(
244
+ q.transpose(1, 2),
245
+ k.transpose(1, 2),
246
+ v.transpose(1, 2),
247
+ dropout_p=self.dropout_rate if self.training else 0.0,
248
+ causal=self.causal,
249
+ ) # (batch_size, seqlen, nheads, headdim)
250
+ del q, k, v
251
+
252
+ out = out.reshape(out.shape[0], out.shape[1], -1)
253
+ out = self.linear_out(out)
254
+ return out
255
+
256
+ except Exception as e:
257
+ pass
258
+ self.use_flash_attn = False
259
+
260
+ # Fall back to the default implementation
261
+ q, k, v = self.forward_qkv(query, key, value, expand_kv)
262
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
263
+ return self.forward_attention(v, scores, mask)
264
+
265
+
266
+ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
267
+ """Multi-Head Attention layer with relative position encoding (old version).
268
+
269
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
270
+
271
+ Paper: https://arxiv.org/abs/1901.02860
272
+
273
+ Args:
274
+ n_head (int): The number of heads.
275
+ n_feat (int): The number of features.
276
+ dropout_rate (float): Dropout rate.
277
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
278
+
279
+ """
280
+
281
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
282
+ """Construct an RelPositionMultiHeadedAttention object."""
283
+ super().__init__(n_head, n_feat, dropout_rate)
284
+ self.zero_triu = zero_triu
285
+ # linear transformation for positional encoding
286
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
287
+ # these two learnable bias are used in matrix c and matrix d
288
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
289
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
290
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
291
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
292
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
293
+
294
+ def rel_shift(self, x):
295
+ """Compute relative positional encoding.
296
+
297
+ Args:
298
+ x (torch.Tensor): Input tensor (batch, head, time1, time2).
299
+
300
+ Returns:
301
+ torch.Tensor: Output tensor.
302
+
303
+ """
304
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
305
+ x_padded = torch.cat([zero_pad, x], dim=-1)
306
+
307
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
308
+ x = x_padded[:, :, 1:].view_as(x)
309
+
310
+ if self.zero_triu:
311
+ ones = torch.ones((x.size(2), x.size(3)))
312
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
313
+
314
+ return x
315
+
316
+ def forward(self, query, key, value, pos_emb, mask):
317
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
318
+
319
+ Args:
320
+ query (torch.Tensor): Query tensor (#batch, time1, size).
321
+ key (torch.Tensor): Key tensor (#batch, time2, size).
322
+ value (torch.Tensor): Value tensor (#batch, time2, size).
323
+ pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
324
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
325
+ (#batch, time1, time2).
326
+
327
+ Returns:
328
+ torch.Tensor: Output tensor (#batch, time1, d_model).
329
+
330
+ """
331
+ q, k, v = self.forward_qkv(query, key, value)
332
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
333
+
334
+ n_batch_pos = pos_emb.size(0)
335
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
336
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
337
+
338
+ # (batch, head, time1, d_k)
339
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
340
+ # (batch, head, time1, d_k)
341
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
342
+
343
+ # compute attention score
344
+ # first compute matrix a and matrix c
345
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
346
+ # (batch, head, time1, time2)
347
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
348
+
349
+ # compute matrix b and matrix d
350
+ # (batch, head, time1, time1)
351
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
352
+ matrix_bd = self.rel_shift(matrix_bd)
353
+
354
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
355
+ self.d_k
356
+ ) # (batch, head, time1, time2)
357
+
358
+ return self.forward_attention(v, scores, mask)
359
+
360
+
361
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
362
+ """Multi-Head Attention layer with relative position encoding (new implementation).
363
+
364
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
365
+
366
+ Paper: https://arxiv.org/abs/1901.02860
367
+
368
+ Args:
369
+ n_head (int): The number of heads.
370
+ n_feat (int): The number of features.
371
+ dropout_rate (float): Dropout rate.
372
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
373
+
374
+ """
375
+
376
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
377
+ """Construct an RelPositionMultiHeadedAttention object."""
378
+ super().__init__(n_head, n_feat, dropout_rate)
379
+ self.zero_triu = zero_triu
380
+ # linear transformation for positional encoding
381
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
382
+ # these two learnable bias are used in matrix c and matrix d
383
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
384
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
385
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
386
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
387
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
388
+
389
+ def rel_shift(self, x):
390
+ """Compute relative positional encoding.
391
+
392
+ Args:
393
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
394
+ time1 means the length of query vector.
395
+
396
+ Returns:
397
+ torch.Tensor: Output tensor.
398
+
399
+ """
400
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
401
+ x_padded = torch.cat([zero_pad, x], dim=-1)
402
+
403
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
404
+ x = x_padded[:, :, 1:].view_as(x)[
405
+ :, :, :, : x.size(-1) // 2 + 1
406
+ ] # only keep the positions from 0 to time2
407
+
408
+ if self.zero_triu:
409
+ ones = torch.ones((x.size(2), x.size(3)), device=x.device)
410
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
411
+
412
+ return x
413
+
414
+ def forward(self, query, key, value, pos_emb, mask):
415
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
416
+
417
+ Args:
418
+ query (torch.Tensor): Query tensor (#batch, time1, size).
419
+ key (torch.Tensor): Key tensor (#batch, time2, size).
420
+ value (torch.Tensor): Value tensor (#batch, time2, size).
421
+ pos_emb (torch.Tensor): Positional embedding tensor
422
+ (#batch, 2*time1-1, size).
423
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
424
+ (#batch, time1, time2).
425
+
426
+ Returns:
427
+ torch.Tensor: Output tensor (#batch, time1, d_model).
428
+
429
+ """
430
+ q, k, v = self.forward_qkv(query, key, value)
431
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
432
+
433
+ n_batch_pos = pos_emb.size(0)
434
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
435
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
436
+
437
+ # (batch, head, time1, d_k)
438
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
439
+ # (batch, head, time1, d_k)
440
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
441
+
442
+ # compute attention score
443
+ # first compute matrix a and matrix c
444
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
445
+ # (batch, head, time1, time2)
446
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
447
+
448
+ # compute matrix b and matrix d
449
+ # (batch, head, time1, 2*time1-1)
450
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
451
+ matrix_bd = self.rel_shift(matrix_bd)
452
+
453
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
454
+ self.d_k
455
+ ) # (batch, head, time1, time2)
456
+
457
+ return self.forward_attention(v, scores, mask)
src/espnet_import/cgmlp.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MLP with convolutional gating (cgMLP) definition.
2
+
3
+ References:
4
+ https://openreview.net/forum?id=RA-zVvZLYIy
5
+ https://arxiv.org/abs/2105.08050
6
+
7
+ """
8
+
9
+ import torch
10
+
11
+ from src.espnet_import.nets_utils import get_activation
12
+ from src.espnet_import.layer_norm import LayerNorm
13
+
14
+
15
+ class ConvolutionalSpatialGatingUnit(torch.nn.Module):
16
+ """Convolutional Spatial Gating Unit (CSGU)."""
17
+
18
+ def __init__(
19
+ self,
20
+ size: int,
21
+ kernel_size: int,
22
+ dropout_rate: float,
23
+ use_linear_after_conv: bool,
24
+ gate_activation: str,
25
+ ):
26
+ super().__init__()
27
+
28
+ n_channels = size // 2 # split input channels
29
+ self.norm = LayerNorm(n_channels)
30
+ self.conv = torch.nn.Conv1d(
31
+ n_channels,
32
+ n_channels,
33
+ kernel_size,
34
+ 1,
35
+ (kernel_size - 1) // 2,
36
+ groups=n_channels,
37
+ )
38
+ if use_linear_after_conv:
39
+ self.linear = torch.nn.Linear(n_channels, n_channels)
40
+ else:
41
+ self.linear = None
42
+
43
+ if gate_activation == "identity":
44
+ self.act = torch.nn.Identity()
45
+ else:
46
+ self.act = get_activation(gate_activation)
47
+
48
+ self.dropout = torch.nn.Dropout(dropout_rate)
49
+
50
+ def espnet_initialization_fn(self):
51
+ torch.nn.init.normal_(self.conv.weight, std=1e-6)
52
+ torch.nn.init.ones_(self.conv.bias)
53
+ if self.linear is not None:
54
+ torch.nn.init.normal_(self.linear.weight, std=1e-6)
55
+ torch.nn.init.ones_(self.linear.bias)
56
+
57
+ def forward(self, x, gate_add=None):
58
+ """Forward method
59
+
60
+ Args:
61
+ x (torch.Tensor): (N, T, D)
62
+ gate_add (torch.Tensor): (N, T, D/2)
63
+
64
+ Returns:
65
+ out (torch.Tensor): (N, T, D/2)
66
+ """
67
+ x_r, x_g = x.chunk(2, dim=-1)
68
+
69
+ x_g = self.norm(x_g) # (N, T, D/2)
70
+ x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
71
+ if self.linear is not None:
72
+ x_g = self.linear(x_g)
73
+
74
+ if gate_add is not None:
75
+ x_g = x_g + gate_add
76
+
77
+ x_g = self.act(x_g)
78
+ out = x_r * x_g # (N, T, D/2)
79
+ out = self.dropout(out)
80
+ return out
81
+
82
+
83
+ class ConvolutionalGatingMLP(torch.nn.Module):
84
+ """Convolutional Gating MLP (cgMLP)."""
85
+
86
+ def __init__(
87
+ self,
88
+ size: int,
89
+ linear_units: int,
90
+ kernel_size: int,
91
+ dropout_rate: float,
92
+ use_linear_after_conv: bool,
93
+ gate_activation: str,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.channel_proj1 = torch.nn.Sequential(
98
+ torch.nn.Linear(size, linear_units), torch.nn.GELU()
99
+ )
100
+ self.csgu = ConvolutionalSpatialGatingUnit(
101
+ size=linear_units,
102
+ kernel_size=kernel_size,
103
+ dropout_rate=dropout_rate,
104
+ use_linear_after_conv=use_linear_after_conv,
105
+ gate_activation=gate_activation,
106
+ )
107
+ self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)
108
+
109
+ def forward(self, x, mask):
110
+ if isinstance(x, tuple):
111
+ xs_pad, pos_emb = x
112
+ else:
113
+ xs_pad, pos_emb = x, None
114
+
115
+ xs_pad = self.channel_proj1(xs_pad) # size -> linear_units
116
+ xs_pad = self.csgu(xs_pad) # linear_units -> linear_units/2
117
+ xs_pad = self.channel_proj2(xs_pad) # linear_units/2 -> size
118
+
119
+ if pos_emb is not None:
120
+ out = (xs_pad, pos_emb)
121
+ else:
122
+ out = xs_pad
123
+ return out
src/espnet_import/embedding.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Positional Encoding Module."""
8
+
9
+ import logging
10
+ import math
11
+
12
+ import torch
13
+ from packaging.version import parse as V
14
+
15
+
16
+ # from espnet2.asr.frontend.cnn import dim_1_layer_norm
17
+ def dim_1_layer_norm(x, eps=1e-05, gamma=None, beta=None):
18
+ """Functional version of Dim1LayerNorm."""
19
+
20
+ B, D, T = x.shape
21
+ mean = torch.mean(x, 1, keepdim=True)
22
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
23
+
24
+ x = (x - mean) * torch.rsqrt(variance + eps)
25
+
26
+ if gamma is not None:
27
+ x = x * gamma.view(1, -1, 1)
28
+ if beta is not None:
29
+ x = x + beta.view(1, -1, 1)
30
+ return x
31
+
32
+
33
+ def _pre_hook(
34
+ state_dict,
35
+ prefix,
36
+ local_metadata,
37
+ strict,
38
+ missing_keys,
39
+ unexpected_keys,
40
+ error_msgs,
41
+ ):
42
+ """Perform pre-hook in load_state_dict for backward compatibility.
43
+
44
+ Note:
45
+ We saved self.pe until v.0.5.2 but we have omitted it later.
46
+ Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
47
+
48
+ """
49
+ k = prefix + "pe"
50
+ if k in state_dict:
51
+ state_dict.pop(k)
52
+
53
+
54
+ class PositionalEncoding(torch.nn.Module):
55
+ """Positional encoding.
56
+
57
+ Args:
58
+ d_model (int): Embedding dimension.
59
+ dropout_rate (float): Dropout rate.
60
+ max_len (int): Maximum input length.
61
+ reverse (bool): Whether to reverse the input position. Only for
62
+ the class LegacyRelPositionalEncoding. We remove it in the current
63
+ class RelPositionalEncoding.
64
+ """
65
+
66
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
67
+ """Construct an PositionalEncoding object."""
68
+ super(PositionalEncoding, self).__init__()
69
+ self.d_model = d_model
70
+ self.reverse = reverse
71
+ self.xscale = math.sqrt(self.d_model)
72
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
73
+ self.pe = None
74
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
75
+ self._register_load_state_dict_pre_hook(_pre_hook)
76
+
77
+ def extend_pe(self, x):
78
+ """Reset the positional encodings."""
79
+ if self.pe is not None:
80
+ if self.pe.size(1) >= x.size(1):
81
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
82
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
83
+ return
84
+ pe = torch.zeros(x.size(1), self.d_model)
85
+ if self.reverse:
86
+ position = torch.arange(
87
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
88
+ ).unsqueeze(1)
89
+ else:
90
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
91
+ div_term = torch.exp(
92
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
93
+ * -(math.log(10000.0) / self.d_model)
94
+ )
95
+ pe[:, 0::2] = torch.sin(position * div_term)
96
+ pe[:, 1::2] = torch.cos(position * div_term)
97
+ pe = pe.unsqueeze(0)
98
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
99
+
100
+ def forward(self, x: torch.Tensor):
101
+ """Add positional encoding.
102
+
103
+ Args:
104
+ x (torch.Tensor): Input tensor (batch, time, `*`).
105
+
106
+ Returns:
107
+ torch.Tensor: Encoded tensor (batch, time, `*`).
108
+ """
109
+ self.extend_pe(x)
110
+ x = x * self.xscale + self.pe[:, : x.size(1)]
111
+ return self.dropout(x)
112
+
113
+
114
+ class ScaledPositionalEncoding(PositionalEncoding):
115
+ """Scaled positional encoding module.
116
+
117
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
118
+
119
+ Args:
120
+ d_model (int): Embedding dimension.
121
+ dropout_rate (float): Dropout rate.
122
+ max_len (int): Maximum input length.
123
+
124
+ """
125
+
126
+ def __init__(self, d_model, dropout_rate, max_len=5000):
127
+ """Initialize class."""
128
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
129
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
130
+
131
+ def reset_parameters(self):
132
+ """Reset parameters."""
133
+ self.alpha.data = torch.tensor(1.0)
134
+
135
+ def forward(self, x):
136
+ """Add positional encoding.
137
+
138
+ Args:
139
+ x (torch.Tensor): Input tensor (batch, time, `*`).
140
+
141
+ Returns:
142
+ torch.Tensor: Encoded tensor (batch, time, `*`).
143
+
144
+ """
145
+ self.extend_pe(x)
146
+ x = x + self.alpha * self.pe[:, : x.size(1)]
147
+ return self.dropout(x)
148
+
149
+
150
+ class LearnableFourierPosEnc(torch.nn.Module):
151
+ """Learnable Fourier Features for Positional Encoding.
152
+
153
+ See https://arxiv.org/pdf/2106.02795.pdf
154
+
155
+ Args:
156
+ d_model (int): Embedding dimension.
157
+ dropout_rate (float): Dropout rate.
158
+ max_len (int): Maximum input length.
159
+ gamma (float): init parameter for the positional kernel variance
160
+ see https://arxiv.org/pdf/2106.02795.pdf.
161
+ apply_scaling (bool): Whether to scale the input before adding the pos encoding.
162
+ hidden_dim (int): if not None, we modulate the pos encodings with
163
+ an MLP whose hidden layer has hidden_dim neurons.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ d_model,
169
+ dropout_rate=0.0,
170
+ max_len=5000,
171
+ gamma=1.0,
172
+ apply_scaling=False,
173
+ hidden_dim=None,
174
+ ):
175
+ """Initialize class."""
176
+ super(LearnableFourierPosEnc, self).__init__()
177
+
178
+ self.d_model = d_model
179
+
180
+ if apply_scaling:
181
+ self.xscale = math.sqrt(self.d_model)
182
+ else:
183
+ self.xscale = 1.0
184
+
185
+ self.dropout = torch.nn.Dropout(dropout_rate)
186
+ self.max_len = max_len
187
+
188
+ self.gamma = gamma
189
+ if self.gamma is None:
190
+ self.gamma = self.d_model // 2
191
+
192
+ assert (
193
+ d_model % 2 == 0
194
+ ), "d_model should be divisible by two in order to use this layer."
195
+ self.w_r = torch.nn.Parameter(torch.empty(1, d_model // 2))
196
+ self._reset() # init the weights
197
+
198
+ self.hidden_dim = hidden_dim
199
+ if self.hidden_dim is not None:
200
+ self.mlp = torch.nn.Sequential(
201
+ torch.nn.Linear(d_model, hidden_dim),
202
+ torch.nn.GELU(),
203
+ torch.nn.Linear(hidden_dim, d_model),
204
+ )
205
+
206
+ def _reset(self):
207
+ self.w_r.data = torch.normal(
208
+ 0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2)
209
+ )
210
+
211
+ def extend_pe(self, x):
212
+ """Reset the positional encodings."""
213
+ position_v = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1).to(x)
214
+
215
+ cosine = torch.cos(torch.matmul(position_v, self.w_r))
216
+ sine = torch.sin(torch.matmul(position_v, self.w_r))
217
+ pos_enc = torch.cat((cosine, sine), -1)
218
+ pos_enc /= math.sqrt(self.d_model)
219
+
220
+ if self.hidden_dim is None:
221
+ return pos_enc.unsqueeze(0)
222
+ else:
223
+ return self.mlp(pos_enc.unsqueeze(0))
224
+
225
+ def forward(self, x: torch.Tensor):
226
+ """Add positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, time, `*`).
230
+
231
+ Returns:
232
+ torch.Tensor: Encoded tensor (batch, time, `*`).
233
+ """
234
+ pe = self.extend_pe(x)
235
+ x = x * self.xscale + pe
236
+ return self.dropout(x)
237
+
238
+
239
+ class LegacyRelPositionalEncoding(PositionalEncoding):
240
+ """Relative positional encoding module (old version).
241
+
242
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
243
+
244
+ See : Appendix B in https://arxiv.org/abs/1901.02860
245
+
246
+ Args:
247
+ d_model (int): Embedding dimension.
248
+ dropout_rate (float): Dropout rate.
249
+ max_len (int): Maximum input length.
250
+
251
+ """
252
+
253
+ def __init__(self, d_model, dropout_rate, max_len=5000):
254
+ """Initialize class."""
255
+ super().__init__(
256
+ d_model=d_model,
257
+ dropout_rate=dropout_rate,
258
+ max_len=max_len,
259
+ reverse=True,
260
+ )
261
+
262
+ def forward(self, x):
263
+ """Compute positional encoding.
264
+
265
+ Args:
266
+ x (torch.Tensor): Input tensor (batch, time, `*`).
267
+
268
+ Returns:
269
+ torch.Tensor: Encoded tensor (batch, time, `*`).
270
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
271
+
272
+ """
273
+ self.extend_pe(x)
274
+ x = x * self.xscale
275
+ pos_emb = self.pe[:, : x.size(1)]
276
+ return self.dropout(x), self.dropout(pos_emb)
277
+
278
+
279
+ class RelPositionalEncoding(torch.nn.Module):
280
+ """Relative positional encoding module (new implementation).
281
+
282
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
283
+
284
+ See : Appendix B in https://arxiv.org/abs/1901.02860
285
+
286
+ Args:
287
+ d_model (int): Embedding dimension.
288
+ dropout_rate (float): Dropout rate.
289
+ max_len (int): Maximum input length.
290
+
291
+ """
292
+
293
+ def __init__(self, d_model, dropout_rate, max_len=5000):
294
+ """Construct an PositionalEncoding object."""
295
+ super(RelPositionalEncoding, self).__init__()
296
+ self.d_model = d_model
297
+ self.xscale = math.sqrt(self.d_model)
298
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
299
+ self.pe = None
300
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
301
+
302
+ def extend_pe(self, x):
303
+ """Reset the positional encodings."""
304
+ if self.pe is not None:
305
+ # self.pe contains both positive and negative parts
306
+ # the length of self.pe is 2 * input_len - 1
307
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
308
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
309
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
310
+ return
311
+ # Suppose `i` means to the position of query vecotr and `j` means the
312
+ # position of key vector. We use position relative positions when keys
313
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
314
+ pe_positive = torch.zeros(x.size(1), self.d_model)
315
+ pe_negative = torch.zeros(x.size(1), self.d_model)
316
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
317
+ div_term = torch.exp(
318
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
319
+ * -(math.log(10000.0) / self.d_model)
320
+ )
321
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
322
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
323
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
324
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
325
+
326
+ # Reserve the order of positive indices and concat both positive and
327
+ # negative indices. This is used to support the shifting trick
328
+ # as in https://arxiv.org/abs/1901.02860
329
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
330
+ pe_negative = pe_negative[1:].unsqueeze(0)
331
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
332
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
333
+
334
+ def forward(self, x: torch.Tensor):
335
+ """Add positional encoding.
336
+
337
+ Args:
338
+ x (torch.Tensor): Input tensor (batch, time, `*`).
339
+
340
+ Returns:
341
+ torch.Tensor: Encoded tensor (batch, time, `*`).
342
+
343
+ """
344
+ self.extend_pe(x)
345
+ x = x * self.xscale
346
+ pos_emb = self.pe[
347
+ :,
348
+ self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
349
+ ]
350
+ return self.dropout(x), self.dropout(pos_emb)
351
+
352
+
353
+ class StreamPositionalEncoding(torch.nn.Module):
354
+ """Streaming Positional encoding.
355
+
356
+ Args:
357
+ d_model (int): Embedding dimension.
358
+ dropout_rate (float): Dropout rate.
359
+ max_len (int): Maximum input length.
360
+
361
+ """
362
+
363
+ def __init__(self, d_model, dropout_rate, max_len=5000):
364
+ """Construct an PositionalEncoding object."""
365
+ super(StreamPositionalEncoding, self).__init__()
366
+ self.d_model = d_model
367
+ self.xscale = math.sqrt(self.d_model)
368
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
369
+ self.pe = None
370
+ self.tmp = torch.tensor(0.0).expand(1, max_len)
371
+ self.extend_pe(self.tmp.size(1), self.tmp.device, self.tmp.dtype)
372
+ self._register_load_state_dict_pre_hook(_pre_hook)
373
+
374
+ def extend_pe(self, length, device, dtype):
375
+ """Reset the positional encodings."""
376
+ if self.pe is not None:
377
+ if self.pe.size(1) >= length:
378
+ if self.pe.dtype != dtype or self.pe.device != device:
379
+ self.pe = self.pe.to(dtype=dtype, device=device)
380
+ return
381
+ pe = torch.zeros(length, self.d_model)
382
+ position = torch.arange(0, length, dtype=torch.float32).unsqueeze(1)
383
+ div_term = torch.exp(
384
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
385
+ * -(math.log(10000.0) / self.d_model)
386
+ )
387
+ pe[:, 0::2] = torch.sin(position * div_term)
388
+ pe[:, 1::2] = torch.cos(position * div_term)
389
+ pe = pe.unsqueeze(0)
390
+ self.pe = pe.to(device=device, dtype=dtype)
391
+
392
+ def forward(self, x: torch.Tensor, start_idx: int = 0):
393
+ """Add positional encoding.
394
+
395
+ Args:
396
+ x (torch.Tensor): Input tensor (batch, time, `*`).
397
+
398
+ Returns:
399
+ torch.Tensor: Encoded tensor (batch, time, `*`).
400
+
401
+ """
402
+ self.extend_pe(x.size(1) + start_idx, x.device, x.dtype)
403
+ x = x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
404
+ return self.dropout(x)
405
+
406
+
407
+ class ConvolutionalPositionalEmbedding(torch.nn.Module):
408
+ """Convolutional positional embedding.
409
+
410
+ Used in wav2vec2/HuBERT SSL models.
411
+ https://arxiv.org/abs/1904.11660
412
+
413
+ Args:
414
+ embed_dim (int): Feature dimension of the input Tensor.
415
+ dropout (float): unused
416
+ max_len (int): unused
417
+ num_layers (int): number of conv layers
418
+ kernel_size (int): The number of frames to be use.
419
+ groups (int): The number of groups in feature dimensions.
420
+ weight_norm (str): [new, legacy, none].
421
+ How to init conv weights. Recommended setting is
422
+ none if num_layers > 1.
423
+ """
424
+
425
+ def __init__(
426
+ self,
427
+ embed_dim: int,
428
+ dropout: float,
429
+ max_len: int = 5000,
430
+ num_layers: int = 1,
431
+ kernel_size: int = 128,
432
+ groups: int = 16,
433
+ weight_norm: str = "new",
434
+ use_residual: bool = False,
435
+ ):
436
+ """Initialize Convoluational Positional Embedding."""
437
+ super().__init__()
438
+ self.embed_dim = embed_dim
439
+ self.kernel_size = kernel_size
440
+ self.weight_norm = weight_norm
441
+
442
+ convs = []
443
+ for layer in range(num_layers):
444
+ conv = torch.nn.Conv1d(
445
+ in_channels=embed_dim,
446
+ out_channels=embed_dim,
447
+ kernel_size=kernel_size,
448
+ padding=kernel_size // 2,
449
+ groups=groups,
450
+ )
451
+ if weight_norm != "none" and weight_norm is not None:
452
+ std = math.sqrt((4 * (1.0)) / (kernel_size * embed_dim))
453
+ torch.nn.init.normal_(conv.weight, mean=0, std=std)
454
+ torch.nn.init.constant_(conv.bias, 0)
455
+ # torch.nn.utils.weight_norm leads to weird behavior
456
+ # with copy.deepcopy(). Usually isnt needed,
457
+ # but its important for models that use EMA
458
+ if weight_norm == "new":
459
+ if V(torch.__version__) >= V("2.2.0"):
460
+ conv = torch.nn.utils.parametrizations.weight_norm(
461
+ conv, name="weight", dim=2
462
+ )
463
+ else:
464
+ weight_norm = "legacy"
465
+ logging.warning(
466
+ "torch.nn.utils.parametrizations.weight_norm is only "
467
+ + "supported for pytorch versions >= 2.2.0. "
468
+ + "Defaulting to torch.nn.utils.weight_norm."
469
+ )
470
+ if weight_norm == "legacy":
471
+ conv = torch.nn.utils.weight_norm(conv, name="weight", dim=2)
472
+ convs.append(conv)
473
+ self.convs = torch.nn.ModuleList(convs)
474
+ self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
475
+ self.use_residual = use_residual
476
+
477
+ def __prepare_scriptable__(self):
478
+ """Prepare Scriptable method."""
479
+ for hook in self.conv._forward_pre_hooks.values():
480
+ # The hook we want to remove is an instance of WeightNorm class, so
481
+ # normally we would do `if isinstance(...)` but this class is not accessible
482
+ # because of shadowing, so we check the module name directly.
483
+ # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
484
+ if (
485
+ hook.__module__ == "torch.nn.utils.weight_norm"
486
+ and hook.__class__.__name__ == "WeightNorm"
487
+ ):
488
+ logging.warning("Removing weight_norm from %s", self.__class__.__name__)
489
+ torch.nn.utils.remove_weight_norm(self.conv)
490
+ return self
491
+
492
+ def forward(self, x):
493
+ """Forward Method.
494
+
495
+ Args:
496
+ x (Tensor): shape ``[batch, frame, feature]``.
497
+
498
+ Returns:
499
+ Tensor: The resulting feature. Shape ``[batch, frame, feature]``.
500
+ """
501
+ if self.use_residual:
502
+ residual = x
503
+
504
+ x = x.transpose(-2, -1)
505
+ for conv in self.convs:
506
+ x = conv(x)
507
+
508
+ # remove extra padding
509
+ if self.num_remove > 0:
510
+ x = x[..., : -self.num_remove]
511
+
512
+ x = torch.nn.functional.gelu(x)
513
+
514
+ # manually normalize if the conv is not parameterized
515
+ # with weight norm
516
+ if self.weight_norm is None or self.weight_norm == "none":
517
+ x = dim_1_layer_norm(x)
518
+
519
+ x = x.transpose(-2, -1)
520
+
521
+ if self.use_residual:
522
+ x = x + residual
523
+ return x
src/espnet_import/fastformer.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fastformer attention definition.
2
+
3
+ Reference:
4
+ Wu et al., "Fastformer: Additive Attention Can Be All You Need"
5
+ https://arxiv.org/abs/2108.09084
6
+ https://github.com/wuch15/Fastformer
7
+
8
+ """
9
+
10
+ import numpy
11
+ import torch
12
+
13
+
14
+ class FastSelfAttention(torch.nn.Module):
15
+ """Fast self-attention used in Fastformer."""
16
+
17
+ def __init__(
18
+ self,
19
+ size,
20
+ attention_heads,
21
+ dropout_rate,
22
+ ):
23
+ super().__init__()
24
+ if size % attention_heads != 0:
25
+ raise ValueError(
26
+ f"Hidden size ({size}) is not an integer multiple "
27
+ f"of attention heads ({attention_heads})"
28
+ )
29
+ self.attention_head_size = size // attention_heads
30
+ self.num_attention_heads = attention_heads
31
+
32
+ self.query = torch.nn.Linear(size, size)
33
+ self.query_att = torch.nn.Linear(size, attention_heads)
34
+ self.key = torch.nn.Linear(size, size)
35
+ self.key_att = torch.nn.Linear(size, attention_heads)
36
+ self.transform = torch.nn.Linear(size, size)
37
+ self.dropout = torch.nn.Dropout(dropout_rate)
38
+
39
+ def espnet_initialization_fn(self):
40
+ self.apply(self.init_weights)
41
+
42
+ def init_weights(self, module):
43
+ if isinstance(module, torch.nn.Linear):
44
+ module.weight.data.normal_(mean=0.0, std=0.02)
45
+ if isinstance(module, torch.nn.Linear) and module.bias is not None:
46
+ module.bias.data.zero_()
47
+
48
+ def transpose_for_scores(self, x):
49
+ """Reshape and transpose to compute scores.
50
+
51
+ Args:
52
+ x: (batch, time, size = n_heads * attn_dim)
53
+
54
+ Returns:
55
+ (batch, n_heads, time, attn_dim)
56
+ """
57
+
58
+ new_x_shape = x.shape[:-1] + (
59
+ self.num_attention_heads,
60
+ self.attention_head_size,
61
+ )
62
+ return x.reshape(*new_x_shape).transpose(1, 2)
63
+
64
+ def forward(self, xs_pad, mask):
65
+ """Forward method.
66
+
67
+ Args:
68
+ xs_pad: (batch, time, size = n_heads * attn_dim)
69
+ mask: (batch, 1, time), nonpadding is 1, padding is 0
70
+
71
+ Returns:
72
+ torch.Tensor: (batch, time, size)
73
+ """
74
+
75
+ batch_size, seq_len, _ = xs_pad.shape
76
+ mixed_query_layer = self.query(xs_pad) # (batch, time, size)
77
+ mixed_key_layer = self.key(xs_pad) # (batch, time, size)
78
+
79
+ if mask is not None:
80
+ mask = mask.eq(0) # padding is 1, nonpadding is 0
81
+
82
+ # (batch, n_heads, time)
83
+ query_for_score = (
84
+ self.query_att(mixed_query_layer).transpose(1, 2)
85
+ / self.attention_head_size**0.5
86
+ )
87
+ if mask is not None:
88
+ min_value = float(
89
+ numpy.finfo(
90
+ torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype
91
+ ).min
92
+ )
93
+ query_for_score = query_for_score.masked_fill(mask, min_value)
94
+ query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
95
+ else:
96
+ query_weight = torch.softmax(query_for_score, dim=-1)
97
+
98
+ query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time)
99
+ query_layer = self.transpose_for_scores(
100
+ mixed_query_layer
101
+ ) # (batch, n_heads, time, attn_dim)
102
+
103
+ pooled_query = (
104
+ torch.matmul(query_weight, query_layer)
105
+ .transpose(1, 2)
106
+ .reshape(-1, 1, self.num_attention_heads * self.attention_head_size)
107
+ ) # (batch, 1, size = n_heads * attn_dim)
108
+ pooled_query = self.dropout(pooled_query)
109
+ pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size)
110
+
111
+ mixed_query_key_layer = (
112
+ mixed_key_layer * pooled_query_repeat
113
+ ) # (batch, time, size)
114
+
115
+ # (batch, n_heads, time)
116
+ query_key_score = (
117
+ self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5
118
+ ).transpose(1, 2)
119
+ if mask is not None:
120
+ min_value = float(
121
+ numpy.finfo(
122
+ torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype
123
+ ).min
124
+ )
125
+ query_key_score = query_key_score.masked_fill(mask, min_value)
126
+ query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(
127
+ mask, 0.0
128
+ )
129
+ else:
130
+ query_key_weight = torch.softmax(query_key_score, dim=-1)
131
+
132
+ query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time)
133
+ key_layer = self.transpose_for_scores(
134
+ mixed_query_key_layer
135
+ ) # (batch, n_heads, time, attn_dim)
136
+ pooled_key = torch.matmul(
137
+ query_key_weight, key_layer
138
+ ) # (batch, n_heads, 1, attn_dim)
139
+ pooled_key = self.dropout(pooled_key)
140
+
141
+ # NOTE: value = query, due to param sharing
142
+ weighted_value = (pooled_key * query_layer).transpose(
143
+ 1, 2
144
+ ) # (batch, time, n_heads, attn_dim)
145
+ weighted_value = weighted_value.reshape(
146
+ weighted_value.shape[:-2]
147
+ + (self.num_attention_heads * self.attention_head_size,)
148
+ ) # (batch, time, size)
149
+ weighted_value = (
150
+ self.dropout(self.transform(weighted_value)) + mixed_query_layer
151
+ )
152
+
153
+ return weighted_value
src/espnet_import/label_smoothing_loss.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ # from espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py
8
+ """Label smoothing module."""
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ class LabelSmoothingLoss(nn.Module):
15
+ """Label-smoothing loss.
16
+
17
+ :param int size: the number of class
18
+ :param int padding_idx: ignored class id
19
+ :param float smoothing: smoothing rate (0.0 means the conventional CE)
20
+ :param bool normalize_length: normalize loss by sequence length if True
21
+ :param torch.nn.Module criterion: loss function to be smoothed
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ size,
27
+ padding_idx,
28
+ smoothing,
29
+ normalize_length=False,
30
+ criterion=nn.KLDivLoss(reduction="none"),
31
+ ):
32
+ """Construct an LabelSmoothingLoss object."""
33
+ super(LabelSmoothingLoss, self).__init__()
34
+ self.criterion = criterion
35
+ self.padding_idx = padding_idx
36
+ self.confidence = 1.0 - smoothing
37
+ self.smoothing = smoothing
38
+ self.size = size
39
+ self.true_dist = None
40
+ self.normalize_length = normalize_length
41
+
42
+ def forward(self, x, target):
43
+ """Compute loss between x and target.
44
+
45
+ :param torch.Tensor x: prediction (batch, seqlen, class)
46
+ :param torch.Tensor target:
47
+ target signal masked with self.padding_id (batch, seqlen)
48
+ :return: scalar float value
49
+ :rtype torch.Tensor
50
+ """
51
+ assert x.size(2) == self.size
52
+ batch_size = x.size(0)
53
+ x = x.view(-1, self.size)
54
+ target = target.view(-1)
55
+ with torch.no_grad():
56
+ true_dist = x.clone()
57
+ true_dist.fill_(self.smoothing / (self.size - 1))
58
+ ignore = target == self.padding_idx # (B,)
59
+ total = len(target) - ignore.sum().item()
60
+ target = target.masked_fill(ignore, 0) # avoid -1 index
61
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
62
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
63
+ denom = total if self.normalize_length else batch_size
64
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
src/espnet_import/layer_norm.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+ # from https://github.com/espnet/espnet/blob/master/espnet2/legacy/nets/pytorch_backend/transformer/layer_norm.py
7
+
8
+ """Layer normalization module."""
9
+
10
+ import torch
11
+
12
+
13
+ class LayerNorm(torch.nn.LayerNorm):
14
+ """Layer normalization module.
15
+
16
+ Args:
17
+ nout (int): Output dim size.
18
+ dim (int): Dimension to be normalized.
19
+
20
+ """
21
+
22
+ def __init__(self, nout, dim=-1):
23
+ """Construct an LayerNorm object."""
24
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
25
+ self.dim = dim
26
+
27
+ def forward(self, x):
28
+ """Apply layer normalization.
29
+
30
+ Args:
31
+ x (torch.Tensor): Input tensor.
32
+
33
+ Returns:
34
+ torch.Tensor: Normalized tensor.
35
+
36
+ """
37
+ if self.dim == -1:
38
+ return super(LayerNorm, self).forward(x)
39
+ return (
40
+ super(LayerNorm, self)
41
+ .forward(x.transpose(self.dim, -1))
42
+ .transpose(self.dim, -1)
43
+ )
src/espnet_import/nets_utils.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # from https://github.com/espnet/espnet/blob/master/espnet2/legacy/nets/pytorch_backend/nets_utils.py
4
+ """Network related utility tools."""
5
+
6
+ import logging
7
+ from typing import Dict, Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ def to_device(m, x):
14
+ """Send tensor into the device of the module.
15
+
16
+ Args:
17
+ m (torch.nn.Module): Torch module.
18
+ x (Tensor): Torch tensor.
19
+
20
+ Returns:
21
+ Tensor: Torch tensor located in the same place as torch module.
22
+
23
+ """
24
+ if isinstance(m, torch.nn.Module):
25
+ device = next(m.parameters()).device
26
+ elif isinstance(m, torch.Tensor):
27
+ device = m.device
28
+ else:
29
+ raise TypeError(
30
+ "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
31
+ )
32
+ return x.to(device)
33
+
34
+
35
+ def pad_list(xs, pad_value):
36
+ """Perform padding for the list of tensors.
37
+
38
+ Args:
39
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
40
+ pad_value (float): Value for padding.
41
+
42
+ Returns:
43
+ Tensor: Padded tensor (B, Tmax, `*`).
44
+
45
+ Examples:
46
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
47
+ >>> x
48
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
49
+ >>> pad_list(x, 0)
50
+ tensor([[1., 1., 1., 1.],
51
+ [1., 1., 0., 0.],
52
+ [1., 0., 0., 0.]])
53
+
54
+ """
55
+ n_batch = len(xs)
56
+ max_len = max(x.size(0) for x in xs)
57
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
58
+
59
+ for i in range(n_batch):
60
+ pad[i, : xs[i].size(0)] = xs[i]
61
+
62
+ return pad
63
+
64
+
65
+ @torch.compiler.disable
66
+ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
67
+ """Make mask tensor containing indices of padded part.
68
+
69
+ Args:
70
+ lengths (LongTensor or List): Batch of lengths (B,).
71
+ xs (Tensor, optional): The reference tensor.
72
+ If set, masks will be the same shape as this tensor.
73
+ length_dim (int, optional): Dimension indicator of the above tensor.
74
+ See the example.
75
+
76
+ Returns:
77
+ Tensor: Mask tensor containing indices of padded part.
78
+ dtype=torch.uint8 in PyTorch 1.2-
79
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
80
+
81
+ Examples:
82
+ With only lengths.
83
+
84
+ >>> lengths = [5, 3, 2]
85
+ >>> make_pad_mask(lengths)
86
+ masks = [[0, 0, 0, 0 ,0],
87
+ [0, 0, 0, 1, 1],
88
+ [0, 0, 1, 1, 1]]
89
+
90
+ With the reference tensor.
91
+
92
+ >>> xs = torch.zeros((3, 2, 4))
93
+ >>> make_pad_mask(lengths, xs)
94
+ tensor([[[0, 0, 0, 0],
95
+ [0, 0, 0, 0]],
96
+ [[0, 0, 0, 1],
97
+ [0, 0, 0, 1]],
98
+ [[0, 0, 1, 1],
99
+ [0, 0, 1, 1]]], dtype=torch.uint8)
100
+ >>> xs = torch.zeros((3, 2, 6))
101
+ >>> make_pad_mask(lengths, xs)
102
+ tensor([[[0, 0, 0, 0, 0, 1],
103
+ [0, 0, 0, 0, 0, 1]],
104
+ [[0, 0, 0, 1, 1, 1],
105
+ [0, 0, 0, 1, 1, 1]],
106
+ [[0, 0, 1, 1, 1, 1],
107
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
108
+
109
+ With the reference tensor and dimension indicator.
110
+
111
+ >>> xs = torch.zeros((3, 6, 6))
112
+ >>> make_pad_mask(lengths, xs, 1)
113
+ tensor([[[0, 0, 0, 0, 0, 0],
114
+ [0, 0, 0, 0, 0, 0],
115
+ [0, 0, 0, 0, 0, 0],
116
+ [0, 0, 0, 0, 0, 0],
117
+ [0, 0, 0, 0, 0, 0],
118
+ [1, 1, 1, 1, 1, 1]],
119
+ [[0, 0, 0, 0, 0, 0],
120
+ [0, 0, 0, 0, 0, 0],
121
+ [0, 0, 0, 0, 0, 0],
122
+ [1, 1, 1, 1, 1, 1],
123
+ [1, 1, 1, 1, 1, 1],
124
+ [1, 1, 1, 1, 1, 1]],
125
+ [[0, 0, 0, 0, 0, 0],
126
+ [0, 0, 0, 0, 0, 0],
127
+ [1, 1, 1, 1, 1, 1],
128
+ [1, 1, 1, 1, 1, 1],
129
+ [1, 1, 1, 1, 1, 1],
130
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
131
+ >>> make_pad_mask(lengths, xs, 2)
132
+ tensor([[[0, 0, 0, 0, 0, 1],
133
+ [0, 0, 0, 0, 0, 1],
134
+ [0, 0, 0, 0, 0, 1],
135
+ [0, 0, 0, 0, 0, 1],
136
+ [0, 0, 0, 0, 0, 1],
137
+ [0, 0, 0, 0, 0, 1]],
138
+ [[0, 0, 0, 1, 1, 1],
139
+ [0, 0, 0, 1, 1, 1],
140
+ [0, 0, 0, 1, 1, 1],
141
+ [0, 0, 0, 1, 1, 1],
142
+ [0, 0, 0, 1, 1, 1],
143
+ [0, 0, 0, 1, 1, 1]],
144
+ [[0, 0, 1, 1, 1, 1],
145
+ [0, 0, 1, 1, 1, 1],
146
+ [0, 0, 1, 1, 1, 1],
147
+ [0, 0, 1, 1, 1, 1],
148
+ [0, 0, 1, 1, 1, 1],
149
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
150
+
151
+ """
152
+ if length_dim == 0:
153
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
154
+
155
+ # If the input dimension is 2 or 3,
156
+ # then we use ESPnet-ONNX based implementation for tracable modeling.
157
+ # otherwise we use the traditional implementation for research use.
158
+ if isinstance(lengths, list):
159
+ logging.warning(
160
+ "Using make_pad_mask with a list of lengths is not tracable. "
161
+ + "If you try to trace this function with type(lengths) == list, "
162
+ + "please change the type of lengths to torch.LongTensor."
163
+ )
164
+
165
+ if (
166
+ (xs is None or xs.dim() in (2, 3))
167
+ and length_dim <= 2
168
+ and (not isinstance(lengths, list) and lengths.dim() == 1)
169
+ ):
170
+ return _make_pad_mask_traceable(lengths, xs, length_dim, maxlen)
171
+ else:
172
+ return _make_pad_mask(lengths, xs, length_dim, maxlen)
173
+
174
+
175
+ def _make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
176
+ if not isinstance(lengths, list):
177
+ lengths = lengths.long().tolist()
178
+
179
+ bs = int(len(lengths))
180
+ if maxlen is None:
181
+ if xs is None:
182
+ maxlen = int(max(lengths))
183
+ else:
184
+ maxlen = xs.size(length_dim)
185
+ else:
186
+ assert xs is None, "When maxlen is specified, xs must not be specified."
187
+ assert maxlen >= int(
188
+ max(lengths)
189
+ ), f"maxlen {maxlen} must be >= max(lengths) {max(lengths)}"
190
+
191
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
192
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
193
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
194
+ mask = seq_range_expand >= seq_length_expand
195
+
196
+ if xs is not None:
197
+ assert (
198
+ xs.size(0) == bs
199
+ ), f"The size of x.size(0) {xs.size(0)} must match the batch size {bs}"
200
+
201
+ if length_dim < 0:
202
+ length_dim = xs.dim() + length_dim
203
+ # ind = (:, None, ..., None, :, , None, ..., None)
204
+ ind = tuple(
205
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
206
+ )
207
+ mask = mask[ind].expand_as(xs).to(xs.device)
208
+ return mask
209
+
210
+
211
+ def _make_pad_mask_traceable(lengths, xs, length_dim, maxlen=None):
212
+ """Make mask tensor containing indices of padded part.
213
+
214
+ This is a simplified implementation of make_pad_mask without the xs input
215
+ that supports JIT tracing for applications like exporting models to ONNX.
216
+ Dimension length of xs should be 2 or 3
217
+ This function will create torch.ones(maxlen, maxlen).triu(diagonal=1) and
218
+ select rows to create mask tensor.
219
+ """
220
+ if xs is None:
221
+ device = lengths.device
222
+ else:
223
+ device = xs.device
224
+
225
+ if xs is not None and len(xs.shape) == 3:
226
+ if length_dim == 1:
227
+ lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
228
+ else:
229
+ # Then length_dim is 2 or -1.
230
+ if length_dim not in (-1, 2):
231
+ logging.warning(
232
+ f"Invalid length_dim {length_dim}."
233
+ + "We set it to -1, which is the default value."
234
+ )
235
+ length_dim = -1
236
+ lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
237
+
238
+ if maxlen is not None:
239
+ assert xs is None
240
+ assert maxlen >= lengths.max()
241
+ elif xs is not None:
242
+ maxlen = xs.shape[length_dim]
243
+ else:
244
+ maxlen = lengths.max()
245
+
246
+ # clip max(length) to maxlen
247
+ lengths = torch.clamp(lengths, max=maxlen).type(torch.long)
248
+
249
+ mask = torch.ones(maxlen + 1, maxlen + 1, dtype=torch.bool, device=device)
250
+ mask = triu_onnx(mask)[1:, :-1] # onnx cannot handle diagonal argument.
251
+ mask = mask[lengths - 1][..., :maxlen]
252
+
253
+ if xs is not None and len(xs.shape) == 3 and length_dim == 1:
254
+ return mask.transpose(1, 2)
255
+ else:
256
+ return mask
257
+
258
+
259
+ def triu_onnx(x):
260
+ """Make TriU for ONNX."""
261
+ arange = torch.arange(x.size(0), device=x.device)
262
+ mask = arange.unsqueeze(-1).expand(-1, x.size(0)) <= arange
263
+ return x * mask
264
+
265
+
266
+ def make_non_pad_mask(lengths, xs=None, length_dim=-1):
267
+ """Make mask tensor containing indices of non-padded part.
268
+
269
+ Args:
270
+ lengths (LongTensor or List): Batch of lengths (B,).
271
+ xs (Tensor, optional): The reference tensor.
272
+ If set, masks will be the same shape as this tensor.
273
+ length_dim (int, optional): Dimension indicator of the above tensor.
274
+ See the example.
275
+
276
+ Returns:
277
+ ByteTensor: mask tensor containing indices of padded part.
278
+ dtype=torch.uint8 in PyTorch 1.2-
279
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
280
+
281
+ Examples:
282
+ With only lengths.
283
+
284
+ >>> lengths = [5, 3, 2]
285
+ >>> make_non_pad_mask(lengths)
286
+ masks = [[1, 1, 1, 1 ,1],
287
+ [1, 1, 1, 0, 0],
288
+ [1, 1, 0, 0, 0]]
289
+
290
+ With the reference tensor.
291
+
292
+ >>> xs = torch.zeros((3, 2, 4))
293
+ >>> make_non_pad_mask(lengths, xs)
294
+ tensor([[[1, 1, 1, 1],
295
+ [1, 1, 1, 1]],
296
+ [[1, 1, 1, 0],
297
+ [1, 1, 1, 0]],
298
+ [[1, 1, 0, 0],
299
+ [1, 1, 0, 0]]], dtype=torch.uint8)
300
+ >>> xs = torch.zeros((3, 2, 6))
301
+ >>> make_non_pad_mask(lengths, xs)
302
+ tensor([[[1, 1, 1, 1, 1, 0],
303
+ [1, 1, 1, 1, 1, 0]],
304
+ [[1, 1, 1, 0, 0, 0],
305
+ [1, 1, 1, 0, 0, 0]],
306
+ [[1, 1, 0, 0, 0, 0],
307
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
308
+
309
+ With the reference tensor and dimension indicator.
310
+
311
+ >>> xs = torch.zeros((3, 6, 6))
312
+ >>> make_non_pad_mask(lengths, xs, 1)
313
+ tensor([[[1, 1, 1, 1, 1, 1],
314
+ [1, 1, 1, 1, 1, 1],
315
+ [1, 1, 1, 1, 1, 1],
316
+ [1, 1, 1, 1, 1, 1],
317
+ [1, 1, 1, 1, 1, 1],
318
+ [0, 0, 0, 0, 0, 0]],
319
+ [[1, 1, 1, 1, 1, 1],
320
+ [1, 1, 1, 1, 1, 1],
321
+ [1, 1, 1, 1, 1, 1],
322
+ [0, 0, 0, 0, 0, 0],
323
+ [0, 0, 0, 0, 0, 0],
324
+ [0, 0, 0, 0, 0, 0]],
325
+ [[1, 1, 1, 1, 1, 1],
326
+ [1, 1, 1, 1, 1, 1],
327
+ [0, 0, 0, 0, 0, 0],
328
+ [0, 0, 0, 0, 0, 0],
329
+ [0, 0, 0, 0, 0, 0],
330
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
331
+ >>> make_non_pad_mask(lengths, xs, 2)
332
+ tensor([[[1, 1, 1, 1, 1, 0],
333
+ [1, 1, 1, 1, 1, 0],
334
+ [1, 1, 1, 1, 1, 0],
335
+ [1, 1, 1, 1, 1, 0],
336
+ [1, 1, 1, 1, 1, 0],
337
+ [1, 1, 1, 1, 1, 0]],
338
+ [[1, 1, 1, 0, 0, 0],
339
+ [1, 1, 1, 0, 0, 0],
340
+ [1, 1, 1, 0, 0, 0],
341
+ [1, 1, 1, 0, 0, 0],
342
+ [1, 1, 1, 0, 0, 0],
343
+ [1, 1, 1, 0, 0, 0]],
344
+ [[1, 1, 0, 0, 0, 0],
345
+ [1, 1, 0, 0, 0, 0],
346
+ [1, 1, 0, 0, 0, 0],
347
+ [1, 1, 0, 0, 0, 0],
348
+ [1, 1, 0, 0, 0, 0],
349
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
350
+
351
+ """
352
+ return ~make_pad_mask(lengths, xs, length_dim)
353
+
354
+
355
+ def mask_by_length(xs, lengths, fill=0):
356
+ """Mask tensor according to length.
357
+
358
+ Args:
359
+ xs (Tensor): Batch of input tensor (B, `*`).
360
+ lengths (LongTensor or List): Batch of lengths (B,).
361
+ fill (int or float): Value to fill masked part.
362
+
363
+ Returns:
364
+ Tensor: Batch of masked input tensor (B, `*`).
365
+
366
+ Examples:
367
+ >>> x = torch.arange(5).repeat(3, 1) + 1
368
+ >>> x
369
+ tensor([[1, 2, 3, 4, 5],
370
+ [1, 2, 3, 4, 5],
371
+ [1, 2, 3, 4, 5]])
372
+ >>> lengths = [5, 3, 2]
373
+ >>> mask_by_length(x, lengths)
374
+ tensor([[1, 2, 3, 4, 5],
375
+ [1, 2, 3, 0, 0],
376
+ [1, 2, 0, 0, 0]])
377
+
378
+ """
379
+ assert xs.size(0) == len(lengths)
380
+ ret = xs.data.new(*xs.size()).fill_(fill)
381
+ for i, l in enumerate(lengths):
382
+ ret[i, :l] = xs[i, :l]
383
+ return ret
384
+
385
+
386
+ def th_accuracy(pad_outputs, pad_targets, ignore_label):
387
+ """Calculate accuracy.
388
+
389
+ Args:
390
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
391
+ pad_targets (LongTensor): Target label tensors (B, Lmax, D).
392
+ ignore_label (int): Ignore label id.
393
+
394
+ Returns:
395
+ float: Accuracy value (0.0 - 1.0).
396
+
397
+ """
398
+ pad_pred = pad_outputs.view(
399
+ pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
400
+ ).argmax(2)
401
+ mask = pad_targets != ignore_label
402
+ numerator = torch.sum(
403
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
404
+ )
405
+ denominator = torch.sum(mask)
406
+ return float(numerator) / float(denominator)
407
+
408
+
409
+ def to_torch_tensor(x):
410
+ """Change to torch.Tensor or ComplexTensor from numpy.ndarray.
411
+
412
+ Args:
413
+ x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
414
+
415
+ Returns:
416
+ Tensor or ComplexTensor: Type converted inputs.
417
+
418
+ Examples:
419
+ >>> xs = np.ones(3, dtype=np.float32)
420
+ >>> xs = to_torch_tensor(xs)
421
+ tensor([1., 1., 1.])
422
+ >>> xs = torch.ones(3, 4, 5)
423
+ >>> assert to_torch_tensor(xs) is xs
424
+ >>> xs = {'real': xs, 'imag': xs}
425
+ >>> to_torch_tensor(xs)
426
+ ComplexTensor(
427
+ Real:
428
+ tensor([1., 1., 1.])
429
+ Imag;
430
+ tensor([1., 1., 1.])
431
+ )
432
+
433
+ """
434
+ # If numpy, change to torch tensor
435
+ if isinstance(x, np.ndarray):
436
+ if x.dtype.kind == "c":
437
+ # Dynamically importing because torch_complex requires python3
438
+ from torch_complex.tensor import ComplexTensor
439
+
440
+ return ComplexTensor(x)
441
+ else:
442
+ return torch.from_numpy(x)
443
+
444
+ # If {'real': ..., 'imag': ...}, convert to ComplexTensor
445
+ elif isinstance(x, dict):
446
+ # Dynamically importing because torch_complex requires python3
447
+ from torch_complex.tensor import ComplexTensor
448
+
449
+ if "real" not in x or "imag" not in x:
450
+ raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
451
+ # Relative importing because of using python3 syntax
452
+ return ComplexTensor(x["real"], x["imag"])
453
+
454
+ # If torch.Tensor, as it is
455
+ elif isinstance(x, torch.Tensor):
456
+ return x
457
+
458
+ else:
459
+ error = (
460
+ "x must be numpy.ndarray, torch.Tensor or a dict like "
461
+ "{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
462
+ "but got {}".format(type(x))
463
+ )
464
+ try:
465
+ from torch_complex.tensor import ComplexTensor
466
+ except Exception:
467
+ # If PY2
468
+ raise ValueError(error)
469
+ else:
470
+ # If PY3
471
+ if isinstance(x, ComplexTensor):
472
+ return x
473
+ else:
474
+ raise ValueError(error)
475
+
476
+
477
+ def get_subsample(train_args, mode, arch):
478
+ """Parse the subsampling factors from the args for the specified `mode` and `arch`.
479
+
480
+ Args:
481
+ train_args: argument Namespace containing options.
482
+ mode: one of ('asr', 'mt', 'st')
483
+ arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
484
+
485
+ Returns:
486
+ np.ndarray / List[np.ndarray]: subsampling factors.
487
+ """
488
+ if arch == "transformer":
489
+ return np.array([1])
490
+
491
+ elif mode == "mt" and arch == "rnn":
492
+ # +1 means input (+1) and layers outputs (train_args.elayer)
493
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
494
+ logging.warning("Subsampling is not performed for machine translation.")
495
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
496
+ return subsample
497
+
498
+ elif (
499
+ (mode == "asr" and arch in ("rnn", "rnn-t"))
500
+ or (mode == "mt" and arch == "rnn")
501
+ or (mode == "st" and arch == "rnn")
502
+ ):
503
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
504
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
505
+ ss = train_args.subsample.split("_")
506
+ for j in range(min(train_args.elayers + 1, len(ss))):
507
+ subsample[j] = int(ss[j])
508
+ else:
509
+ logging.warning(
510
+ "Subsampling is not performed for vgg*. "
511
+ "It is performed in max pooling layers at CNN."
512
+ )
513
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
514
+ return subsample
515
+
516
+ elif mode == "asr" and arch == "rnn_mix":
517
+ subsample = np.ones(
518
+ train_args.elayers_sd + train_args.elayers + 1, dtype=np.int64
519
+ )
520
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
521
+ ss = train_args.subsample.split("_")
522
+ for j in range(
523
+ min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
524
+ ):
525
+ subsample[j] = int(ss[j])
526
+ else:
527
+ logging.warning(
528
+ "Subsampling is not performed for vgg*. "
529
+ "It is performed in max pooling layers at CNN."
530
+ )
531
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
532
+ return subsample
533
+
534
+ elif mode == "asr" and arch == "rnn_mulenc":
535
+ subsample_list = []
536
+ for idx in range(train_args.num_encs):
537
+ subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int64)
538
+ if train_args.etype[idx].endswith("p") and not train_args.etype[
539
+ idx
540
+ ].startswith("vgg"):
541
+ ss = train_args.subsample[idx].split("_")
542
+ for j in range(min(train_args.elayers[idx] + 1, len(ss))):
543
+ subsample[j] = int(ss[j])
544
+ else:
545
+ logging.warning(
546
+ "Encoder %d: Subsampling is not performed for vgg*. "
547
+ "It is performed in max pooling layers at CNN.",
548
+ idx + 1,
549
+ )
550
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
551
+ subsample_list.append(subsample)
552
+ return subsample_list
553
+
554
+ else:
555
+ raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
556
+
557
+
558
+ def rename_state_dict(
559
+ old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
560
+ ):
561
+ """Replace keys of old prefix with new prefix in state dict."""
562
+ # need this list not to break the dict iterator
563
+ old_keys = [k for k in state_dict if k.startswith(old_prefix)]
564
+ if len(old_keys) > 0:
565
+ logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
566
+ for k in old_keys:
567
+ v = state_dict.pop(k)
568
+ new_k = k.replace(old_prefix, new_prefix)
569
+ state_dict[new_k] = v
570
+
571
+ import torch
572
+
573
+
574
+ # from espnet2.legacy.nets.pytorch_backend.conformer.swish import Swish
575
+ class Swish(torch.nn.Module):
576
+ """Construct an Swish object."""
577
+
578
+ def forward(self, x):
579
+ """Return Swich activation function."""
580
+ return x * torch.sigmoid(x)
581
+
582
+
583
+ def get_activation(act):
584
+ """Return activation function."""
585
+
586
+ activation_funcs = {
587
+ "hardtanh": torch.nn.Hardtanh,
588
+ "tanh": torch.nn.Tanh,
589
+ "relu": torch.nn.ReLU,
590
+ "selu": torch.nn.SELU,
591
+ "swish": Swish,
592
+ }
593
+
594
+ return activation_funcs[act]()
595
+
596
+
597
+ def trim_by_ctc_posterior(
598
+ h: torch.Tensor,
599
+ ctc_probs: torch.Tensor,
600
+ masks: torch.Tensor,
601
+ pos_emb: torch.Tensor = None,
602
+ ):
603
+ """Trim the encoder hidden output using CTC posterior.
604
+
605
+ The continuous frames in the tail that confidently represent
606
+ blank symbols are trimmed.
607
+ """
608
+ # Empirical settings
609
+ frame_tolerance = 5
610
+ conf_tolerance = 0.95
611
+ blank_id = 0
612
+
613
+ assert masks.size(1) == 1
614
+ masks = masks.squeeze(1)
615
+ hlens = masks.sum(dim=1)
616
+ assert h.size()[:2] == ctc_probs.size()[:2]
617
+ assert h.size(0) == hlens.size(0)
618
+
619
+ # blank frames
620
+ max_values, max_indices = ctc_probs.max(dim=2)
621
+ blank_masks = torch.logical_and(
622
+ max_values > conf_tolerance, max_indices == blank_id
623
+ )
624
+
625
+ # plus ignored frames
626
+ joint_masks = torch.logical_or(blank_masks, ~masks)
627
+
628
+ # lengths after the trimming
629
+ B, T, _ = h.size()
630
+ frame_idx = torch.where(
631
+ joint_masks, -1, torch.arange(T).unsqueeze(0).repeat(B, 1).to(h.device)
632
+ )
633
+ after_lens = torch.where(
634
+ frame_idx.max(dim=-1)[0] + frame_tolerance + 1 < hlens,
635
+ frame_idx.max(dim=-1)[0] + frame_tolerance + 1,
636
+ hlens,
637
+ )
638
+
639
+ h = h[:, : max(after_lens)]
640
+ masks = ~make_pad_mask(after_lens).to(h.device).unsqueeze(1)
641
+
642
+ if pos_emb is None:
643
+ pos_emb = None
644
+ elif (hlens.max() * 2 - 1).item() == pos_emb.size(1): # RelPositionalEncoding
645
+ pos_emb = pos_emb[
646
+ :, pos_emb.size(1) // 2 - h.size(1) + 1 : pos_emb.size(1) // 2 + h.size(1)
647
+ ]
648
+ else:
649
+ pos_emb = pos_emb[:, : h.size(1)]
650
+
651
+ return h, masks, pos_emb
652
+
653
+
654
+ def roll_tensor(
655
+ x: torch.Tensor,
656
+ lengths: torch.Tensor,
657
+ roll_amounts: Optional[torch.Tensor] = None,
658
+ fixed_intervals: Optional[int] = None,
659
+ ) -> torch.Tensor:
660
+ """Left-roll tensor x by roll_amounts, only within lengths and optionally quantized.
661
+
662
+ Args:
663
+ x: input tensor (B, T, D)
664
+ lengths: lengths of each sequence (B,)
665
+ roll_amounts: random shift amounts (B,). If None, random shift
666
+ amounts are generated.
667
+ fixed_intervals: if not None, roll_amounts are quantized to
668
+ multiples of this.
669
+ Returns:
670
+ rolled_x: rolled tensor (B, T, D)
671
+ Useful to apply roll augmentation to the input, while considering
672
+ the input length for each sample.
673
+ """
674
+ B, T, D = x.shape
675
+
676
+ indices = torch.arange(T).unsqueeze(0).expand(B, T).to(x.device) # (B, T)
677
+ lengths = lengths.unsqueeze(1) # (B, 1)
678
+
679
+ if roll_amounts is None:
680
+ roll_amounts = torch.randint(0, lengths.max(), (B,), device=x.device)
681
+ if fixed_intervals is not None:
682
+ roll_amounts = (roll_amounts // fixed_intervals) * fixed_intervals
683
+ roll_indices = (indices - roll_amounts.unsqueeze(1)) % lengths # (B, T)
684
+ roll_indices = roll_indices.unsqueeze(2).expand(-1, -1, D) # (B, T, D)
685
+
686
+ mask = indices < lengths # (B, T), True if position is valid
687
+ rolled_x = torch.empty_like(x)
688
+ rolled_x[mask] = x.gather(1, roll_indices)[mask]
689
+ rolled_x[~mask] = x[~mask]
690
+ return rolled_x
src/espnet_import/positionwise_feed_forward.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Positionwise feed forward layer definition."""
8
+
9
+ import torch
10
+
11
+
12
+ class PositionwiseFeedForward(torch.nn.Module):
13
+ """Positionwise feed forward layer.
14
+
15
+ Args:
16
+ idim (int): Input dimenstion.
17
+ hidden_units (int): The number of hidden units.
18
+ dropout_rate (float): Dropout rate.
19
+
20
+ """
21
+
22
+ def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
23
+ """Construct an PositionwiseFeedForward object."""
24
+ super(PositionwiseFeedForward, self).__init__()
25
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
26
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
27
+ self.dropout = torch.nn.Dropout(dropout_rate)
28
+ self.activation = activation
29
+
30
+ def forward(self, x):
31
+ """Forward function."""
32
+ return self.w_2(self.dropout(self.activation(self.w_1(x))))
src/espnet_import/repeat.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Repeat the same layer definition."""
8
+
9
+ import torch
10
+
11
+
12
+ class MultiSequential(torch.nn.Sequential):
13
+ """Multi-input multi-output torch.nn.Sequential."""
14
+
15
+ def __init__(self, *args, layer_drop_rate=0.0):
16
+ """Initialize MultiSequential with layer_drop.
17
+
18
+ Args:
19
+ layer_drop_rate (float): Probability of dropping out each fn (layer).
20
+
21
+ """
22
+ super(MultiSequential, self).__init__(*args)
23
+ self.layer_drop_rate = layer_drop_rate
24
+
25
+ def forward(self, *args):
26
+ """Repeat."""
27
+ _probs = torch.empty(len(self)).uniform_()
28
+ for idx, m in enumerate(self):
29
+ if not self.training or (_probs[idx] >= self.layer_drop_rate):
30
+ args = m(*args)
31
+ return args
32
+
33
+
34
+ def repeat(N, fn, layer_drop_rate=0.0):
35
+ """Repeat module N times.
36
+
37
+ Args:
38
+ N (int): Number of repeat time.
39
+ fn (Callable): Function to generate module.
40
+ layer_drop_rate (float): Probability of dropping out each fn (layer).
41
+
42
+ Returns:
43
+ MultiSequential: Repeated model instance.
44
+
45
+ """
46
+ return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate)
src/espnet_import/subsampling.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Subsampling layer definition."""
8
+
9
+ import torch
10
+
11
+ from src.espnet_import.embedding import PositionalEncoding
12
+
13
+
14
+ class TooShortUttError(Exception):
15
+ """Raised when the utt is too short for subsampling.
16
+
17
+ Args:
18
+ message (str): Message for error catch
19
+ actual_size (int): the short size that cannot pass the subsampling
20
+ limit (int): the limit size for subsampling
21
+
22
+ """
23
+
24
+ def __init__(self, message, actual_size, limit):
25
+ """Construct a TooShortUttError for error handler."""
26
+ super().__init__(message)
27
+ self.actual_size = actual_size
28
+ self.limit = limit
29
+
30
+
31
+ def check_short_utt(ins, size):
32
+ """Check if the utterance is too short for subsampling."""
33
+ if isinstance(ins, Conv1dSubsampling1) and size < 5:
34
+ return True, 5
35
+ if isinstance(ins, Conv1dSubsampling2) and size < 5:
36
+ return True, 5
37
+ if isinstance(ins, Conv1dSubsampling3) and size < 7:
38
+ return True, 7
39
+ if isinstance(ins, Conv2dSubsampling1) and size < 5:
40
+ return True, 5
41
+ if isinstance(ins, Conv2dSubsampling2) and size < 7:
42
+ return True, 7
43
+ if isinstance(ins, Conv2dSubsampling) and size < 7:
44
+ return True, 7
45
+ if isinstance(ins, Conv2dSubsampling6) and size < 11:
46
+ return True, 11
47
+ if isinstance(ins, Conv2dSubsampling8) and size < 15:
48
+ return True, 15
49
+ return False, -1
50
+
51
+
52
+ def _upgrade_legacy_subsampling_state_dict(state_dict, prefix):
53
+ """Remap legacy nn.Sequential keys for subsampling modules."""
54
+ w_new = prefix + "out.weight"
55
+ b_new = prefix + "out.bias"
56
+ w_old = prefix + "out.0.weight"
57
+ b_old = prefix + "out.0.bias"
58
+
59
+ if w_new not in state_dict and w_old in state_dict:
60
+ state_dict[w_new] = state_dict.pop(w_old)
61
+ elif w_new in state_dict and w_old in state_dict:
62
+ state_dict.pop(w_old)
63
+
64
+ if b_new not in state_dict and b_old in state_dict:
65
+ state_dict[b_new] = state_dict.pop(b_old)
66
+ elif b_new in state_dict and b_old in state_dict:
67
+ state_dict.pop(b_old)
68
+
69
+ old_pos_prefix = prefix + "out.1."
70
+ new_pos_prefix = prefix + "pos_enc."
71
+ for k in list(state_dict.keys()):
72
+ if not k.startswith(old_pos_prefix):
73
+ continue
74
+ new_k = new_pos_prefix + k[len(old_pos_prefix) :]
75
+ if new_k not in state_dict:
76
+ state_dict[new_k] = state_dict[k]
77
+ state_dict.pop(k, None)
78
+
79
+
80
+ class Conv1dSubsampling1(torch.nn.Module):
81
+ """Convolutional 1D subsampling.
82
+
83
+ Args:
84
+ idim (int): Input dimension.
85
+ odim (int): Output dimension.
86
+ dropout_rate (float): Dropout rate.
87
+ pos_enc (torch.nn.Module): Custom position encoding layer.
88
+
89
+ """
90
+
91
+ def __init__(self, idim, odim, dropout_rate, pos_enc=None):
92
+ """Construct an Conv1dSubsampling1 object."""
93
+ super(Conv1dSubsampling1, self).__init__()
94
+ self.conv = torch.nn.Sequential(
95
+ torch.nn.Conv1d(idim, odim, 3, 1),
96
+ torch.nn.ReLU(),
97
+ torch.nn.Conv1d(odim, odim, 3, 1),
98
+ torch.nn.ReLU(),
99
+ )
100
+ self.out = torch.nn.Linear(odim, odim)
101
+ self.pos_enc = (
102
+ pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
103
+ )
104
+
105
+ def _load_from_state_dict(
106
+ self,
107
+ state_dict,
108
+ prefix,
109
+ local_metadata,
110
+ strict,
111
+ missing_keys,
112
+ unexpected_keys,
113
+ error_msgs,
114
+ ):
115
+ _upgrade_legacy_subsampling_state_dict(state_dict, prefix)
116
+ super()._load_from_state_dict(
117
+ state_dict,
118
+ prefix,
119
+ local_metadata,
120
+ strict,
121
+ missing_keys,
122
+ unexpected_keys,
123
+ error_msgs,
124
+ )
125
+
126
+ def forward(self, x, x_mask, prefix_embeds=None):
127
+ """Subsample x.
128
+
129
+ Args:
130
+ x (torch.Tensor): Input tensor (#batch, time, idim).
131
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
132
+ prefix_embeds (torch.Tensor or None): Prefix token embeddings
133
+ (#batch, prefix_len, odim).
134
+
135
+ Returns:
136
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
137
+ where time' = time // 2.
138
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
139
+ where time' = time // 2.
140
+
141
+ """
142
+ x = x.transpose(2, 1) # (#batch, idim, time)
143
+ x = self.conv(x)
144
+ b, c, t = x.size()
145
+ x = self.out(x.transpose(1, 2).contiguous())
146
+ if x_mask is not None:
147
+ x_mask = x_mask[:, :, :-2:1][:, :, :-2:1]
148
+
149
+ if prefix_embeds is not None:
150
+ x = torch.cat([prefix_embeds, x], dim=1)
151
+ if x_mask is not None:
152
+ x_mask = torch.cat(
153
+ [
154
+ torch.ones(
155
+ x_mask.shape[0],
156
+ 1,
157
+ prefix_embeds.size(1),
158
+ dtype=x_mask.dtype,
159
+ device=x_mask.device,
160
+ ),
161
+ x_mask,
162
+ ],
163
+ dim=-1,
164
+ )
165
+
166
+ x = self.pos_enc(x)
167
+
168
+ return x, x_mask
169
+
170
+ def __getitem__(self, key):
171
+ """Get item.
172
+
173
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
174
+ return the positioning encoding.
175
+
176
+ """
177
+ if key != -1:
178
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
179
+ return self.pos_enc
180
+
181
+
182
+ class Conv1dSubsampling2(torch.nn.Module):
183
+ """Convolutional 1D subsampling (to 1/2 length).
184
+
185
+ Args:
186
+ idim (int): Input dimension.
187
+ odim (int): Output dimension.
188
+ dropout_rate (float): Dropout rate.
189
+ pos_enc (torch.nn.Module): Custom position encoding layer.
190
+
191
+ """
192
+
193
+ def __init__(self, idim, odim, dropout_rate, pos_enc=None):
194
+ """Construct an Conv1dSubsampling2 object."""
195
+ super(Conv1dSubsampling2, self).__init__()
196
+ self.conv = torch.nn.Sequential(
197
+ torch.nn.Conv1d(idim, odim, 3, 1),
198
+ torch.nn.ReLU(),
199
+ torch.nn.Conv1d(odim, odim, 3, 2),
200
+ torch.nn.ReLU(),
201
+ )
202
+ self.out = torch.nn.Linear(odim, odim)
203
+ self.pos_enc = (
204
+ pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
205
+ )
206
+
207
+ def _load_from_state_dict(
208
+ self,
209
+ state_dict,
210
+ prefix,
211
+ local_metadata,
212
+ strict,
213
+ missing_keys,
214
+ unexpected_keys,
215
+ error_msgs,
216
+ ):
217
+ _upgrade_legacy_subsampling_state_dict(state_dict, prefix)
218
+ super()._load_from_state_dict(
219
+ state_dict,
220
+ prefix,
221
+ local_metadata,
222
+ strict,
223
+ missing_keys,
224
+ unexpected_keys,
225
+ error_msgs,
226
+ )
227
+
228
+ def forward(self, x, x_mask, prefix_embeds=None):
229
+ """Subsample x.
230
+
231
+ Args:
232
+ x (torch.Tensor): Input tensor (#batch, time, idim).
233
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
234
+ prefix_embeds (torch.Tensor or None): Prefix token embeddings
235
+ (#batch, prefix_len, odim).
236
+
237
+ Returns:
238
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
239
+ where time' = time // 2.
240
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
241
+ where time' = time // 2.
242
+
243
+ """
244
+ x = x.transpose(2, 1) # (#batch, idim, time)
245
+ x = self.conv(x)
246
+ b, c, t = x.size()
247
+ x = self.out(x.transpose(1, 2).contiguous())
248
+ if x_mask is not None:
249
+ x_mask = x_mask[:, :, :-2:1][:, :, :-2:2]
250
+
251
+ if prefix_embeds is not None:
252
+ x = torch.cat([prefix_embeds, x], dim=1)
253
+ if x_mask is not None:
254
+ x_mask = torch.cat(
255
+ [
256
+ torch.ones(
257
+ x_mask.shape[0],
258
+ 1,
259
+ prefix_embeds.size(1),
260
+ dtype=x_mask.dtype,
261
+ device=x_mask.device,
262
+ ),
263
+ x_mask,
264
+ ],
265
+ dim=-1,
266
+ )
267
+
268
+ x = self.pos_enc(x)
269
+
270
+ return x, x_mask
271
+
272
+ def __getitem__(self, key):
273
+ """Get item.
274
+
275
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
276
+ return the positioning encoding.
277
+
278
+ """
279
+ if key != -1:
280
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
281
+ return self.pos_enc
282
+
283
+
284
+ class Conv1dSubsampling3(torch.nn.Module):
285
+ """Convolutional 1D subsampling (to 1/3 length).
286
+
287
+ Args:
288
+ idim (int): Input dimension.
289
+ odim (int): Output dimension.
290
+ dropout_rate (float): Dropout rate.
291
+ pos_enc (torch.nn.Module): Custom position encoding layer.
292
+
293
+ """
294
+
295
+ def __init__(self, idim, odim, dropout_rate, pos_enc=None):
296
+ """Construct an Conv1dSubsampling3 object."""
297
+ super(Conv1dSubsampling3, self).__init__()
298
+ self.conv = torch.nn.Sequential(
299
+ torch.nn.Conv1d(idim, odim, 3, 1),
300
+ torch.nn.ReLU(),
301
+ torch.nn.Conv1d(odim, odim, 5, 3),
302
+ torch.nn.ReLU(),
303
+ )
304
+ self.out = torch.nn.Linear(odim, odim)
305
+ self.pos_enc = (
306
+ pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
307
+ )
308
+
309
+ def _load_from_state_dict(
310
+ self,
311
+ state_dict,
312
+ prefix,
313
+ local_metadata,
314
+ strict,
315
+ missing_keys,
316
+ unexpected_keys,
317
+ error_msgs,
318
+ ):
319
+ _upgrade_legacy_subsampling_state_dict(state_dict, prefix)
320
+ super()._load_from_state_dict(
321
+ state_dict,
322
+ prefix,
323
+ local_metadata,
324
+ strict,
325
+ missing_keys,
326
+ unexpected_keys,
327
+ error_msgs,
328
+ )
329
+
330
+ def forward(self, x, x_mask, prefix_embeds=None):
331
+ """Subsample x.
332
+
333
+ Args:
334
+ x (torch.Tensor): Input tensor (#batch, time, idim).
335
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
336
+ prefix_embeds (torch.Tensor or None): Prefix token embeddings
337
+ (#batch, prefix_len, odim).
338
+
339
+ Returns:
340
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
341
+ where time' = time // 2.
342
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
343
+ where time' = time // 2.
344
+
345
+ """
346
+ x = x.transpose(2, 1) # (#batch, idim, time)
347
+ x = self.conv(x)
348
+ b, c, t = x.size()
349
+ x = self.out(x.transpose(1, 2).contiguous())
350
+ if x_mask is not None:
351
+ x_mask = x_mask[:, :, :-2:1][:, :, :-4:3]
352
+
353
+ if prefix_embeds is not None:
354
+ x = torch.cat([prefix_embeds, x], dim=1)
355
+ if x_mask is not None:
356
+ x_mask = torch.cat(
357
+ [
358
+ torch.ones(
359
+ x_mask.shape[0],
360
+ 1,
361
+ prefix_embeds.size(1),
362
+ dtype=x_mask.dtype,
363
+ device=x_mask.device,
364
+ ),
365
+ x_mask,
366
+ ],
367
+ dim=-1,
368
+ )
369
+
370
+ x = self.pos_enc(x)
371
+
372
+ return x, x_mask
373
+
374
+ def __getitem__(self, key):
375
+ """Get item.
376
+
377
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
378
+ return the positioning encoding.
379
+
380
+ """
381
+ if key != -1:
382
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
383
+ return self.pos_enc
384
+
385
+
386
+ class Conv2dSubsampling(torch.nn.Module):
387
+ """Convolutional 2D subsampling (to 1/4 length).
388
+
389
+ Args:
390
+ idim (int): Input dimension.
391
+ odim (int): Output dimension.
392
+ dropout_rate (float): Dropout rate.
393
+ pos_enc (torch.nn.Module): Custom position encoding layer.
394
+
395
+ """
396
+
397
+ def __init__(self, idim, odim, dropout_rate, pos_enc=None):
398
+ """Construct an Conv2dSubsampling object."""
399
+ super(Conv2dSubsampling, self).__init__()
400
+ self.conv = torch.nn.Sequential(
401
+ torch.nn.Conv2d(1, odim, 3, 2),
402
+ torch.nn.ReLU(),
403
+ torch.nn.Conv2d(odim, odim, 3, 2),
404
+ torch.nn.ReLU(),
405
+ )
406
+ self.out = torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
407
+ self.pos_enc = (
408
+ pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
409
+ )
410
+
411
+ def _load_from_state_dict(
412
+ self,
413
+ state_dict,
414
+ prefix,
415
+ local_metadata,
416
+ strict,
417
+ missing_keys,
418
+ unexpected_keys,
419
+ error_msgs,
420
+ ):
421
+ _upgrade_legacy_subsampling_state_dict(state_dict, prefix)
422
+ super()._load_from_state_dict(
423
+ state_dict,
424
+ prefix,
425
+ local_metadata,
426
+ strict,
427
+ missing_keys,
428
+ unexpected_keys,
429
+ error_msgs,
430
+ )
431
+
432
+ def forward(self, x, x_mask, prefix_embeds=None):
433
+ """Subsample x.
434
+
435
+ Args:
436
+ x (torch.Tensor): Input tensor (#batch, time, idim).
437
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
438
+ prefix_embeds (torch.Tensor or None): Prefix token embeddings
439
+ (#batch, prefix_len, odim).
440
+
441
+ Returns:
442
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
443
+ where time' = time // 4.
444
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
445
+ where time' = time // 4.
446
+
447
+ """
448
+ x = x.unsqueeze(1) # (b, c, t, f)
449
+ x = self.conv(x)
450
+ b, c, t, f = x.size()
451
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
452
+ if x_mask is not None:
453
+ x_mask = x_mask[:, :, :-2:2][:, :, :-2:2]
454
+
455
+ if prefix_embeds is not None:
456
+ x = torch.cat([prefix_embeds, x], dim=1)
457
+ if x_mask is not None:
458
+ x_mask = torch.cat(
459
+ [
460
+ torch.ones(
461
+ x_mask.shape[0],
462
+ 1,
463
+ prefix_embeds.size(1),
464
+ dtype=x_mask.dtype,
465
+ device=x_mask.device,
466
+ ),
467
+ x_mask,
468
+ ],
469
+ dim=-1,
470
+ )
471
+
472
+ x = self.pos_enc(x)
473
+
474
+ return x, x_mask
475
+
476
+ # def __getitem__(self, key):
477
+ # """Get item.
478
+
479
+ # When reset_parameters() is called, if use_scaled_pos_enc is used,
480
+ # return the positioning encoding.
481
+
482
+ # """
483
+ # if key != -1:
484
+ # raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
485
+ # return self.out[key]
486
+
487
+
488
+ class Conv2dSubsampling1(torch.nn.Module):
489
+ """Similar to Conv2dSubsampling module, but without any subsampling performed.
490
+
491
+ Args:
492
+ idim (int): Input dimension.
493
+ odim (int): Output dimension.
494
+ dropout_rate (float): Dropout rate.
495
+ pos_enc (torch.nn.Module): Custom position encoding layer.
496
+
497
+ """
498
+
499
+ def __init__(self, idim, odim, dropout_rate, pos_enc=None):
500
+ """Construct an Conv2dSubsampling1 object."""
501
+ super(Conv2dSubsampling1, self).__init__()
502
+ self.conv = torch.nn.Sequential(
503
+ torch.nn.Conv2d(1, odim, 3, 1),
504
+ torch.nn.ReLU(),
505
+ torch.nn.Conv2d(odim, odim, 3, 1),
506
+ torch.nn.ReLU(),
507
+ )
508
+ self.out = torch.nn.Linear(odim * (idim - 4), odim)
509
+ self.pos_enc = (
510
+ pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
511
+ )
512
+
513
+ def _load_from_state_dict(
514
+ self,
515
+ state_dict,
516
+ prefix,
517
+ local_metadata,
518
+ strict,
519
+ missing_keys,
520
+ unexpected_keys,
521
+ error_msgs,
522
+ ):
523
+ _upgrade_legacy_subsampling_state_dict(state_dict, prefix)
524
+ super()._load_from_state_dict(
525
+ state_dict,
526
+ prefix,
527
+ local_metadata,
528
+ strict,
529
+ missing_keys,
530
+ unexpected_keys,
531
+ error_msgs,
532
+ )
533
+
534
+ def forward(self, x, x_mask, prefix_embeds=None):
535
+ """Pass x through 2 Conv2d layers without subsampling.
536
+
537
+ Args:
538
+ x (torch.Tensor): Input tensor (#batch, time, idim).
539
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
540
+ prefix_embeds (torch.Tensor or None): Prefix token embeddings
541
+ (#batch, prefix_len, odim).
542
+
543
+ Returns:
544
+ torch.Tensor: Subsampled tensor (#batch, time', odim).
545
+ where time' = time - 4.
546
+ torch.Tensor: Subsampled mask (#batch, 1, time').
547
+ where time' = time - 4.
548
+
549
+ """
550
+ x = x.unsqueeze(1) # (b, c, t, f)
551
+ x = self.conv(x)
552
+ b, c, t, f = x.size()
553
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
554
+ if x_mask is not None:
555
+ x_mask = x_mask[:, :, :-4]
556
+
557
+ if prefix_embeds is not None:
558
+ x = torch.cat([prefix_embeds, x], dim=1)
559
+ if x_mask is not None:
560
+ x_mask = torch.cat(
561
+ [
562
+ torch.ones(
563
+ x_mask.shape[0],
564
+ 1,
565
+ prefix_embeds.size(1),
566
+ dtype=x_mask.dtype,
567
+ device=x_mask.device,
568
+ ),
569
+ x_mask,
570
+ ],
571
+ dim=-1,
572
+ )
573
+
574
+ x = self.pos_enc(x)
575
+
576
+ return x, x_mask
577
+
578
+ def __getitem__(self, key):
579
+ """Get item.
580
+
581
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
582
+ return the positioning encoding.
583
+
584
+ """
585
+ if key != -1:
586
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
587
+ return self.pos_enc
588
+
589
+
590
+ class Conv2dSubsampling2(torch.nn.Module):
591
+ """Convolutional 2D subsampling (to 1/2 length).
592
+
593
+ Args:
594
+ idim (int): Input dimension.
595
+ odim (int): Output dimension.
596
+ dropout_rate (float): Dropout rate.
597
+ pos_enc (torch.nn.Module): Custom position encoding layer.
598
+
599
+ """
600
+
601
+ def __init__(self, idim, odim, dropout_rate, pos_enc=None):
602
+ """Construct an Conv2dSubsampling2 object."""
603
+ super(Conv2dSubsampling2, self).__init__()
604
+ self.conv = torch.nn.Sequential(
605
+ torch.nn.Conv2d(1, odim, 3, 2),
606
+ torch.nn.ReLU(),
607
+ torch.nn.Conv2d(odim, odim, 3, 1),
608
+ torch.nn.ReLU(),
609
+ )
610
+ self.out = torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim)
611
+ self.pos_enc = (
612
+ pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
613
+ )
614
+
615
+ def _load_from_state_dict(
616
+ self,
617
+ state_dict,
618
+ prefix,
619
+ local_metadata,
620
+ strict,
621
+ missing_keys,
622
+ unexpected_keys,
623
+ error_msgs,
624
+ ):
625
+ _upgrade_legacy_subsampling_state_dict(state_dict, prefix)
626
+ super()._load_from_state_dict(
627
+ state_dict,
628
+ prefix,
629
+ local_metadata,
630
+ strict,
631
+ missing_keys,
632
+ unexpected_keys,
633
+ error_msgs,
634
+ )
635
+
636
+ def forward(self, x, x_mask, prefix_embeds=None):
637
+ """Subsample x.
638
+
639
+ Args:
640
+ x (torch.Tensor): Input tensor (#batch, time, idim).
641
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
642
+ prefix_embeds (torch.Tensor or None): Prefix token embeddings
643
+ (#batch, prefix_len, odim).
644
+
645
+ Returns:
646
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
647
+ where time' = time // 2.
648
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
649
+ where time' = time // 2.
650
+
651
+ """
652
+ x = x.unsqueeze(1) # (b, c, t, f)
653
+ x = self.conv(x)
654
+ b, c, t, f = x.size()
655
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
656
+ if x_mask is not None:
657
+ x_mask = x_mask[:, :, :-2:2][:, :, :-2:1]
658
+
659
+ if prefix_embeds is not None:
660
+ x = torch.cat([prefix_embeds, x], dim=1)
661
+ if x_mask is not None:
662
+ x_mask = torch.cat(
663
+ [
664
+ torch.ones(
665
+ x_mask.shape[0],
666
+ 1,
667
+ prefix_embeds.size(1),
668
+ dtype=x_mask.dtype,
669
+ device=x_mask.device,
670
+ ),
671
+ x_mask,
672
+ ],
673
+ dim=-1,
674
+ )
675
+
676
+ x = self.pos_enc(x)
677
+
678
+ return x, x_mask
679
+
680
+ def __getitem__(self, key):
681
+ """Get item.
682
+
683
+ When reset_parameters() is called, if use_scaled_pos_enc is used,
684
+ return the positioning encoding.
685
+
686
+ """
687
+ if key != -1:
688
+ raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
689
+ return self.pos_enc
690
+
691
+
692
+ class Conv2dSubsampling6(torch.nn.Module):
693
+ """Convolutional 2D subsampling (to 1/6 length).
694
+
695
+ Args:
696
+ idim (int): Input dimension.
697
+ odim (int): Output dimension.
698
+ dropout_rate (float): Dropout rate.
699
+ pos_enc (torch.nn.Module): Custom position encoding layer.
700
+
701
+ """
702
+
703
+ def __init__(self, idim, odim, dropout_rate, pos_enc=None):
704
+ """Construct an Conv2dSubsampling6 object."""
705
+ super(Conv2dSubsampling6, self).__init__()
706
+ self.conv = torch.nn.Sequential(
707
+ torch.nn.Conv2d(1, odim, 3, 2),
708
+ torch.nn.ReLU(),
709
+ torch.nn.Conv2d(odim, odim, 5, 3),
710
+ torch.nn.ReLU(),
711
+ )
712
+ self.out = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
713
+ self.pos_enc = (
714
+ pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
715
+ )
716
+
717
+ def _load_from_state_dict(
718
+ self,
719
+ state_dict,
720
+ prefix,
721
+ local_metadata,
722
+ strict,
723
+ missing_keys,
724
+ unexpected_keys,
725
+ error_msgs,
726
+ ):
727
+ _upgrade_legacy_subsampling_state_dict(state_dict, prefix)
728
+ super()._load_from_state_dict(
729
+ state_dict,
730
+ prefix,
731
+ local_metadata,
732
+ strict,
733
+ missing_keys,
734
+ unexpected_keys,
735
+ error_msgs,
736
+ )
737
+
738
+ def forward(self, x, x_mask, prefix_embeds=None):
739
+ """Subsample x.
740
+
741
+ Args:
742
+ x (torch.Tensor): Input tensor (#batch, time, idim).
743
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
744
+ prefix_embeds (torch.Tensor or None): Prefix token embeddings
745
+ (#batch, prefix_len, odim).
746
+
747
+ Returns:
748
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
749
+ where time' = time // 6.
750
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
751
+ where time' = time // 6.
752
+
753
+ """
754
+ x = x.unsqueeze(1) # (b, c, t, f)
755
+ x = self.conv(x)
756
+ b, c, t, f = x.size()
757
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
758
+ if x_mask is not None:
759
+ x_mask = x_mask[:, :, :-2:2][:, :, :-4:3]
760
+
761
+ if prefix_embeds is not None:
762
+ x = torch.cat([prefix_embeds, x], dim=1)
763
+ if x_mask is not None:
764
+ x_mask = torch.cat(
765
+ [
766
+ torch.ones(
767
+ x_mask.shape[0],
768
+ 1,
769
+ prefix_embeds.size(1),
770
+ dtype=x_mask.dtype,
771
+ device=x_mask.device,
772
+ ),
773
+ x_mask,
774
+ ],
775
+ dim=-1,
776
+ )
777
+
778
+ x = self.pos_enc(x)
779
+
780
+ return x, x_mask
781
+
782
+
783
+ class Conv2dSubsampling8(torch.nn.Module):
784
+ """Convolutional 2D subsampling (to 1/8 length).
785
+
786
+ Args:
787
+ idim (int): Input dimension.
788
+ odim (int): Output dimension.
789
+ dropout_rate (float): Dropout rate.
790
+ pos_enc (torch.nn.Module): Custom position encoding layer.
791
+
792
+ """
793
+
794
+ def __init__(self, idim, odim, dropout_rate, pos_enc=None):
795
+ """Construct an Conv2dSubsampling8 object."""
796
+ super(Conv2dSubsampling8, self).__init__()
797
+ self.conv = torch.nn.Sequential(
798
+ torch.nn.Conv2d(1, odim, 3, 2),
799
+ torch.nn.ReLU(),
800
+ torch.nn.Conv2d(odim, odim, 3, 2),
801
+ torch.nn.ReLU(),
802
+ torch.nn.Conv2d(odim, odim, 3, 2),
803
+ torch.nn.ReLU(),
804
+ )
805
+ self.out = torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
806
+ self.pos_enc = (
807
+ pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate)
808
+ )
809
+
810
+ def _load_from_state_dict(
811
+ self,
812
+ state_dict,
813
+ prefix,
814
+ local_metadata,
815
+ strict,
816
+ missing_keys,
817
+ unexpected_keys,
818
+ error_msgs,
819
+ ):
820
+ _upgrade_legacy_subsampling_state_dict(state_dict, prefix)
821
+ super()._load_from_state_dict(
822
+ state_dict,
823
+ prefix,
824
+ local_metadata,
825
+ strict,
826
+ missing_keys,
827
+ unexpected_keys,
828
+ error_msgs,
829
+ )
830
+
831
+ def forward(self, x, x_mask, prefix_embeds=None):
832
+ """Subsample x.
833
+
834
+ Args:
835
+ x (torch.Tensor): Input tensor (#batch, time, idim).
836
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
837
+ prefix_embeds (torch.Tensor or None): Prefix token embeddings
838
+ (#batch, prefix_len, odim).
839
+
840
+ Returns:
841
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
842
+ where time' = time // 8.
843
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
844
+ where time' = time // 8.
845
+
846
+ """
847
+ x = x.unsqueeze(1) # (b, c, t, f)
848
+ x = self.conv(x)
849
+ b, c, t, f = x.size()
850
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
851
+ if x_mask is not None:
852
+ x_mask = x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
853
+
854
+ if prefix_embeds is not None:
855
+ x = torch.cat([prefix_embeds, x], dim=1)
856
+ if x_mask is not None:
857
+ x_mask = torch.cat(
858
+ [
859
+ torch.ones(
860
+ x_mask.shape[0],
861
+ 1,
862
+ prefix_embeds.size(1),
863
+ dtype=x_mask.dtype,
864
+ device=x_mask.device,
865
+ ),
866
+ x_mask,
867
+ ],
868
+ dim=-1,
869
+ )
870
+
871
+ x = self.pos_enc(x)
872
+
873
+ return x, x_mask
src/model/__init__.py ADDED
File without changes
src/model/powsm/__init__.py ADDED
File without changes
src/model/powsm/ctc.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from typeguard import typechecked
6
+ from src.utils import RankedLogger
7
+
8
+ log = RankedLogger(__name__, rank_zero_only=True)
9
+
10
+
11
+ class CTC(torch.nn.Module):
12
+ """CTC module.
13
+
14
+ Args:
15
+ odim: dimension of outputs
16
+ encoder_output_size: number of encoder projection units
17
+ dropout_rate: dropout rate (0.0 ~ 1.0)
18
+ ctc_type: builtin or gtnctc
19
+ reduce: reduce the CTC loss into a scalar
20
+ ignore_nan_grad: Same as zero_infinity (keeping for backward compatiblity)
21
+ zero_infinity: Whether to zero infinite losses and the associated gradients.
22
+ """
23
+
24
+ @typechecked
25
+ def __init__(
26
+ self,
27
+ odim: int,
28
+ encoder_output_size: int,
29
+ dropout_rate: float = 0.0,
30
+ ctc_type: str = "builtin",
31
+ reduce: bool = True,
32
+ ignore_nan_grad: Optional[bool] = None,
33
+ zero_infinity: bool = True,
34
+ brctc_risk_strategy: str = "exp",
35
+ brctc_group_strategy: str = "end",
36
+ brctc_risk_factor: float = 0.0,
37
+ ):
38
+ super().__init__()
39
+ eprojs = encoder_output_size
40
+ self.dropout_rate = dropout_rate
41
+ self.ctc_lo = torch.nn.Linear(eprojs, odim)
42
+ self.ctc_type = ctc_type
43
+ if ignore_nan_grad is not None:
44
+ zero_infinity = ignore_nan_grad
45
+
46
+ if self.ctc_type == "builtin":
47
+ self.ctc_loss = torch.nn.CTCLoss(
48
+ reduction="none", zero_infinity=zero_infinity
49
+ )
50
+ elif self.ctc_type == "builtin2":
51
+ self.ignore_nan_grad = True
52
+ log.warning("builtin2")
53
+ self.ctc_loss = torch.nn.CTCLoss(reduction="none")
54
+
55
+ elif self.ctc_type == "gtnctc":
56
+ raise ImportError("gtnctc requires gtn_ctc which is not bundled here.")
57
+
58
+ elif self.ctc_type == "brctc":
59
+ try:
60
+ import k2 # noqa
61
+ except ImportError:
62
+ raise ImportError("You should install K2 to use Bayes Risk CTC")
63
+
64
+ raise ImportError("brctc requires BayesRiskCTC which is not bundled here.")
65
+ else:
66
+ raise ValueError(
67
+ f'ctc_type must be "builtin" or "builtin2": {self.ctc_type}'
68
+ )
69
+
70
+ self.reduce = reduce
71
+
72
+ def loss_fn(
73
+ self,
74
+ th_pred,
75
+ th_target,
76
+ th_ilen,
77
+ th_olen,
78
+ lang_sym: Optional[Union[List[str], None]] = None,
79
+ accent_sym: Optional[Union[List[str], None]] = None,
80
+ ) -> torch.Tensor:
81
+ if self.ctc_type in ["builtin", "brctc"]:
82
+ th_pred = th_pred.log_softmax(2).float()
83
+ loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
84
+
85
+ if self.ctc_type == "builtin":
86
+ size = th_pred.size(1)
87
+ else:
88
+ size = loss.size(0) # some invalid examples will be excluded
89
+
90
+ if self.reduce:
91
+ # Batch-size average
92
+ loss = loss.sum() / size
93
+ else:
94
+ loss = loss / size
95
+ return loss
96
+
97
+ # builtin2 ignores nan losses using the logic below, while
98
+ # builtin relies on the zero_infinity flag in pytorch CTC
99
+ elif self.ctc_type == "builtin2":
100
+ th_pred = th_pred.log_softmax(2).float()
101
+ loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
102
+
103
+ if loss.requires_grad and self.ignore_nan_grad:
104
+ # ctc_grad: (L, B, O)
105
+ ctc_grad = loss.grad_fn(torch.ones_like(loss))
106
+ ctc_grad = ctc_grad.sum([0, 2])
107
+ indices = torch.isfinite(ctc_grad)
108
+ size = indices.long().sum()
109
+ if size == 0:
110
+ # Return as is
111
+ log.warning(
112
+ "All samples in this mini-batch got nan grad."
113
+ " Returning nan value instead of CTC loss"
114
+ )
115
+ elif size != th_pred.size(1):
116
+ log.warning(
117
+ f"{th_pred.size(1) - size}/{th_pred.size(1)}"
118
+ " samples got nan grad."
119
+ " These were ignored for CTC loss."
120
+ )
121
+
122
+ # Create mask for target
123
+ target_mask = torch.full(
124
+ [th_target.size(0)],
125
+ 1,
126
+ dtype=torch.bool,
127
+ device=th_target.device,
128
+ )
129
+ s = 0
130
+ for ind, le in enumerate(th_olen):
131
+ if not indices[ind]:
132
+ target_mask[s : s + le] = 0
133
+ s += le
134
+
135
+ # Calc loss again using maksed data
136
+ loss = self.ctc_loss(
137
+ th_pred[:, indices, :],
138
+ th_target[target_mask],
139
+ th_ilen[indices],
140
+ th_olen[indices],
141
+ )
142
+ else:
143
+ size = th_pred.size(1)
144
+
145
+ if self.reduce:
146
+ # Batch-size average
147
+ loss = loss.sum() / size
148
+ else:
149
+ loss = loss / size
150
+ return loss
151
+
152
+ elif self.ctc_type == "gtnctc":
153
+ log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
154
+ return self.ctc_loss(log_probs, th_target, th_ilen, 0, "none")
155
+
156
+ else:
157
+ raise NotImplementedError
158
+
159
+ def forward(
160
+ self,
161
+ hs_pad,
162
+ hlens,
163
+ ys_pad,
164
+ ys_lens,
165
+ lang_sym: Optional[Union[List[str], None]] = None,
166
+ accent_sym: Optional[Union[List[str], None]] = None,
167
+ ):
168
+ """Calculate CTC loss.
169
+
170
+ Args:
171
+ hs_pad: batch of padded hidden state sequences (B, Tmax, D)
172
+ hlens: batch of lengths of hidden state sequences (B)
173
+ ys_pad: batch of padded character id sequence tensor (B, Lmax)
174
+ ys_lens: batch of lengths of character sequence (B)
175
+ lang_sym: optional list of language codes per utterance
176
+ accent_sym: optional list of accent codes per utterance
177
+ """
178
+ # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
179
+ ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
180
+
181
+ if self.ctc_type == "brctc":
182
+ loss = self.loss_fn(
183
+ ys_hat, ys_pad, hlens, ys_lens, lang_sym=lang_sym, accent_sym=accent_sym
184
+ ).to(device=hs_pad.device, dtype=hs_pad.dtype)
185
+ return loss
186
+
187
+ elif self.ctc_type == "gtnctc":
188
+ # gtn expects list form for ys
189
+ ys_true = [y[y != -1] for y in ys_pad] # parse padded ys
190
+ else:
191
+ # ys_hat: (B, L, D) -> (L, B, D)
192
+ ys_hat = ys_hat.transpose(0, 1)
193
+ # (B, L) -> (BxL,)
194
+ ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
195
+
196
+ loss = self.loss_fn(
197
+ ys_hat, ys_true, hlens, ys_lens, lang_sym=lang_sym, accent_sym=accent_sym
198
+ ).to(device=hs_pad.device, dtype=hs_pad.dtype)
199
+
200
+ return loss
201
+
202
+ def softmax(self, hs_pad):
203
+ """softmax of frame activations
204
+
205
+ Args:
206
+ Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
207
+ Returns:
208
+ torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
209
+ """
210
+ return F.softmax(self.ctc_lo(hs_pad), dim=2)
211
+
212
+ def log_softmax(self, hs_pad):
213
+ """log_softmax of frame activations
214
+
215
+ Args:
216
+ Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
217
+ Returns:
218
+ torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
219
+ """
220
+ return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
221
+
222
+ def argmax(self, hs_pad):
223
+ """argmax of frame activations
224
+
225
+ Args:
226
+ torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
227
+ Returns:
228
+ torch.Tensor: argmax applied 2d tensor (B, Tmax)
229
+ """
230
+ return torch.argmax(self.ctc_lo(hs_pad), dim=2)
src/model/powsm/e_branchformer.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Kwangyoun Kim (ASAPP inc.)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """E-Branchformer encoder definition.
5
+
6
+ Reference:
7
+ Kwangyoun Kim, Felix Wu, Yifan Peng, Jing Pan,
8
+ Prashant Sridhar, Kyu J. Han, Shinji Watanabe,
9
+ "E-Branchformer: Branchformer with Enhanced merging
10
+ for speech recognition," in SLT 2022.
11
+ """
12
+
13
+ import logging
14
+ from typing import List, Optional, Tuple
15
+
16
+ import torch
17
+ from typeguard import typechecked
18
+
19
+ from src.model.powsm.ctc import CTC
20
+ from src.espnet_import.fastformer import FastSelfAttention
21
+ from src.espnet_import.cgmlp import ConvolutionalGatingMLP
22
+
23
+ from src.espnet_import.nets_utils import get_activation, make_pad_mask
24
+ from src.espnet_import.attention import (
25
+ LegacyRelPositionMultiHeadedAttention,
26
+ MultiHeadedAttention,
27
+ RelPositionMultiHeadedAttention,
28
+ )
29
+ from src.espnet_import.embedding import (
30
+ ConvolutionalPositionalEmbedding,
31
+ LegacyRelPositionalEncoding,
32
+ PositionalEncoding,
33
+ RelPositionalEncoding,
34
+ ScaledPositionalEncoding,
35
+ )
36
+ from src.espnet_import.layer_norm import LayerNorm
37
+ from src.espnet_import.positionwise_feed_forward import PositionwiseFeedForward
38
+ from src.espnet_import.repeat import repeat
39
+ from src.espnet_import.subsampling import (
40
+ Conv1dSubsampling1,
41
+ Conv1dSubsampling2,
42
+ Conv1dSubsampling3,
43
+ Conv2dSubsampling,
44
+ Conv2dSubsampling1,
45
+ Conv2dSubsampling2,
46
+ Conv2dSubsampling6,
47
+ Conv2dSubsampling8,
48
+ TooShortUttError,
49
+ check_short_utt,
50
+ )
51
+
52
+
53
+ class EBranchformerEncoderLayer(torch.nn.Module):
54
+ """E-Branchformer encoder layer module.
55
+
56
+ Args:
57
+ size (int): model dimension
58
+ attn: standard self-attention or efficient attention
59
+ cgmlp: ConvolutionalGatingMLP
60
+ feed_forward: feed-forward module, optional
61
+ feed_forward: macaron-style feed-forward module, optional
62
+ dropout_rate (float): dropout probability
63
+ merge_conv_kernel (int): kernel size of the depth-wise conv in merge module
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ size: int,
69
+ attn: torch.nn.Module,
70
+ cgmlp: torch.nn.Module,
71
+ feed_forward: Optional[torch.nn.Module],
72
+ feed_forward_macaron: Optional[torch.nn.Module],
73
+ dropout_rate: float,
74
+ merge_conv_kernel: int = 3,
75
+ ):
76
+ super().__init__()
77
+
78
+ self.size = size
79
+ self.attn = attn
80
+ self.cgmlp = cgmlp
81
+
82
+ self.feed_forward = feed_forward
83
+ self.feed_forward_macaron = feed_forward_macaron
84
+ self.ff_scale = 1.0
85
+ if self.feed_forward is not None:
86
+ self.norm_ff = LayerNorm(size)
87
+ if self.feed_forward_macaron is not None:
88
+ self.ff_scale = 0.5
89
+ self.norm_ff_macaron = LayerNorm(size)
90
+
91
+ self.norm_mha = LayerNorm(size) # for the MHA module
92
+ self.norm_mlp = LayerNorm(size) # for the MLP module
93
+ self.norm_final = LayerNorm(size) # for the final output of the block
94
+
95
+ self.dropout = torch.nn.Dropout(dropout_rate)
96
+
97
+ self.depthwise_conv_fusion = torch.nn.Conv1d(
98
+ size + size,
99
+ size + size,
100
+ kernel_size=merge_conv_kernel,
101
+ stride=1,
102
+ padding=(merge_conv_kernel - 1) // 2,
103
+ groups=size + size,
104
+ bias=True,
105
+ )
106
+ self.merge_proj = torch.nn.Linear(size + size, size)
107
+
108
+ def forward(self, x_input, mask, cache=None):
109
+ """Compute encoded features.
110
+
111
+ Args:
112
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
113
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
114
+ - w/o pos emb: Tensor (#batch, time, size).
115
+ mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
116
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
117
+ Returns:
118
+ torch.Tensor: Output tensor (#batch, time, size).
119
+ torch.Tensor: Mask tensor (#batch, time).
120
+ """
121
+
122
+ if cache is not None:
123
+ raise NotImplementedError("cache is not None, which is not tested")
124
+
125
+ if isinstance(x_input, tuple):
126
+ x, pos_emb = x_input[0], x_input[1]
127
+ else:
128
+ x, pos_emb = x_input, None
129
+
130
+ if self.feed_forward_macaron is not None:
131
+ residual = x
132
+ x = self.norm_ff_macaron(x)
133
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
134
+
135
+ # Two branches
136
+ x1 = x
137
+ x2 = x
138
+
139
+ # Branch 1: multi-headed attention module
140
+ x1 = self.norm_mha(x1)
141
+
142
+ if isinstance(self.attn, FastSelfAttention):
143
+ x_att = self.attn(x1, mask)
144
+ else:
145
+ if pos_emb is not None:
146
+ x_att = self.attn(x1, x1, x1, pos_emb, mask)
147
+ else:
148
+ x_att = self.attn(x1, x1, x1, mask)
149
+
150
+ x1 = self.dropout(x_att)
151
+
152
+ # Branch 2: convolutional gating mlp
153
+ x2 = self.norm_mlp(x2)
154
+
155
+ if pos_emb is not None:
156
+ x2 = (x2, pos_emb)
157
+ x2 = self.cgmlp(x2, mask)
158
+ if isinstance(x2, tuple):
159
+ x2 = x2[0]
160
+
161
+ x2 = self.dropout(x2)
162
+
163
+ # Merge two branches
164
+ x_concat = torch.cat([x1, x2], dim=-1)
165
+ x_tmp = x_concat.transpose(1, 2)
166
+ x_tmp = self.depthwise_conv_fusion(x_tmp)
167
+ x_tmp = x_tmp.transpose(1, 2)
168
+ x = x + self.dropout(self.merge_proj(x_concat + x_tmp))
169
+
170
+ if self.feed_forward is not None:
171
+ # feed forward module
172
+ residual = x
173
+ x = self.norm_ff(x)
174
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
175
+
176
+ x = self.norm_final(x)
177
+
178
+ if pos_emb is not None:
179
+ return (x, pos_emb), mask
180
+
181
+ return x, mask
182
+
183
+
184
+ class EBranchformerEncoder(torch.nn.Module):
185
+ """E-Branchformer encoder module."""
186
+
187
+ @typechecked
188
+ def __init__(
189
+ self,
190
+ input_size: int,
191
+ output_size: int = 256,
192
+ attention_heads: int = 4,
193
+ attention_layer_type: str = "rel_selfattn",
194
+ pos_enc_layer_type: str = "rel_pos",
195
+ rel_pos_type: str = "latest",
196
+ cgmlp_linear_units: int = 2048,
197
+ cgmlp_conv_kernel: int = 31,
198
+ use_linear_after_conv: bool = False,
199
+ gate_activation: str = "identity",
200
+ num_blocks: int = 12,
201
+ dropout_rate: float = 0.1,
202
+ positional_dropout_rate: float = 0.1,
203
+ attention_dropout_rate: float = 0.0,
204
+ input_layer: Optional[str] = "conv2d",
205
+ zero_triu: bool = False,
206
+ padding_idx: int = -1,
207
+ layer_drop_rate: float = 0.0,
208
+ max_pos_emb_len: int = 5000,
209
+ use_ffn: bool = False,
210
+ macaron_ffn: bool = False,
211
+ ffn_activation_type: str = "swish",
212
+ linear_units: int = 2048,
213
+ positionwise_layer_type: str = "linear",
214
+ merge_conv_kernel: int = 3,
215
+ interctc_layer_idx=None,
216
+ interctc_use_conditioning: bool = False,
217
+ qk_norm: bool = False,
218
+ use_flash_attn: bool = True,
219
+ gradient_checkpoint_layers: List[int] = [],
220
+ ):
221
+ super().__init__()
222
+ self._output_size = output_size
223
+
224
+ if rel_pos_type == "legacy":
225
+ if pos_enc_layer_type == "rel_pos":
226
+ pos_enc_layer_type = "legacy_rel_pos"
227
+ if attention_layer_type == "rel_selfattn":
228
+ attention_layer_type = "legacy_rel_selfattn"
229
+ elif rel_pos_type == "latest":
230
+ assert attention_layer_type != "legacy_rel_selfattn"
231
+ assert pos_enc_layer_type != "legacy_rel_pos"
232
+ else:
233
+ raise ValueError("unknown rel_pos_type: " + rel_pos_type)
234
+
235
+ if pos_enc_layer_type == "abs_pos":
236
+ pos_enc_class = PositionalEncoding
237
+ elif pos_enc_layer_type == "conv":
238
+ pos_enc_class = ConvolutionalPositionalEmbedding
239
+ elif pos_enc_layer_type == "scaled_abs_pos":
240
+ pos_enc_class = ScaledPositionalEncoding
241
+ elif pos_enc_layer_type == "rel_pos":
242
+ assert attention_layer_type == "rel_selfattn"
243
+ pos_enc_class = RelPositionalEncoding
244
+ elif pos_enc_layer_type == "legacy_rel_pos":
245
+ assert attention_layer_type == "legacy_rel_selfattn"
246
+ pos_enc_class = LegacyRelPositionalEncoding
247
+ logging.warning(
248
+ "Using legacy_rel_pos and it will be deprecated in the future."
249
+ )
250
+ else:
251
+ raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
252
+
253
+ if input_layer == "linear":
254
+ self.embed = torch.nn.Sequential(
255
+ torch.nn.Linear(input_size, output_size),
256
+ torch.nn.LayerNorm(output_size),
257
+ torch.nn.Dropout(dropout_rate),
258
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
259
+ )
260
+ elif input_layer == "conv1d1":
261
+ self.embed = Conv1dSubsampling1(
262
+ input_size,
263
+ output_size,
264
+ dropout_rate,
265
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
266
+ )
267
+ elif input_layer == "conv1d2":
268
+ self.embed = Conv1dSubsampling2(
269
+ input_size,
270
+ output_size,
271
+ dropout_rate,
272
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
273
+ )
274
+ elif input_layer == "conv1d3":
275
+ self.embed = Conv1dSubsampling3(
276
+ input_size,
277
+ output_size,
278
+ dropout_rate,
279
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
280
+ )
281
+ elif input_layer == "conv2d":
282
+ self.embed = Conv2dSubsampling(
283
+ input_size,
284
+ output_size,
285
+ dropout_rate,
286
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
287
+ )
288
+ elif input_layer == "conv2d1":
289
+ self.embed = Conv2dSubsampling1(
290
+ input_size,
291
+ output_size,
292
+ dropout_rate,
293
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
294
+ )
295
+ elif input_layer == "conv2d2":
296
+ self.embed = Conv2dSubsampling2(
297
+ input_size,
298
+ output_size,
299
+ dropout_rate,
300
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
301
+ )
302
+ elif input_layer == "conv2d6":
303
+ self.embed = Conv2dSubsampling6(
304
+ input_size,
305
+ output_size,
306
+ dropout_rate,
307
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
308
+ )
309
+ elif input_layer == "conv2d8":
310
+ self.embed = Conv2dSubsampling8(
311
+ input_size,
312
+ output_size,
313
+ dropout_rate,
314
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
315
+ )
316
+ elif input_layer == "embed":
317
+ self.embed = torch.nn.Sequential(
318
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
319
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
320
+ )
321
+ elif isinstance(input_layer, torch.nn.Module):
322
+ self.embed = torch.nn.Sequential(
323
+ input_layer,
324
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
325
+ )
326
+ elif input_layer is None:
327
+ if input_size == output_size:
328
+ self.embed = torch.nn.Sequential(
329
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len)
330
+ )
331
+ else:
332
+ self.embed = torch.nn.Linear(input_size, output_size)
333
+ else:
334
+ raise ValueError("unknown input_layer: " + input_layer)
335
+
336
+ activation = get_activation(ffn_activation_type)
337
+ if positionwise_layer_type == "linear":
338
+ positionwise_layer = PositionwiseFeedForward
339
+ positionwise_layer_args = (
340
+ output_size,
341
+ linear_units,
342
+ dropout_rate,
343
+ activation,
344
+ )
345
+ elif positionwise_layer_type is None:
346
+ logging.warning("no macaron ffn")
347
+ else:
348
+ raise ValueError("Support only linear.")
349
+
350
+ if attention_layer_type == "selfattn":
351
+ # Default to flash attention unless overrided by user
352
+ if use_flash_attn:
353
+ try:
354
+ import flash_attn_interface # noqa
355
+ except Exception:
356
+ use_flash_attn = False
357
+ encoder_selfattn_layer = MultiHeadedAttention
358
+ encoder_selfattn_layer_args = (
359
+ attention_heads,
360
+ output_size,
361
+ attention_dropout_rate,
362
+ qk_norm,
363
+ use_flash_attn,
364
+ False,
365
+ False,
366
+ )
367
+ elif attention_layer_type == "legacy_rel_selfattn":
368
+ assert pos_enc_layer_type == "legacy_rel_pos"
369
+ encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
370
+ encoder_selfattn_layer_args = (
371
+ attention_heads,
372
+ output_size,
373
+ attention_dropout_rate,
374
+ )
375
+ logging.warning(
376
+ "Using legacy_rel_selfattn and it will be deprecated in the future."
377
+ )
378
+ elif attention_layer_type == "rel_selfattn":
379
+ assert pos_enc_layer_type == "rel_pos"
380
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
381
+ encoder_selfattn_layer_args = (
382
+ attention_heads,
383
+ output_size,
384
+ attention_dropout_rate,
385
+ zero_triu,
386
+ )
387
+ elif attention_layer_type == "fast_selfattn":
388
+ assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
389
+ encoder_selfattn_layer = FastSelfAttention
390
+ encoder_selfattn_layer_args = (
391
+ output_size,
392
+ attention_heads,
393
+ attention_dropout_rate,
394
+ )
395
+ else:
396
+ raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
397
+
398
+ cgmlp_layer = ConvolutionalGatingMLP
399
+ cgmlp_layer_args = (
400
+ output_size,
401
+ cgmlp_linear_units,
402
+ cgmlp_conv_kernel,
403
+ dropout_rate,
404
+ use_linear_after_conv,
405
+ gate_activation,
406
+ )
407
+
408
+ self.encoders = repeat(
409
+ num_blocks,
410
+ lambda lnum: EBranchformerEncoderLayer(
411
+ output_size,
412
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
413
+ cgmlp_layer(*cgmlp_layer_args),
414
+ positionwise_layer(*positionwise_layer_args) if use_ffn else None,
415
+ (
416
+ positionwise_layer(*positionwise_layer_args)
417
+ if use_ffn and macaron_ffn
418
+ else None
419
+ ),
420
+ dropout_rate,
421
+ merge_conv_kernel,
422
+ ),
423
+ layer_drop_rate,
424
+ )
425
+ self.after_norm = LayerNorm(output_size)
426
+
427
+ self.layer_drop_rate = layer_drop_rate
428
+
429
+ if interctc_layer_idx is None:
430
+ interctc_layer_idx = []
431
+ self.interctc_layer_idx = interctc_layer_idx
432
+ if len(interctc_layer_idx) > 0:
433
+ assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
434
+ self.interctc_use_conditioning = interctc_use_conditioning
435
+ self.conditioning_layer = None
436
+
437
+ # For gradient checkpointing
438
+ # 0 is the embedding layer, 1 is the first encoder layer, etc.
439
+ self.gradient_checkpoint_layers = gradient_checkpoint_layers
440
+ # logging.info(f"Gradient checkpoint layers: {self.gradient_checkpoint_layers}")
441
+
442
+ def output_size(self) -> int:
443
+ return self._output_size
444
+
445
+ def forward(
446
+ self,
447
+ xs_pad: torch.Tensor,
448
+ ilens: torch.Tensor,
449
+ prev_states: torch.Tensor = None,
450
+ masks: torch.Tensor = None,
451
+ ctc: CTC = None,
452
+ max_layer: int = None,
453
+ return_all_hs: bool = False,
454
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
455
+ """Calculate forward propagation.
456
+
457
+ Args:
458
+ xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
459
+ ilens (torch.Tensor): Input length (#batch).
460
+ prev_states (torch.Tensor): Not to be used now.
461
+ ctc (CTC): Intermediate CTC module.
462
+ max_layer (int): Layer depth below which InterCTC is applied.
463
+ Returns:
464
+ torch.Tensor: Output tensor (#batch, L, output_size).
465
+ torch.Tensor: Output length (#batch).
466
+ torch.Tensor: Not to be used now.
467
+ """
468
+
469
+ if masks is None:
470
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
471
+ else:
472
+ masks = ~masks[:, None, :]
473
+
474
+ if (
475
+ isinstance(self.embed, Conv2dSubsampling)
476
+ or isinstance(self.embed, Conv1dSubsampling1)
477
+ or isinstance(self.embed, Conv1dSubsampling2)
478
+ or isinstance(self.embed, Conv1dSubsampling3)
479
+ or isinstance(self.embed, Conv2dSubsampling1)
480
+ or isinstance(self.embed, Conv2dSubsampling2)
481
+ or isinstance(self.embed, Conv2dSubsampling6)
482
+ or isinstance(self.embed, Conv2dSubsampling8)
483
+ ):
484
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
485
+ if short_status:
486
+ raise TooShortUttError(
487
+ f"has {xs_pad.size(1)} frames and is too short for subsampling "
488
+ + f"(it needs more than {limit_size} frames), return empty results",
489
+ xs_pad.size(1),
490
+ limit_size,
491
+ )
492
+ if 0 in self.gradient_checkpoint_layers:
493
+ xs_pad, masks = torch.utils.checkpoint.checkpoint(
494
+ self.embed, xs_pad, masks, use_reentrant=False
495
+ )
496
+ else:
497
+ xs_pad, masks = self.embed(xs_pad, masks)
498
+ elif self.embed is not None:
499
+ if 0 in self.gradient_checkpoint_layers:
500
+ xs_pad = torch.utils.checkpoint.checkpoint(
501
+ self.embed, xs_pad, use_reentrant=False
502
+ )
503
+ else:
504
+ xs_pad = self.embed(xs_pad)
505
+
506
+ intermediate_outs = []
507
+ for layer_idx, encoder_layer in enumerate(self.encoders):
508
+ if max_layer is not None and layer_idx >= max_layer:
509
+ break
510
+
511
+ if (
512
+ self.training
513
+ and torch.empty(1).uniform_().item() < self.layer_drop_rate
514
+ ):
515
+ continue
516
+
517
+ if layer_idx + 1 in self.gradient_checkpoint_layers:
518
+ xs_pad, masks = torch.utils.checkpoint.checkpoint(
519
+ encoder_layer, xs_pad, masks, use_reentrant=False
520
+ )
521
+ else:
522
+ xs_pad, masks = encoder_layer(xs_pad, masks)
523
+
524
+ if return_all_hs:
525
+ if isinstance(xs_pad, tuple):
526
+ intermediate_outs.append(xs_pad[0])
527
+ else:
528
+ intermediate_outs.append(xs_pad)
529
+
530
+ elif layer_idx + 1 in self.interctc_layer_idx:
531
+ encoder_out = xs_pad
532
+
533
+ if isinstance(encoder_out, tuple):
534
+ encoder_out = encoder_out[0]
535
+
536
+ intermediate_outs.append((layer_idx + 1, encoder_out))
537
+
538
+ if self.interctc_use_conditioning:
539
+ ctc_out = ctc.softmax(encoder_out)
540
+
541
+ if isinstance(xs_pad, tuple):
542
+ xs_pad = list(xs_pad)
543
+ xs_pad[0] = xs_pad[0] + self.conditioning_layer(ctc_out)
544
+ xs_pad = tuple(xs_pad)
545
+ else:
546
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
547
+
548
+ if isinstance(xs_pad, tuple):
549
+ xs_pad = xs_pad[0]
550
+
551
+ xs_pad = self.after_norm(xs_pad)
552
+ olens = masks.squeeze(1).sum(1)
553
+ if len(intermediate_outs) > 0:
554
+ return (xs_pad, intermediate_outs), olens, None
555
+ return xs_pad, olens, None
src/model/powsm/specaug.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SpecAugment module."""
2
+
3
+ from typing import Optional, Sequence, Union
4
+ import math
5
+ from typeguard import typechecked
6
+ import torch
7
+ from src.espnet_import.nets_utils import pad_list
8
+
9
+ DEFAULT_TIME_WARP_MODE = "bicubic"
10
+
11
+
12
+ def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
13
+ """Time warping using torch.interpolate.
14
+
15
+ Args:
16
+ x: (Batch, Time, Freq)
17
+ window: time warp parameter
18
+ mode: Interpolate mode
19
+ """
20
+
21
+ # bicubic supports 4D or more dimension tensor
22
+ org_size = x.size()
23
+ if x.dim() == 3:
24
+ # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
25
+ x = x[:, None]
26
+
27
+ t = x.shape[2]
28
+ if t - window <= window:
29
+ return x.view(*org_size)
30
+
31
+ center = torch.randint(window, t - window, (1,))[0]
32
+ warped = torch.randint(center - window, center + window, (1,))[0] + 1
33
+
34
+ # left: (Batch, Channel, warped, Freq)
35
+ # right: (Batch, Channel, time - warped, Freq)
36
+ left = torch.nn.functional.interpolate(
37
+ x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
38
+ )
39
+ right = torch.nn.functional.interpolate(
40
+ x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
41
+ )
42
+
43
+ if x.requires_grad:
44
+ x = torch.cat([left, right], dim=-2)
45
+ else:
46
+ x[:, :, :warped] = left
47
+ x[:, :, warped:] = right
48
+
49
+ return x.view(*org_size)
50
+
51
+
52
+ def mask_along_axis(
53
+ spec: torch.Tensor,
54
+ spec_lengths: torch.Tensor,
55
+ mask_width_range: Sequence[int] = (0, 30),
56
+ dim: int = 1,
57
+ num_mask: int = 2,
58
+ replace_with_zero: bool = True,
59
+ ):
60
+ """Apply mask along the specified direction.
61
+
62
+ Args:
63
+ spec: (Batch, Length, Freq)
64
+ spec_lengths: (Length): Not using lengths in this implementation
65
+ mask_width_range: Select the width randomly between this range
66
+ """
67
+
68
+ org_size = spec.size()
69
+ if spec.dim() == 4:
70
+ # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
71
+ spec = spec.view(-1, spec.size(2), spec.size(3))
72
+
73
+ B = spec.shape[0]
74
+ # D = Length or Freq
75
+ D = spec.shape[dim]
76
+ # mask_length: (B, num_mask, 1)
77
+ mask_length = torch.randint(
78
+ mask_width_range[0],
79
+ mask_width_range[1],
80
+ (B, num_mask),
81
+ device=spec.device,
82
+ ).unsqueeze(2)
83
+
84
+ # mask_pos: (B, num_mask, 1)
85
+ mask_pos = torch.randint(
86
+ 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
87
+ ).unsqueeze(2)
88
+
89
+ # aran: (1, 1, D)
90
+ aran = torch.arange(D, device=spec.device)[None, None, :]
91
+ # mask: (Batch, num_mask, D)
92
+ mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
93
+ # Multiply masks: (Batch, num_mask, D) -> (Batch, D)
94
+ mask = mask.any(dim=1)
95
+ if dim == 1:
96
+ # mask: (Batch, Length, 1)
97
+ mask = mask.unsqueeze(2)
98
+ elif dim == 2:
99
+ # mask: (Batch, 1, Freq)
100
+ mask = mask.unsqueeze(1)
101
+
102
+ if replace_with_zero:
103
+ value = 0.0
104
+ else:
105
+ value = spec.mean()
106
+
107
+ if spec.requires_grad:
108
+ spec = spec.masked_fill(mask, value)
109
+ else:
110
+ spec = spec.masked_fill_(mask, value)
111
+ spec = spec.view(*org_size)
112
+ return spec, spec_lengths
113
+
114
+
115
+ class TimeWarp(torch.nn.Module):
116
+ """Time warping using torch.interpolate.
117
+
118
+ Args:
119
+ window: time warp parameter
120
+ mode: Interpolate mode
121
+ """
122
+
123
+ def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
124
+ super().__init__()
125
+ self.window = window
126
+ self.mode = mode
127
+
128
+ def extra_repr(self):
129
+ return f"window={self.window}, mode={self.mode}"
130
+
131
+ def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
132
+ """Forward function.
133
+
134
+ Args:
135
+ x: (Batch, Time, Freq)
136
+ x_lengths: (Batch,)
137
+ """
138
+
139
+ if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
140
+ # Note that applying same warping for each sample
141
+ y = time_warp(x, window=self.window, mode=self.mode)
142
+ else:
143
+ # FIXME(kamo): I have no idea to batchify Timewarp
144
+ ys = []
145
+ for i in range(x.size(0)):
146
+ _y = time_warp(
147
+ x[i][None, : x_lengths[i]],
148
+ window=self.window,
149
+ mode=self.mode,
150
+ )[0]
151
+ ys.append(_y)
152
+ y = pad_list(ys, 0.0)
153
+
154
+ return y, x_lengths
155
+
156
+
157
+ class MaskAlongAxis(torch.nn.Module):
158
+ @typechecked
159
+ def __init__(
160
+ self,
161
+ mask_width_range: Union[int, Sequence[int]] = (0, 30),
162
+ num_mask: int = 2,
163
+ dim: Union[int, str] = "time",
164
+ replace_with_zero: bool = True,
165
+ ):
166
+ if isinstance(mask_width_range, int):
167
+ mask_width_range = (0, mask_width_range)
168
+ if len(mask_width_range) != 2:
169
+ raise TypeError(
170
+ f"mask_width_range must be a tuple of int and int values: "
171
+ f"{mask_width_range}",
172
+ )
173
+
174
+ assert mask_width_range[1] > mask_width_range[0]
175
+ if isinstance(dim, str):
176
+ if dim == "time":
177
+ dim = 1
178
+ elif dim == "freq":
179
+ dim = 2
180
+ else:
181
+ raise ValueError("dim must be int, 'time' or 'freq'")
182
+ if dim == 1:
183
+ self.mask_axis = "time"
184
+ elif dim == 2:
185
+ self.mask_axis = "freq"
186
+ else:
187
+ self.mask_axis = "unknown"
188
+
189
+ super().__init__()
190
+ self.mask_width_range = mask_width_range
191
+ self.num_mask = num_mask
192
+ self.dim = dim
193
+ self.replace_with_zero = replace_with_zero
194
+
195
+ def extra_repr(self):
196
+ return (
197
+ f"mask_width_range={self.mask_width_range}, "
198
+ f"num_mask={self.num_mask}, axis={self.mask_axis}"
199
+ )
200
+
201
+ def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
202
+ """Forward function.
203
+
204
+ Args:
205
+ spec: (Batch, Length, Freq)
206
+ """
207
+
208
+ return mask_along_axis(
209
+ spec,
210
+ spec_lengths,
211
+ mask_width_range=self.mask_width_range,
212
+ dim=self.dim,
213
+ num_mask=self.num_mask,
214
+ replace_with_zero=self.replace_with_zero,
215
+ )
216
+
217
+
218
+ class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
219
+ """Mask input spec along a specified axis with variable maximum width.
220
+
221
+ Formula:
222
+ max_width = max_width_ratio * seq_len
223
+ """
224
+
225
+ @typechecked
226
+ def __init__(
227
+ self,
228
+ mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
229
+ num_mask: int = 2,
230
+ dim: Union[int, str] = "time",
231
+ replace_with_zero: bool = True,
232
+ ):
233
+ if isinstance(mask_width_ratio_range, float):
234
+ mask_width_ratio_range = (0.0, mask_width_ratio_range)
235
+ if len(mask_width_ratio_range) != 2:
236
+ raise TypeError(
237
+ f"mask_width_ratio_range must be a tuple of float and float values: "
238
+ f"{mask_width_ratio_range}",
239
+ )
240
+
241
+ assert mask_width_ratio_range[1] > mask_width_ratio_range[0]
242
+ if isinstance(dim, str):
243
+ if dim == "time":
244
+ dim = 1
245
+ elif dim == "freq":
246
+ dim = 2
247
+ else:
248
+ raise ValueError("dim must be int, 'time' or 'freq'")
249
+ if dim == 1:
250
+ self.mask_axis = "time"
251
+ elif dim == 2:
252
+ self.mask_axis = "freq"
253
+ else:
254
+ self.mask_axis = "unknown"
255
+
256
+ super().__init__()
257
+ self.mask_width_ratio_range = mask_width_ratio_range
258
+ self.num_mask = num_mask
259
+ self.dim = dim
260
+ self.replace_with_zero = replace_with_zero
261
+
262
+ def extra_repr(self):
263
+ return (
264
+ f"mask_width_ratio_range={self.mask_width_ratio_range}, "
265
+ f"num_mask={self.num_mask}, axis={self.mask_axis}"
266
+ )
267
+
268
+ def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
269
+ """Forward function.
270
+
271
+ Args:
272
+ spec: (Batch, Length, Freq)
273
+ """
274
+
275
+ max_seq_len = spec.shape[self.dim]
276
+ min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0])
277
+ min_mask_width = max([0, min_mask_width])
278
+ max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1])
279
+ max_mask_width = min([max_seq_len, max_mask_width])
280
+
281
+ if max_mask_width > min_mask_width:
282
+ return mask_along_axis(
283
+ spec,
284
+ spec_lengths,
285
+ mask_width_range=(min_mask_width, max_mask_width),
286
+ dim=self.dim,
287
+ num_mask=self.num_mask,
288
+ replace_with_zero=self.replace_with_zero,
289
+ )
290
+ return spec, spec_lengths
291
+
292
+
293
+ class SpecAug(torch.nn.Module):
294
+ """Implementation of SpecAug.
295
+
296
+ Reference:
297
+ Daniel S. Park et al.
298
+ "SpecAugment: A Simple Data
299
+ Augmentation Method for Automatic Speech Recognition"
300
+
301
+ .. warning::
302
+ When using cuda mode, time_warp doesn't have reproducibility
303
+ due to `torch.nn.functional.interpolate`.
304
+
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ apply_time_warp: bool = True,
310
+ time_warp_window: int = 5,
311
+ time_warp_mode: str = "bicubic",
312
+ apply_freq_mask: bool = True,
313
+ freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
314
+ num_freq_mask: int = 2,
315
+ apply_time_mask: bool = True,
316
+ time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
317
+ time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
318
+ num_time_mask: int = 2,
319
+ replace_with_zero: bool = True,
320
+ ):
321
+ if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
322
+ raise ValueError(
323
+ "Either one of time_warp, time_mask, or freq_mask should be applied"
324
+ )
325
+ if (
326
+ apply_time_mask
327
+ and (time_mask_width_range is not None)
328
+ and (time_mask_width_ratio_range is not None)
329
+ ):
330
+ raise ValueError(
331
+ 'Either one of "time_mask_width_range" or '
332
+ '"time_mask_width_ratio_range" can be used'
333
+ )
334
+ super().__init__()
335
+ self.apply_time_warp = apply_time_warp
336
+ self.apply_freq_mask = apply_freq_mask
337
+ self.apply_time_mask = apply_time_mask
338
+
339
+ if apply_time_warp:
340
+ self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
341
+ else:
342
+ self.time_warp = None
343
+
344
+ if apply_freq_mask:
345
+ self.freq_mask = MaskAlongAxis(
346
+ dim="freq",
347
+ mask_width_range=freq_mask_width_range,
348
+ num_mask=num_freq_mask,
349
+ replace_with_zero=replace_with_zero,
350
+ )
351
+ else:
352
+ self.freq_mask = None
353
+
354
+ if apply_time_mask:
355
+ if time_mask_width_range is not None:
356
+ self.time_mask = MaskAlongAxis(
357
+ dim="time",
358
+ mask_width_range=time_mask_width_range,
359
+ num_mask=num_time_mask,
360
+ replace_with_zero=replace_with_zero,
361
+ )
362
+ elif time_mask_width_ratio_range is not None:
363
+ self.time_mask = MaskAlongAxisVariableMaxWidth(
364
+ dim="time",
365
+ mask_width_ratio_range=time_mask_width_ratio_range,
366
+ num_mask=num_time_mask,
367
+ replace_with_zero=replace_with_zero,
368
+ )
369
+ else:
370
+ raise ValueError(
371
+ 'Either one of "time_mask_width_range" or '
372
+ '"time_mask_width_ratio_range" should be used.'
373
+ )
374
+ else:
375
+ self.time_mask = None
376
+
377
+ def forward(self, x, x_lengths=None):
378
+ if self.time_warp is not None:
379
+ x, x_lengths = self.time_warp(x, x_lengths)
380
+ if self.freq_mask is not None:
381
+ x, x_lengths = self.freq_mask(x, x_lengths)
382
+ if self.time_mask is not None:
383
+ x, x_lengths = self.time_mask(x, x_lengths)
384
+ return x, x_lengths
src/model/powsm/utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import warnings
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def force_gatherable(data, device):
9
+ """Change object to gatherable in torch.nn.DataParallel recursively
10
+
11
+ The restriction to the returned value in DataParallel:
12
+ The object must be
13
+ - torch.cuda.Tensor
14
+ - 1 or more dimension. 0-dimension-tensor sends warning.
15
+ or a list, tuple, dict.
16
+
17
+ """
18
+ if isinstance(data, dict):
19
+ return {k: force_gatherable(v, device) for k, v in data.items()}
20
+ # DataParallel can't handle NamedTuple well
21
+ elif isinstance(data, tuple) and type(data) is not tuple:
22
+ return type(data)(*[force_gatherable(o, device) for o in data])
23
+ elif isinstance(data, (list, tuple, set)):
24
+ return type(data)(force_gatherable(v, device) for v in data)
25
+ elif isinstance(data, np.ndarray):
26
+ return force_gatherable(torch.from_numpy(data), device)
27
+ elif isinstance(data, torch.Tensor):
28
+ if data.dim() == 0:
29
+ # To 1-dim array
30
+ data = data[None]
31
+ return data.to(device)
32
+ elif isinstance(data, float):
33
+ return torch.tensor([data], dtype=torch.float, device=device)
34
+ elif isinstance(data, int):
35
+ return torch.tensor([data], dtype=torch.long, device=device)
36
+ elif data is None:
37
+ return None
38
+ else:
39
+ warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
40
+ return data
41
+
42
+
43
+ def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
44
+ """Change the device of object recursively"""
45
+ if isinstance(data, dict):
46
+ return {
47
+ k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()
48
+ }
49
+ elif dataclasses.is_dataclass(data) and not isinstance(data, type):
50
+ return type(data)(
51
+ *[
52
+ to_device(v, device, dtype, non_blocking, copy)
53
+ for v in dataclasses.astuple(data)
54
+ ]
55
+ )
56
+ # maybe namedtuple. I don't know the correct way to judge namedtuple.
57
+ elif isinstance(data, tuple) and type(data) is not tuple:
58
+ return type(data)(
59
+ *[to_device(o, device, dtype, non_blocking, copy) for o in data]
60
+ )
61
+ elif isinstance(data, (list, tuple)):
62
+ return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
63
+ elif isinstance(data, np.ndarray):
64
+ return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
65
+ elif isinstance(data, torch.Tensor):
66
+ if dtype is not None:
67
+ dtype = str(dtype).removeprefix("torch.")
68
+ cur_dtype = str(data.dtype).removeprefix("torch.")
69
+
70
+ if not (
71
+ ("int" in dtype and "int" in cur_dtype)
72
+ or ("float" in dtype and "float" in cur_dtype)
73
+ ):
74
+ dtype = None # avoid conversion between int and float.
75
+ else:
76
+ dtype = getattr(torch, dtype)
77
+
78
+ return data.to(device, dtype, non_blocking, copy)
79
+ else:
80
+ return data
src/model/xeusphoneme/__init__.py ADDED
File without changes
src/model/xeusphoneme/builders.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from pathlib import Path
3
+ from typing import Dict, Optional, Tuple
4
+ import argparse
5
+ import yaml
6
+ import json
7
+ import torch
8
+
9
+ from src.model.powsm.specaug import SpecAug
10
+ from src.model.powsm.e_branchformer import EBranchformerEncoder
11
+ from src.model.xeusphoneme.cnn_frontend import CNNFrontend as Wav2VecCNN
12
+ from src.model.xeusphoneme.linear_layer import LinearProjection
13
+ from src.core.utils import download_hf_snapshot
14
+ from src.model.xeusphoneme.xeuspr_model import XeusPRModel
15
+ from src.model.xeusphoneme.xeuspr_inference import XeusPRInference
16
+ from src.model.powsm.ctc import CTC
17
+ from src.utils import RankedLogger
18
+
19
+
20
+ log = RankedLogger(__name__, rank_zero_only=False)
21
+
22
+
23
+ class XeusPRTokenizer:
24
+ """Tokenizer that maps IPA phones to IDs using the xeuspr ipa_vocab.json."""
25
+
26
+ def __init__(self, vocab_file: str):
27
+ with open(vocab_file) as f:
28
+ self.vocab: Dict[str, int] = json.load(f)
29
+ self.unk_id = self.vocab.get("<unk>", 0)
30
+
31
+ def tokens2ids(self, tokens) -> list:
32
+ return [self.vocab.get(t, self.unk_id) for t in tokens]
33
+
34
+
35
+
36
+ def build_xeus_pr(
37
+ config_file: str,
38
+ checkpoint: Optional[str] = None,
39
+ vocab_file: Optional[str] = None,
40
+ ctc_config: Optional[dict] = None,
41
+ weighted_sum: bool = False,
42
+ interctc_layer_idx: Optional[list] = None,
43
+ interctc_weight: float = 0.0,
44
+ interctc_use_conditioning: bool = False,
45
+ interctc_ctc_type: str = "phone",
46
+ ctc_aux_config: Optional[dict] = None,
47
+ decoder_config: Optional[dict] = None,
48
+ ctc_weight: float = 1.0,
49
+ ) -> XeusPRModel:
50
+ """Build Xeus PR model from config and optional checkpoint.
51
+
52
+ Args:
53
+ config_file: Path to config yaml file
54
+ checkpoint: Path to model checkpoint (pretrained or fully trained)
55
+ vocab_file: Path to vocabulary file. If None, use vocab in config.
56
+ ctc_config: Optional dict of CTC config
57
+ weighted_sum: Whether to use weighted sum of transformer layers
58
+
59
+ Returns:
60
+ XeusPRModel
61
+ """
62
+ with open(config_file, "r", encoding="utf-8") as f:
63
+ args = argparse.Namespace(**yaml.safe_load(f))
64
+ if vocab_file is not None:
65
+ with open(vocab_file) as f:
66
+ tok2id = json.load(f)
67
+ id2tok = {v: k for k, v in tok2id.items()}
68
+ token_list = [id2tok[i] for i in range(len(id2tok))]
69
+ elif isinstance(args.token_list, str):
70
+ with open(args.token_list, encoding="utf-8") as f:
71
+ token_list = [line.rstrip() for line in f]
72
+ else:
73
+ token_list = list(args.token_list)
74
+ vocab_size = len(token_list)
75
+ log.info(f"Vocabulary size: {vocab_size}")
76
+
77
+ assert (
78
+ getattr(args, "frontend") == "wav2vec_cnn"
79
+ ), "Config must specify wav2vec_cnn frontend"
80
+ frontend = Wav2VecCNN(**args.frontend_conf)
81
+ input_size = frontend.output_size()
82
+
83
+ specaug = None
84
+ if hasattr(args, "specaug") and args.specaug == "specaug":
85
+ specaug = SpecAug(**args.specaug_conf)
86
+
87
+ normalize = None
88
+ assert (
89
+ getattr(args, "preencoder") == "linear"
90
+ ), "Config must specify linear preencoder"
91
+ preencoder = LinearProjection(input_size=input_size, **args.preencoder_conf)
92
+ input_size = preencoder.output_size()
93
+ assert (
94
+ args.encoder == "e_branchformer"
95
+ ), f"Only e_branchformer supported, got {args.encoder}"
96
+ encoder_conf = dict(args.encoder_conf)
97
+ if interctc_layer_idx:
98
+ encoder_conf["interctc_layer_idx"] = interctc_layer_idx
99
+ if interctc_use_conditioning:
100
+ encoder_conf["interctc_use_conditioning"] = True
101
+ encoder = EBranchformerEncoder(input_size=input_size, **encoder_conf)
102
+
103
+ ctc_config = ctc_config or getattr(args, "ctc_conf", {})
104
+ ctc_config_orig = copy.deepcopy(ctc_config)
105
+ # Build CTC
106
+ ctc = CTC(
107
+ odim=vocab_size,
108
+ encoder_output_size=encoder.output_size(),
109
+ **ctc_config,
110
+ )
111
+
112
+ # Build optional aux CTC (orthographic vocabulary)
113
+ ctc_aux = None
114
+ if ctc_aux_config is not None:
115
+ import sentencepiece as spm
116
+
117
+ ctc_aux_config = dict(ctc_aux_config) # copy to avoid mutating caller's dict
118
+ sp = spm.SentencePieceProcessor()
119
+ sp.load(ctc_aux_config.pop("vocab_file"))
120
+ aux_vocab_size = sp.get_piece_size()
121
+ ctc_aux = CTC(
122
+ odim=aux_vocab_size,
123
+ encoder_output_size=encoder.output_size(),
124
+ ctc_type="builtin",
125
+ **ctc_aux_config,
126
+ )
127
+ log.info(f"Built aux CTC with vocab size {aux_vocab_size}")
128
+
129
+ # Build optional attention decoder
130
+ decoder = None
131
+ if decoder_config:
132
+ from src.model.powsm.transformer_decoder import TransformerDecoder
133
+
134
+ decoder = TransformerDecoder(
135
+ vocab_size=vocab_size,
136
+ encoder_output_size=encoder.output_size(),
137
+ **decoder_config,
138
+ )
139
+
140
+ # Build model
141
+ model = XeusPRModel(
142
+ encoder=encoder,
143
+ ctc=ctc,
144
+ token_list=token_list,
145
+ frontend=frontend,
146
+ specaug=specaug,
147
+ normalize=normalize,
148
+ preencoder=preencoder,
149
+ ignore_id=getattr(args, "ignore_id", -1),
150
+ sym_blank=getattr(args, "sym_blank", "<blank>"),
151
+ freeze_frontend=checkpoint is not None,
152
+ weighted_sum=weighted_sum,
153
+ interctc_weight=interctc_weight,
154
+ interctc_use_conditioning=interctc_use_conditioning,
155
+ interctc_ctc_type=interctc_ctc_type,
156
+ ctc_aux=ctc_aux,
157
+ decoder=decoder,
158
+ ctc_weight=ctc_weight,
159
+ )
160
+
161
+ if checkpoint:
162
+ state_dict = torch.load(checkpoint, map_location="cpu", weights_only=False)
163
+ if "state_dict" in state_dict:
164
+ # convert to standard xeus style checkpoint
165
+ state_dict = state_dict["state_dict"] # for finetuned lightning checkpoints
166
+ state_dict = {
167
+ k.replace("net.", ""): v
168
+ for k, v in state_dict.items()
169
+ if k.startswith("net.")
170
+ }
171
+ load_info = model.load_state_dict(state_dict, strict=False)
172
+ log.info(f"Loaded checkpoint: {checkpoint} with load info: {load_info}")
173
+ print(f"Loaded checkpoint: {checkpoint} with load info: {load_info}")
174
+
175
+ model.training_args = args
176
+ model._net_config = {
177
+ "ctc_config": ctc_config_orig,
178
+ "weighted_sum": weighted_sum,
179
+ "interctc_layer_idx": interctc_layer_idx,
180
+ "interctc_weight": interctc_weight,
181
+ "interctc_use_conditioning": interctc_use_conditioning,
182
+ "interctc_ctc_type": interctc_ctc_type,
183
+ "ctc_aux_config": ctc_aux_config,
184
+ "decoder_config": decoder_config,
185
+ "ctc_weight": ctc_weight,
186
+ }
187
+ return model
188
+
189
+
190
+ def build_xeus_pr_from_hf(
191
+ *,
192
+ work_dir: str,
193
+ hf_repo: Optional[str] = None,
194
+ force: bool = False,
195
+ config_file: Optional[str] = None,
196
+ checkpoint: Optional[str] = None,
197
+ vocab_file: Optional[str] = None,
198
+ ctc_config: Optional[dict] = None,
199
+ load_ckpt: bool = True,
200
+ weighted_sum: bool = False,
201
+ interctc_layer_idx: Optional[list] = None,
202
+ interctc_weight: float = 0.0,
203
+ interctc_use_conditioning: bool = False,
204
+ interctc_ctc_type: str = "phone",
205
+ ctc_aux_config: Optional[dict] = None,
206
+ decoder_config: Optional[dict] = None,
207
+ ctc_weight: float = 1.0,
208
+ ) -> XeusPRModel:
209
+ """Build Xeus PR model from local files or HuggingFace repo.
210
+
211
+ Args:
212
+ work_dir: Directory to store downloaded files from HF repo
213
+ hf_repo: HuggingFace repo name (e.g., "username/xeus-pr")
214
+ If None, load from local files only
215
+ force: Whether to force re-download from HF repo
216
+ config_file: Path to config file. If None, use default path in work_dir.
217
+ Takes precedence over hf_repo download.
218
+ checkpoint: Path to checkpoint file. If None, use default path in work_dir.
219
+ Takes precedence over hf_repo download.
220
+ vocab_file: Path to vocabulary file. If None, use path in config.
221
+ ctc_config: Optional dict of CTC config
222
+ load_ckpt: Whether to load checkpoint weights
223
+ weighted_sum: Whether to use weighted sum of transformer layers
224
+ Returns:
225
+ XeusPRModel
226
+ """
227
+ # Default relative paths in HF repo
228
+ REL_CONFIG = "model/config.yaml"
229
+ REL_CKPT = "model/xeus_checkpoint_new.pth"
230
+
231
+ # Download from HF if repo specified
232
+ if hf_repo:
233
+ log.info(f"Downloading snapshot from HuggingFace: {hf_repo}")
234
+ download_hf_snapshot(
235
+ repo_id=hf_repo,
236
+ force_download=force,
237
+ work_dir=work_dir,
238
+ )
239
+
240
+ # Resolve file paths
241
+ root = Path(work_dir)
242
+ cfg = config_file or str(root / REL_CONFIG)
243
+ ckpt = checkpoint or str(root / REL_CKPT)
244
+
245
+ # Verify files exist
246
+ assert Path(cfg).exists(), f"Config file not found: {cfg}"
247
+ if not load_ckpt:
248
+ ckpt = None
249
+ else:
250
+ assert Path(ckpt).exists(), f"Checkpoint file not found: {ckpt}"
251
+
252
+ log.info(f"Building model from config: {cfg}")
253
+ log.info(f"Loading checkpoint: {ckpt}")
254
+
255
+ return build_xeus_pr(
256
+ config_file=cfg,
257
+ checkpoint=ckpt,
258
+ vocab_file=vocab_file,
259
+ ctc_config=ctc_config,
260
+ weighted_sum=weighted_sum,
261
+ interctc_layer_idx=interctc_layer_idx,
262
+ interctc_weight=interctc_weight,
263
+ interctc_use_conditioning=interctc_use_conditioning,
264
+ interctc_ctc_type=interctc_ctc_type,
265
+ ctc_aux_config=ctc_aux_config,
266
+ decoder_config=decoder_config,
267
+ ctc_weight=ctc_weight,
268
+ )
269
+
270
+
271
+ def build_xeus_pr_inference(
272
+ work_dir: str,
273
+ checkpoint: str,
274
+ vocab_file: str,
275
+ device,
276
+ config_file: Optional[str] = None,
277
+ hf_repo: Optional[str] = None,
278
+ force_download: bool = False,
279
+ dtype: str = "float32",
280
+ ctc_config: Optional[dict] = None,
281
+ weighted_sum: bool = False,
282
+ interctc_layer_idx: Optional[list] = None,
283
+ interctc_weight: float = 0.0,
284
+ interctc_use_conditioning: bool = False,
285
+ interctc_ctc_type: str = "phone",
286
+ ctc_aux_config: Optional[dict] = None,
287
+ decoder_config: Optional[dict] = None,
288
+ ) -> XeusPRInference:
289
+ model = build_xeus_pr_from_hf(
290
+ work_dir=work_dir,
291
+ hf_repo=hf_repo,
292
+ force=force_download,
293
+ config_file=config_file,
294
+ checkpoint=checkpoint,
295
+ vocab_file=vocab_file,
296
+ ctc_config=ctc_config,
297
+ load_ckpt=True,
298
+ weighted_sum=weighted_sum,
299
+ interctc_layer_idx=interctc_layer_idx,
300
+ interctc_weight=interctc_weight,
301
+ interctc_use_conditioning=interctc_use_conditioning,
302
+ interctc_ctc_type=interctc_ctc_type,
303
+ ctc_aux_config=ctc_aux_config,
304
+ decoder_config=decoder_config,
305
+ )
306
+ inference_obj = XeusPRInference(model, device=device, dtype=dtype)
307
+ return inference_obj
src/model/xeusphoneme/cnn_frontend.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from torch.nn import Module
6
+ from torch.nn import functional as F
7
+
8
+
9
+ def dim_1_layer_norm(x, eps=1e-05, gamma=None, beta=None):
10
+ """Functional version of Dim1LayerNorm."""
11
+
12
+ B, D, T = x.shape
13
+ mean = torch.mean(x, 1, keepdim=True)
14
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
15
+
16
+ x = (x - mean) * torch.rsqrt(variance + eps)
17
+
18
+ if gamma is not None:
19
+ x = x * gamma.view(1, -1, 1)
20
+ if beta is not None:
21
+ x = x + beta.view(1, -1, 1)
22
+ return x
23
+
24
+
25
+ class Dim1LayerNorm(Module):
26
+ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, bias=True):
27
+ """LayerNorm on middle dim.
28
+
29
+ It assumes the input is shape B, D, T
30
+ to avoid transposing.
31
+ Faster than TransposedLayerNorm, but
32
+ may lead to minor numerical differences.
33
+ """
34
+ super().__init__()
35
+ self.normalized_shape = normalized_shape
36
+ self.eps = eps
37
+ self.elementwise_affine = elementwise_affine
38
+
39
+ self.weight = None
40
+ self.bias = None
41
+ if elementwise_affine:
42
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
43
+ if bias:
44
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
45
+
46
+ def forward(self, x):
47
+ assert x.size(1) == self.normalized_shape
48
+ return dim_1_layer_norm(x, self.eps, self.weight, self.bias)
49
+
50
+
51
+ class TransposedLayerNorm(nn.LayerNorm):
52
+ """Layer norm with transpose"""
53
+
54
+ def forward(self, input: Tensor) -> Tensor:
55
+ x = input.transpose(-2, -1)
56
+ x = nn.functional.layer_norm(
57
+ x, self.normalized_shape, self.weight, self.bias, self.eps
58
+ )
59
+ x = x.transpose(-2, -1)
60
+ return x
61
+
62
+
63
+ class ConvLayerBlock(Module):
64
+ """Convolution unit of FeatureExtractor"""
65
+
66
+ def __init__(
67
+ self,
68
+ in_channels: int,
69
+ out_channels: int,
70
+ kernel_size: int,
71
+ stride: int,
72
+ bias: bool,
73
+ layer_norm: Optional[Module],
74
+ conv_mode: str,
75
+ ):
76
+ super().__init__()
77
+ self.kernel_size = kernel_size
78
+ self.stride = stride
79
+ self.layer_norm = layer_norm
80
+
81
+ if conv_mode == "standard":
82
+ self.conv = nn.Conv1d(
83
+ in_channels=in_channels,
84
+ out_channels=out_channels,
85
+ kernel_size=kernel_size,
86
+ stride=stride,
87
+ bias=bias,
88
+ )
89
+ elif conv_mode == "depth_only":
90
+ self.conv = nn.Conv1d(
91
+ in_channels=in_channels,
92
+ out_channels=out_channels,
93
+ kernel_size=kernel_size,
94
+ stride=stride,
95
+ bias=bias,
96
+ groups=in_channels,
97
+ )
98
+ elif conv_mode == "depth_sep":
99
+ self.conv = nn.Sequential(
100
+ nn.Conv1d(
101
+ in_channels=in_channels,
102
+ out_channels=in_channels,
103
+ kernel_size=kernel_size,
104
+ stride=stride,
105
+ bias=bias,
106
+ groups=in_channels,
107
+ ),
108
+ nn.Conv1d(
109
+ in_channels=in_channels,
110
+ out_channels=out_channels,
111
+ kernel_size=1,
112
+ stride=1,
113
+ bias=bias,
114
+ ),
115
+ )
116
+ nn.init.kaiming_normal_(self.conv.weight)
117
+
118
+ def forward(
119
+ self,
120
+ x: Tensor,
121
+ length: Optional[Tensor],
122
+ ) -> Tuple[Tensor, Optional[Tensor]]:
123
+ """ConvLayerBlock Forward.
124
+
125
+ Args:
126
+ x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
127
+ length (Tensor or None, optional): Shape ``[batch, ]``.
128
+ Returns:
129
+ Tensor: Shape ``[batch, out_channels, out_frames]``.
130
+ Optional[Tensor]: Shape ``[batch, ]``.
131
+ """
132
+ x = self.conv(x)
133
+ if self.layer_norm is not None:
134
+ x = self.layer_norm(x)
135
+ x = nn.functional.gelu(x)
136
+
137
+ if length is not None:
138
+ length = (
139
+ torch.div(length - self.kernel_size, self.stride, rounding_mode="floor")
140
+ + 1
141
+ )
142
+ # When input length is 0, the resulting length can be negative.
143
+ length = torch.max(torch.zeros_like(length), length)
144
+ return x, length
145
+
146
+
147
+ class CNNFrontend(Module):
148
+ """Convolutional feature extractor.
149
+
150
+ Typically used in SSL models.
151
+ Uses raw waveforms as input.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ norm_mode: str,
157
+ conv_mode: str,
158
+ bias: bool,
159
+ shapes: List[Tuple[int, int, int]] = [
160
+ (512, 10, 5),
161
+ (512, 3, 2),
162
+ (512, 3, 2),
163
+ (512, 3, 2),
164
+ (512, 3, 2),
165
+ (512, 2, 2),
166
+ (512, 2, 2),
167
+ ],
168
+ fs: Union[int, str] = 16000,
169
+ normalize_audio: bool = False,
170
+ normalize_output: bool = False,
171
+ layer_norm_cls: Literal["transposed", "dim1"] = "transposed",
172
+ ):
173
+
174
+ super().__init__()
175
+
176
+ if norm_mode not in ["group_norm", "layer_norm"]:
177
+ raise ValueError("Invalid norm mode")
178
+
179
+ if conv_mode not in ["standard", "depth_only", "depth_sep"]:
180
+ raise ValueError("Invalid cnn mode")
181
+
182
+ self.output_channels = shapes[-1][0]
183
+ self.normalize_audio = normalize_audio
184
+
185
+ if layer_norm_cls == "dim1":
186
+ layer_norm_func = Dim1LayerNorm
187
+ else:
188
+ layer_norm_func = TransposedLayerNorm
189
+
190
+ blocks = []
191
+ in_channels = 1
192
+ self.downsampling_factor = 1
193
+ for i, (out_channels, kernel_size, stride) in enumerate(shapes):
194
+ normalization = None
195
+ if norm_mode == "group_norm" and i == 0:
196
+ normalization = nn.GroupNorm(
197
+ num_groups=out_channels,
198
+ num_channels=out_channels,
199
+ affine=True,
200
+ )
201
+ elif norm_mode == "layer_norm":
202
+ normalization = layer_norm_func(
203
+ normalized_shape=out_channels,
204
+ )
205
+ blocks.append(
206
+ ConvLayerBlock(
207
+ in_channels=in_channels,
208
+ out_channels=out_channels,
209
+ kernel_size=kernel_size,
210
+ stride=stride,
211
+ bias=bias,
212
+ layer_norm=normalization,
213
+ conv_mode=conv_mode,
214
+ )
215
+ )
216
+ in_channels = out_channels
217
+ self.downsampling_factor *= stride
218
+ self.layers = nn.Sequential(*blocks)
219
+
220
+ if normalize_output:
221
+ self.final_norm = nn.LayerNorm(self.output_channels)
222
+ else:
223
+ self.final_norm = nn.Identity()
224
+
225
+ def output_size(self) -> int:
226
+ return self.output_channels
227
+
228
+ def forward(
229
+ self,
230
+ x: Tensor,
231
+ length: Optional[Tensor],
232
+ ) -> Tuple[Tensor, Optional[Tensor]]:
233
+ """CNNFrontend Forward.
234
+
235
+ Args:
236
+ x (Tensor):
237
+ Input Tensor representing a batch of audio,
238
+ shape: ``[batch, time]``.
239
+ length (Tensor or None, optional):
240
+ Valid length of each input sample. shape: ``[batch, ]``.
241
+
242
+ Returns:
243
+ Tensor:
244
+ The resulting feature, shape: ``[batch, frame, feature]``
245
+ Optional[Tensor]:
246
+ Valid length of each output sample. shape: ``[batch, ]``.
247
+ """
248
+ if x.ndim != 2:
249
+ raise ValueError(
250
+ f"Expected the input to be 2D (batch, time). Found: {list(x.shape)}"
251
+ )
252
+
253
+ if self.normalize_audio:
254
+ x = F.layer_norm(x, x.shape)
255
+
256
+ x = x.unsqueeze(1) # (batch, channel==1, frame)
257
+ for layer in self.layers:
258
+ x, length = layer(x, length) # (batch, feature, frame)
259
+ x = x.transpose(1, 2) # (batch, frame, feature)
260
+ x = self.final_norm(x)
261
+ return x, length
src/model/xeusphoneme/linear_layer.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Linear Projection."""
2
+
3
+ from typing import Tuple
4
+ import torch
5
+
6
+
7
+ class LinearProjection(torch.nn.Module):
8
+ def __init__(self, input_size: int, output_size: int, dropout: float = 0.0):
9
+ super().__init__()
10
+ self.output_dim = output_size
11
+ self.linear_out = torch.nn.Linear(input_size, output_size)
12
+ self.dropout = torch.nn.Dropout(dropout)
13
+
14
+ def forward(
15
+ self, input: torch.Tensor, input_lengths: torch.Tensor
16
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
17
+ output = self.linear_out(self.dropout(input))
18
+ return output, input_lengths # no state in this layer
19
+
20
+ def output_size(self) -> int:
21
+ return self.output_dim
src/model/xeusphoneme/resources/ipa_vocab.json ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<blank>": 0,
3
+ "<sos>": 1,
4
+ "<eos>": 2,
5
+ "<unk>": 3,
6
+ "ʈ": 4,
7
+ "ʎː": 5,
8
+ "cː": 6,
9
+ "œ̞": 7,
10
+ "ʔʲ": 8,
11
+ "o̤": 9,
12
+ "ɠ": 10,
13
+ "ø": 11,
14
+ "kˀ": 12,
15
+ "e̝": 13,
16
+ "ʈ͡ʂ": 14,
17
+ "ɡʰ": 15,
18
+ "ɟ": 16,
19
+ "z": 17,
20
+ "ʃˠ": 18,
21
+ "vˠ": 19,
22
+ "ǃʰ": 20,
23
+ "dʷ": 21,
24
+ "ĩ": 22,
25
+ "nˠ": 23,
26
+ "ə": 24,
27
+ "t͡ʃʰ": 25,
28
+ "d̤": 26,
29
+ "fʲ": 27,
30
+ "xʷ": 28,
31
+ "ɛ̃": 29,
32
+ "ʃʰ": 30,
33
+ "ʃ̩": 31,
34
+ "ɤˀ": 32,
35
+ "əː": 33,
36
+ "ɛ̯": 34,
37
+ "ɞ": 35,
38
+ "yː": 36,
39
+ "fʷ": 37,
40
+ "ẽ": 38,
41
+ "rˤ": 39,
42
+ "ɒ": 40,
43
+ "ɲː": 41,
44
+ "j": 42,
45
+ "f": 43,
46
+ "ɲ̥": 44,
47
+ "ʃː": 45,
48
+ "l": 46,
49
+ "ʒ̩": 47,
50
+ "ɛ̝": 48,
51
+ "ð̞": 49,
52
+ "ʃʲ": 50,
53
+ "ɛ": 51,
54
+ "ɟː": 52,
55
+ "ʌ": 53,
56
+ "ʍ": 54,
57
+ "kʰ": 55,
58
+ "p͡f": 56,
59
+ "ɜː": 57,
60
+ "ɘ": 58,
61
+ "bʷ": 59,
62
+ "sː": 60,
63
+ "ɡː": 61,
64
+ "o̝": 62,
65
+ "cʼ": 63,
66
+ "tʰ": 64,
67
+ "kʷ": 65,
68
+ "ŋ̥": 66,
69
+ "r̝": 67,
70
+ "ɸː": 68,
71
+ "u̝": 69,
72
+ "ṳ": 70,
73
+ "β̞": 71,
74
+ "ɾː": 72,
75
+ "ɔˤ": 73,
76
+ "ʎ": 74,
77
+ "ʊ̃": 75,
78
+ "pˀ": 76,
79
+ "m̩": 77,
80
+ "ɕː": 78,
81
+ "ɪ̯": 79,
82
+ "ɖʰ": 80,
83
+ "ɰ": 81,
84
+ "t̠": 82,
85
+ "t͡ʃʲ": 83,
86
+ "ɡ̤": 84,
87
+ "j̩": 85,
88
+ "ɭ̩": 86,
89
+ "ŋ̰": 87,
90
+ "p": 88,
91
+ "ɾ": 89,
92
+ "sʲ": 90,
93
+ "ɲ̤": 91,
94
+ "cʰ": 92,
95
+ "a̯": 93,
96
+ "ɡʷ": 94,
97
+ "t͡s": 95,
98
+ "ɨ̯": 96,
99
+ "n̩": 97,
100
+ "ʌː": 98,
101
+ "ɤ": 99,
102
+ "l̩": 100,
103
+ "l̴": 101,
104
+ "pʲ": 102,
105
+ "k": 103,
106
+ "jː": 104,
107
+ "ɛ̈": 105,
108
+ "t͡ʃː": 106,
109
+ "dˠ": 107,
110
+ "ɱ̩": 108,
111
+ "ɯː": 109,
112
+ "kʼ": 110,
113
+ "ɑ̯": 111,
114
+ "zʷ": 112,
115
+ "çː": 113,
116
+ "ã": 114,
117
+ "sˠ": 115,
118
+ "s̻": 116,
119
+ "ɐ": 117,
120
+ "ɸʷ": 118,
121
+ "ɔ̃": 119,
122
+ "bˠ": 120,
123
+ "ʈː": 121,
124
+ "ʂ": 122,
125
+ "ɑ": 123,
126
+ "ë": 124,
127
+ "ɸ": 125,
128
+ "ɮʲ": 126,
129
+ "nː": 127,
130
+ "mʷ": 128,
131
+ "ǁ": 129,
132
+ "ʒ": 130,
133
+ "jˠ": 131,
134
+ "d": 132,
135
+ "tː": 133,
136
+ "ɤ̆": 134,
137
+ "s̺": 135,
138
+ "mː": 136,
139
+ "ɻ": 137,
140
+ "l̪": 138,
141
+ "ɜ": 139,
142
+ "ɓ": 140,
143
+ "ü": 141,
144
+ "lʲ": 142,
145
+ "tˠ": 143,
146
+ "ŋː": 144,
147
+ "ŋʲ": 145,
148
+ "h̩": 146,
149
+ "qʷ": 147,
150
+ "tʼ": 148,
151
+ "ə̯": 149,
152
+ "t͡sʲː": 150,
153
+ "m̤": 151,
154
+ "ɕʰ": 152,
155
+ "nʲ": 153,
156
+ "rˠ": 154,
157
+ "ɖ̤": 155,
158
+ "ø̈": 156,
159
+ "ɯˀ": 157,
160
+ "mʲ": 158,
161
+ "n̥": 159,
162
+ "mˤ": 160,
163
+ "ʒʲ": 161,
164
+ "æ": 162,
165
+ "tʷ": 163,
166
+ "d̪": 164,
167
+ "ʔ": 165,
168
+ "a̠": 166,
169
+ "ɾˠ": 167,
170
+ "ʉ": 168,
171
+ "ɔ̯": 169,
172
+ "zʲ": 170,
173
+ "ɳː": 171,
174
+ "t͡sː": 172,
175
+ "æ̯": 173,
176
+ "r̤": 174,
177
+ "ɑː": 175,
178
+ "ɘː": 176,
179
+ "ə˞": 177,
180
+ "zˤ": 178,
181
+ "õ": 179,
182
+ "əˀ": 180,
183
+ "e": 181,
184
+ "nˤ": 182,
185
+ "u": 183,
186
+ "ɑ̃": 184,
187
+ "o": 185,
188
+ "ħ": 186,
189
+ "ŋ": 187,
190
+ "mˠ": 188,
191
+ "i": 189,
192
+ "rʲ": 190,
193
+ "ɔ": 191,
194
+ "xʰ": 192,
195
+ "dˤ": 193,
196
+ "s̩": 194,
197
+ "t͡ɕʰ": 195,
198
+ "ɔ̈": 196,
199
+ "ĕ": 197,
200
+ "ɴ": 198,
201
+ "k͡x": 199,
202
+ "d͡ʒ": 200,
203
+ "dʲ": 201,
204
+ "æ̞": 202,
205
+ "ɡ̃": 203,
206
+ "uː": 204,
207
+ "pʰ": 205,
208
+ "ʁ": 206,
209
+ "n̪": 207,
210
+ "zˠ": 208,
211
+ "ø̞": 209,
212
+ "ɔː": 210,
213
+ "ɳ": 211,
214
+ "vʲ": 212,
215
+ "œ̃": 213,
216
+ "ɾ̝": 214,
217
+ "ũ": 215,
218
+ "ĭ": 216,
219
+ "ɐ̯": 217,
220
+ "ʁ̝": 218,
221
+ "qʼ": 219,
222
+ "β": 220,
223
+ "pʼ": 221,
224
+ "ɡ͡b": 222,
225
+ "oː": 223,
226
+ "ɲ": 224,
227
+ "j̃": 225,
228
+ "l̠": 226,
229
+ "a": 227,
230
+ "d͡ʑ": 228,
231
+ "œː": 229,
232
+ "t̪": 230,
233
+ "zː": 231,
234
+ "ʁ̩": 232,
235
+ "ɔ̤": 233,
236
+ "œ": 234,
237
+ "dʰ": 235,
238
+ "lː": 236,
239
+ "z̤": 237,
240
+ "sʰ": 238,
241
+ "ʏ̯": 239,
242
+ "ð": 240,
243
+ "r̩": 241,
244
+ "n̤": 242,
245
+ "ɭʲ": 243,
246
+ "ɭː": 244,
247
+ "ə̃": 245,
248
+ "ä": 246,
249
+ "ʀ": 247,
250
+ "æː": 248,
251
+ "ɡʲ": 249,
252
+ "ɪ̃": 250,
253
+ "lˠ": 251,
254
+ "ʊː": 252,
255
+ "cʲ": 253,
256
+ "ă": 254,
257
+ "d͡ʒː": 255,
258
+ "i̯": 256,
259
+ "ʉː": 257,
260
+ "t͡ɕː": 258,
261
+ "ɬ": 259,
262
+ "fˀ": 260,
263
+ "bʲ": 261,
264
+ "ɐ̃": 262,
265
+ "ɣ̤": 263,
266
+ "xʲ": 264,
267
+ "ɛ̆": 265,
268
+ "θ": 266,
269
+ "ɵː": 267,
270
+ "ɨ̞": 268,
271
+ "ɡ": 269,
272
+ "ð̠": 270,
273
+ "l̤": 271,
274
+ "w̃": 272,
275
+ "ɹ": 273,
276
+ "ɣʲ": 274,
277
+ "wˠ": 275,
278
+ "u̯": 276,
279
+ "wː": 277,
280
+ "ʐ": 278,
281
+ "ɵ": 279,
282
+ "ðˠ": 280,
283
+ "t͡ʃʼ": 281,
284
+ "pʷ": 282,
285
+ "v̤": 283,
286
+ "ǀʰ": 284,
287
+ "x": 285,
288
+ "ɥ": 286,
289
+ "ʂː": 287,
290
+ "r": 288,
291
+ "o̞": 289,
292
+ "ðˤ": 290,
293
+ "ɨ̃": 291,
294
+ "ʊ": 292,
295
+ "ʙ": 293,
296
+ "b̤": 294,
297
+ "ŋ̤": 295,
298
+ "kʲ": 296,
299
+ "ʏː": 297,
300
+ "ʄ": 298,
301
+ "eː": 299,
302
+ "ɗ": 300,
303
+ "ʏ̈": 301,
304
+ "ɛˤ": 302,
305
+ "w": 303,
306
+ "pː": 304,
307
+ "ɖ": 305,
308
+ "ɧ": 306,
309
+ "h": 307,
310
+ "ǁʰ": 308,
311
+ "hʲ": 309,
312
+ "ʃ": 310,
313
+ "ɑ̈": 311,
314
+ "d͡z": 312,
315
+ "bˤ": 313,
316
+ "k͡p": 314,
317
+ "ð̩": 315,
318
+ "n̠": 316,
319
+ "bː": 317,
320
+ "f̩": 318,
321
+ "wʲ": 319,
322
+ "o̯": 320,
323
+ "ʁː": 321,
324
+ "pˠ": 322,
325
+ "kː": 323,
326
+ "ɪˤ": 324,
327
+ "ʑː": 325,
328
+ "ʌ̃": 326,
329
+ "ɪː": 327,
330
+ "ǃ": 328,
331
+ "ç": 329,
332
+ "s": 330,
333
+ "hː": 331,
334
+ "rː": 332,
335
+ "tˤ": 333,
336
+ "ɦʲ": 334,
337
+ "ŋ̩": 335,
338
+ "m̥": 336,
339
+ "ɖː": 337,
340
+ "ɭ": 338,
341
+ "mˀ": 339,
342
+ "n": 340,
343
+ "iː": 341,
344
+ "æ̝": 342,
345
+ "xː": 343,
346
+ "i̤": 344,
347
+ "ɽ̤": 345,
348
+ "ɶ": 346,
349
+ "ˀs": 347,
350
+ "l̥": 348,
351
+ "ɱ": 349,
352
+ "e̞": 350,
353
+ "ʋ": 351,
354
+ "y̯": 352,
355
+ "lˤ": 353,
356
+ "ö": 354,
357
+ "a̝": 355,
358
+ "ɶː": 356,
359
+ "t͡sʼ": 357,
360
+ "s̠": 358,
361
+ "t͡sʲ": 359,
362
+ "ɪ": 360,
363
+ "y̆": 361,
364
+ "ɤː": 362,
365
+ "ɟʰ": 363,
366
+ "ʒː": 364,
367
+ "tʲ": 365,
368
+ "ɕ": 366,
369
+ "ɨ": 367,
370
+ "c": 368,
371
+ "t͡ʃ": 369,
372
+ "ʑ": 370,
373
+ "ʝ": 371,
374
+ "ʋ̥": 372,
375
+ "ɢ": 373,
376
+ "ɛː": 374,
377
+ "b": 375,
378
+ "øː": 376,
379
+ "ǀ": 377,
380
+ "ʏ": 378,
381
+ "i̝": 379,
382
+ "ʊ̯": 380,
383
+ "ʊˤ": 381,
384
+ "ɐˤ": 382,
385
+ "r̥": 383,
386
+ "t͡sʰ": 384,
387
+ "aː": 385,
388
+ "t͡ɬ": 386,
389
+ "ʋː": 387,
390
+ "sˤ": 388,
391
+ "s̪": 389,
392
+ "dː": 390,
393
+ "ɪ̈": 391,
394
+ "ɨː": 392,
395
+ "ɽʷ": 393,
396
+ "ʕ": 394,
397
+ "ɒː": 395,
398
+ "χ": 396,
399
+ "fˠ": 397,
400
+ "ɯ": 398,
401
+ "hˠ": 399,
402
+ "jˤ": 400,
403
+ "tˀ": 401,
404
+ "ɣ": 402,
405
+ "y": 403,
406
+ "ɦ": 404,
407
+ "ʈʰ": 405,
408
+ "t͡ɕ": 406,
409
+ "vː": 407,
410
+ "m": 408,
411
+ "ɮ": 409,
412
+ "e̤": 410,
413
+ "ʋʲ": 411,
414
+ "æ̃": 412,
415
+ "v": 413,
416
+ "ɽ": 414,
417
+ "t": 415,
418
+ "a̤": 416,
419
+ "e̯": 417,
420
+ "ɜ˞": 418,
421
+ "q": 419,
422
+ "bʰ": 420,
423
+ "t͡sˠ": 421,
424
+ "ʂʰ": 422,
425
+ "fː": 423,
426
+ "sʷ": 424,
427
+ "ɾʲ": 425,
428
+ "w̤": 426,
429
+ "fˤ": 427
430
+ }
src/model/xeusphoneme/xeuspr_inference.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Compatible with distributed inference api, uses greedy ctc inference strategy
2
+ # python -m src.model.xeusphoneme.xeuspr_inference
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ from typing import Union, List, Dict, Any, Optional
7
+
8
+ from src.recipe.phone_recognition.greedy_ctc_strategy import GreedyCTCInference
9
+
10
+
11
+ class XeusPRInference:
12
+ """Greedy inference for Xeus Phoneme Recognition model."""
13
+
14
+ def __init__(
15
+ self,
16
+ model: torch.nn.Module,
17
+ device: str = "cpu",
18
+ dtype: str = "float32",
19
+ ):
20
+ self.device = device
21
+ self.dtype = getattr(torch, dtype)
22
+ self.model = model.to(device=self.device, dtype=self.dtype).eval()
23
+
24
+ self.token_list = model.token_list
25
+ self.blank_id = model.get_blank_id()
26
+ self.ignore_id = getattr(model, "ignore_id", -1)
27
+ self.inference_strategy = GreedyCTCInference(
28
+ token_list=self.token_list, blank_id=self.blank_id
29
+ )
30
+
31
+ @torch.no_grad()
32
+ def __call__(
33
+ self, speech: Union[torch.Tensor, np.ndarray], **kwargs
34
+ ) -> List[Dict[str, Any]]:
35
+ """
36
+ Perform greedy inference.
37
+ Args:
38
+ speech: Input speech of shape (nsamples,) or (batch, nsamples)
39
+ Returns:
40
+ List of results matching Powsm API
41
+ """
42
+ # 1. Prepare Input
43
+ if isinstance(speech, np.ndarray):
44
+ speech = torch.from_numpy(speech)
45
+
46
+ if speech.dim() == 1:
47
+ speech = speech.unsqueeze(0)
48
+
49
+ speech = speech.to(device=self.device, dtype=self.dtype)
50
+ speech_lengths = torch.full(
51
+ (speech.size(0),), speech.size(1), device=self.device, dtype=torch.long
52
+ )
53
+ results = self.inference_strategy(
54
+ model=self.model,
55
+ speech=speech,
56
+ speech_lengths=speech_lengths,
57
+ **kwargs,
58
+ )
59
+ return results
60
+
61
+
62
+ if __name__ == "__main__":
63
+ from src.model.xeusphoneme.builders import build_xeus_pr_inference
64
+
65
+ # Example usage
66
+ ckpt_path = "path/to/checkpoints/last.ckpt"
67
+ work_dir = "path/to/exp/cache/xeus"
68
+ vocab_file = "src/model/xeusphoneme/resources/ipa_vocab.json"
69
+ device = "cpu" if not torch.cuda.is_available() else "cuda:0"
70
+ inference_obj = build_xeus_pr_inference(
71
+ work_dir=work_dir,
72
+ checkpoint=ckpt_path,
73
+ vocab_file=vocab_file,
74
+ hf_repo="espnet/xeus",
75
+ config_file=None,
76
+ device=device,
77
+ force_download=False,
78
+ )
79
+ import torchaudio
80
+
81
+ speechpath = "path/to/test_audio.wav"
82
+ speech = torchaudio.load(speechpath)[0].squeeze(0)
83
+ # speech = speech[: 16000 * 40] # 10 seconds of audio
84
+ # dummy_speech = np.random.randn(16000 * 5).astype(np.float32) # 5 seconds of audio
85
+ results = inference_obj(speech=speech)
86
+ print(results)
src/model/xeusphoneme/xeuspr_model.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Xeus Phoneme Recognition Model.
3
+ # -*- coding: utf-8 -*-
4
+
5
+ # Copyright 2025 William Chen. Adapted from ESPnet.
6
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
7
+
8
+ Usage:
9
+ python -m src.model.xeusphoneme.xeuspr_model \
10
+ --work_dir path/to/cache/xeus
11
+ """
12
+ from typing import Any, Dict, Optional, Tuple, Union
13
+ import argparse
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+ from src.model.powsm.utils import force_gatherable
19
+ from src.espnet_import.nets_utils import make_pad_mask, pad_list, th_accuracy
20
+ from src.espnet_import.label_smoothing_loss import LabelSmoothingLoss
21
+
22
+ try:
23
+ from src.recipe.phone_recognition.error_calculator import (
24
+ ErrorCalculator,
25
+ )
26
+ except ImportError:
27
+
28
+ class ErrorCalculator:
29
+ """No-op stub when rapidfuzz/panphon are unavailable."""
30
+
31
+ def __init__(self, *args, **kwargs):
32
+ pass
33
+
34
+ def __call__(self, *args, **kwargs):
35
+ return {}
36
+
37
+
38
+ from src.model.powsm.ctc import CTC
39
+ from src.utils import RankedLogger
40
+
41
+ log = RankedLogger(__name__, rank_zero_only=False)
42
+
43
+
44
+ class XeusPRModel(torch.nn.Module):
45
+ """Encoder-only CTC model for phone recognition using Xeus pretrained weights."""
46
+
47
+ def __init__(
48
+ self,
49
+ encoder: Any,
50
+ ctc: CTC,
51
+ token_list: Union[Tuple, list],
52
+ frontend: Optional[Any] = None,
53
+ specaug: Optional[Any] = None,
54
+ normalize: Optional[Any] = None,
55
+ preencoder: Optional[Any] = None,
56
+ ignore_id: int = -1,
57
+ sym_blank: str = "<blank>",
58
+ freeze_frontend: bool = True,
59
+ weighted_sum: bool = False,
60
+ interctc_weight: float = 0.0,
61
+ interctc_use_conditioning: bool = False,
62
+ interctc_ctc_type: str = "phone",
63
+ ctc_aux: Optional[Any] = None,
64
+ decoder: Optional[Any] = None,
65
+ ctc_weight: float = 1.0,
66
+ lsm_weight: float = 0.0,
67
+ sym_sos: str = "<sos>",
68
+ sym_eos: str = "<eos>",
69
+ **kwargs,
70
+ ):
71
+ super().__init__()
72
+ self.frontend = frontend
73
+ self.specaug = specaug
74
+ self.normalize = normalize
75
+ self.preencoder = preencoder
76
+ self.encoder = encoder
77
+ self.ctc = ctc
78
+ self.ctc_aux = ctc_aux
79
+ self.interctc_ctc_type = interctc_ctc_type
80
+ if interctc_use_conditioning:
81
+ vocab_size_cond = (
82
+ ctc_aux.ctc_lo.out_features
83
+ if interctc_ctc_type == "ortho" and ctc_aux is not None
84
+ else len(token_list)
85
+ )
86
+ self.encoder.conditioning_layer = torch.nn.Linear(
87
+ vocab_size_cond, encoder.output_size()
88
+ )
89
+ self.encoder.interctc_use_conditioning = True
90
+ self.token_list = list(token_list)
91
+ self.ignore_id = ignore_id
92
+ self.blank_id = token_list.index(sym_blank) if sym_blank in token_list else 0
93
+ sym_space = kwargs.get("sym_space", "<space>")
94
+ self.freeze_frontend = freeze_frontend
95
+ self.error_calculator = ErrorCalculator(
96
+ token_list,
97
+ blank_id=self.blank_id,
98
+ sym_space=sym_space,
99
+ ignore_id=ignore_id,
100
+ log_phone_metrics=True,
101
+ )
102
+
103
+ self.decoder = decoder
104
+ self.ctc_weight = ctc_weight
105
+ if decoder is not None:
106
+ self.sos = token_list.index(sym_sos)
107
+ self.eos = token_list.index(sym_eos)
108
+ self.criterion_att = LabelSmoothingLoss(
109
+ size=len(token_list),
110
+ padding_idx=ignore_id,
111
+ smoothing=lsm_weight,
112
+ normalize_length=False,
113
+ )
114
+
115
+ self.weighted_sum = weighted_sum
116
+ if self.weighted_sum:
117
+ n_layers = encoder.num_blocks
118
+ assert (
119
+ n_layers is not None and n_layers > 0
120
+ ), "Cannot infer number of encoder layers for weighted_sum"
121
+ self.layer_weights = torch.nn.Parameter(torch.zeros(int(n_layers)))
122
+ self.interctc_weight = interctc_weight
123
+ self.sampling_rate = 16000
124
+
125
+ def points_by_frames(self) -> int:
126
+ """Samples per encoder frame (CNN downsampling factor)."""
127
+ return self.frontend.downsampling_factor
128
+
129
+ @torch.no_grad()
130
+ def forced_align(self, speech, speech_lengths, text, text_lengths, utt_id=None):
131
+ """CTC forced alignment via torchaudio.functional.forced_align (batch size 1)."""
132
+ assert speech.shape[0] == 1, "forced_align requires batch size 1"
133
+ text = text[:, : text_lengths.max()]
134
+ logits, logit_lengths = self.ctc_logits(speech, speech_lengths)
135
+ log_probs = F.log_softmax(logits, dim=-1)
136
+ align_label, align_prob = torchaudio.functional.forced_align(
137
+ log_probs, text, logit_lengths, text_lengths, blank=self.blank_id
138
+ )
139
+ return align_label, align_prob
140
+
141
+ def collect_feats(
142
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs
143
+ ) -> Dict[str, torch.Tensor]:
144
+ """Extract features for stats collection."""
145
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
146
+ return {"feats": feats, "feats_lengths": feats_lengths}
147
+
148
+ def forward(self, speech, speech_lengths, text, text_lengths, **kwargs):
149
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
150
+
151
+ intermediate_outs = None
152
+ if isinstance(encoder_out, tuple):
153
+ intermediate_outs = encoder_out[1]
154
+ encoder_out = encoder_out[0]
155
+
156
+ loss_ctc, stats = self._calc_ctc_loss(
157
+ encoder_out, encoder_out_lens, text, text_lengths, **kwargs
158
+ )
159
+
160
+ if self.interctc_weight > 0.0 and intermediate_outs:
161
+ if self.interctc_ctc_type == "ortho" and self.ctc_aux is not None:
162
+ ctc_inter = self.ctc_aux
163
+ ys_inter = kwargs.get("asr_text_tokens")
164
+ ys_inter_lens = kwargs.get("asr_text_length")
165
+ else:
166
+ ctc_inter = self.ctc
167
+ ys_inter = torch.where(text == -1, self.ignore_id, text)[
168
+ :, : text_lengths.max()
169
+ ]
170
+ ys_inter_lens = text_lengths
171
+
172
+ if ys_inter is not None and ys_inter_lens is not None:
173
+ loss_interctc = 0.0
174
+ for layer_idx, intermediate_out in intermediate_outs:
175
+ loss_ic = ctc_inter(
176
+ intermediate_out,
177
+ encoder_out_lens,
178
+ ys_inter,
179
+ ys_inter_lens,
180
+ )
181
+ loss_interctc = loss_interctc + loss_ic
182
+ stats[f"loss_interctc_layer{layer_idx}"] = loss_ic.detach()
183
+ loss_interctc = loss_interctc / len(intermediate_outs)
184
+ loss_ctc = (
185
+ 1 - self.interctc_weight
186
+ ) * loss_ctc + self.interctc_weight * loss_interctc
187
+
188
+ # Attention branch
189
+ if self.ctc_weight < 1.0 and self.decoder is not None:
190
+ loss_att, acc_att = self._calc_att_loss(
191
+ encoder_out, encoder_out_lens, text, text_lengths
192
+ )
193
+ stats["loss_att"] = loss_att.detach()
194
+ stats["acc_att"] = acc_att
195
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
196
+ else:
197
+ loss = loss_ctc
198
+
199
+ loss, stats, weight = force_gatherable(
200
+ (loss, stats, speech.shape[0]), loss.device
201
+ )
202
+ return {"loss": loss, "stats": stats, "weight": weight}
203
+
204
+ def _extract_feats(
205
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ """Extract features using frontend."""
208
+ speech = speech[:, : speech_lengths.max()]
209
+ return (
210
+ self.frontend(speech, speech_lengths)
211
+ if self.frontend
212
+ else (speech, speech_lengths)
213
+ )
214
+
215
+ def _apply_preprocessing(
216
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
217
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
218
+ """Apply frontend, specaug, normalize, and preencoder."""
219
+ speech, speech_lengths = self._extract_feats(speech, speech_lengths)
220
+
221
+ if self.specaug and self.training:
222
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
223
+
224
+ if self.normalize:
225
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
226
+
227
+ if self.preencoder:
228
+ speech, speech_lengths = self.preencoder(speech, speech_lengths)
229
+
230
+ return speech, speech_lengths
231
+
232
+ def encode(
233
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
234
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
235
+ """Encode speech to frame-level representations.
236
+
237
+ When weighted_sum=True, returns a weighted sum of all encoder layers.
238
+ Otherwise, calls the encoder without return_all_hs; if interctc_layer_idx
239
+ is configured on the encoder, returns (final_out, [(layer_idx, tensor), ...]).
240
+ """
241
+ speech, speech_lengths = self._apply_preprocessing(speech, speech_lengths)
242
+ pad_masks = make_pad_mask(speech_lengths).to(speech.device)
243
+ if self.weighted_sum:
244
+ encoder_out, encoder_out_lens, _ = self.encoder(
245
+ speech, speech_lengths, masks=pad_masks, return_all_hs=True
246
+ )
247
+ hs_list = encoder_out[1]
248
+ assert len(hs_list) == self.layer_weights.numel()
249
+ w = torch.softmax(self.layer_weights, dim=0).to(
250
+ hs_list[0].device, hs_list[0].dtype
251
+ )
252
+ hs = torch.stack(hs_list, dim=0) # (L, B, T, D)
253
+ return (w.view(-1, 1, 1, 1) * hs).sum(0), encoder_out_lens
254
+ else:
255
+ ctc_for_encoder = (
256
+ self.ctc_aux
257
+ if self.interctc_ctc_type == "ortho" and self.ctc_aux is not None
258
+ else self.ctc
259
+ )
260
+ encoder_out, encoder_out_lens, _ = self.encoder(
261
+ speech, speech_lengths, masks=pad_masks, ctc=ctc_for_encoder
262
+ )
263
+ return encoder_out, encoder_out_lens
264
+
265
+ def ctc_collapse_batch(self, x: torch.Tensor, max_length: int, pad: int = -1):
266
+ B, T = x.shape
267
+ blank = self.blank_id
268
+ x_prev = torch.cat(
269
+ [torch.full((B, 1), blank, device=x.device, dtype=x.dtype), x[:, :-1]],
270
+ dim=1,
271
+ )
272
+ keep = (x != blank) & ((x_prev == blank) | (x != x_prev))
273
+ pos = keep.long().cumsum(1) - 1
274
+ lengths = keep.sum(1)
275
+ out = torch.full((B, T), pad, device=x.device, dtype=x.dtype)
276
+ # Compute batch indices and output positions for kept elements
277
+ batch_idx = (
278
+ torch.arange(B, device=x.device, dtype=torch.long).unsqueeze(1).expand_as(x)
279
+ )
280
+ output_pos = pos.clone()
281
+ # Only use positions where keep is True
282
+ batch_idx_keep = batch_idx[keep]
283
+ output_pos_keep = output_pos[keep]
284
+ # Flatten the output and set values at correct positions
285
+ flat_out = out.view(-1)
286
+ flat_idx = batch_idx_keep * T + output_pos_keep
287
+ flat_out[flat_idx] = x[keep]
288
+ out = flat_out.view(B, T)
289
+ ##### Trim to max_length from ground truth lengths
290
+ out = out[:, :max_length]
291
+ lengths = torch.clamp(lengths, max=max_length)
292
+ return out, lengths
293
+
294
+ def _calc_att_loss(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens):
295
+ ys_pad = torch.where(ys_pad == -1, self.ignore_id, ys_pad)
296
+ ys = [y[y != self.ignore_id][:l] for y, l in zip(ys_pad, ys_pad_lens)]
297
+ _sos = ys_pad.new([self.sos])
298
+ _eos = ys_pad.new([self.eos])
299
+ ys_in = [torch.cat([_sos, y]) for y in ys]
300
+ ys_out = [torch.cat([y, _eos]) for y in ys]
301
+ ys_in_pad = pad_list(ys_in, self.eos)
302
+ ys_out_pad = pad_list(ys_out, self.ignore_id)
303
+ ys_in_lens = torch.tensor([len(y) for y in ys_in], device=ys_pad.device)
304
+
305
+ decoder_out, _ = self.decoder(
306
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
307
+ )
308
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
309
+ acc_att = th_accuracy(
310
+ decoder_out.view(-1, len(self.token_list)),
311
+ ys_out_pad,
312
+ ignore_label=self.ignore_id,
313
+ )
314
+ return loss_att, acc_att
315
+
316
+ def _calc_ctc_loss(
317
+ self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, **kwargs
318
+ ):
319
+ ys_pad = torch.where(ys_pad == -1, self.ignore_id, ys_pad)
320
+ ys_pad = ys_pad[:, : ys_pad_lens.max()]
321
+ loss_ctc = self.ctc(
322
+ encoder_out,
323
+ encoder_out_lens,
324
+ ys_pad,
325
+ ys_pad_lens,
326
+ lang_sym=kwargs.get("lang_sym"),
327
+ accent_sym=kwargs.get("accent_sym"),
328
+ )
329
+ stats = {}
330
+ assert self.error_calculator is not None, "ErrorCalculator not initialized"
331
+ if not self.training: # err calc, slow?
332
+ with torch.no_grad():
333
+ ys_hat = self.ctc.argmax(encoder_out).data # greedy-top1
334
+ metrics = self.error_calculator(
335
+ ys_hat.cpu(), ys_pad.cpu(), ys_pad_lens.cpu()
336
+ )
337
+ for k, v in metrics.items():
338
+ stats[k + "_ctc"] = v
339
+ return loss_ctc, stats
340
+
341
+ def ctc_logits(
342
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
343
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
344
+ """Get CTC logits for inference."""
345
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
346
+ if isinstance(encoder_out, tuple):
347
+ encoder_out = encoder_out[0]
348
+ return self.ctc.ctc_lo(encoder_out), encoder_out_lens
349
+
350
+ def encoder_output_size(self) -> int:
351
+ return self.encoder.output_size()
352
+
353
+ def get_blank_id(self) -> int:
354
+ return self.blank_id
355
+
356
+ def get_frontend(self):
357
+ return self.frontend
358
+
359
+ def get_trainable_parameters(self):
360
+ trainable_params = {"head": [], "encoder": []}
361
+ for n, p in self.named_parameters():
362
+ if (
363
+ n.startswith("ctc")
364
+ or n.startswith("decoder")
365
+ or n.startswith("criterion_att")
366
+ ):
367
+ trainable_params["head"].append(p)
368
+ elif n.startswith("encoder"):
369
+ trainable_params["encoder"].append(p)
370
+ elif n.startswith("frontend"):
371
+ if self.freeze_frontend:
372
+ p.requires_grad = False
373
+ else:
374
+ trainable_params["encoder"].append(p)
375
+ else:
376
+ # freeze other parts:
377
+ p.requires_grad = False
378
+ return trainable_params
src/recipe/__init__.py ADDED
File without changes
src/recipe/phone_recognition/__init__.py ADDED
File without changes
src/recipe/phone_recognition/greedy_ctc_strategy.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Dict, Any, Union
3
+
4
+
5
+ def ctc_collapse_vectorized(
6
+ ids: torch.Tensor, blank_id: int, ignore_id: int = -1
7
+ ) -> List[List[int]]:
8
+ """Optimized CTC collapse for batch tensors."""
9
+ mask = torch.ones_like(ids, dtype=torch.bool)
10
+ mask[:, 1:] = ids[:, 1:] != ids[:, :-1]
11
+ mask &= ids != blank_id
12
+ if ignore_id != -1:
13
+ mask &= ids != ignore_id
14
+
15
+ return [ids[i][mask[i]].tolist() for i in range(ids.size(0))]
16
+
17
+
18
+ class GreedyCTCInference:
19
+ """A scalable inference engine for any CTC-based phone recognizer."""
20
+
21
+ def __init__(self, token_list: List[str], blank_id: int):
22
+ self.token_list = token_list
23
+ self.blank_id = blank_id
24
+
25
+ @torch.no_grad()
26
+ def __call__(
27
+ self,
28
+ model: torch.nn.Module,
29
+ speech: torch.Tensor,
30
+ speech_lengths: torch.Tensor,
31
+ **kwargs
32
+ ) -> List[Dict[str, Any]]:
33
+ # 1. Standardized Forward pass
34
+ # Works as long as model has .encode() and .ctc
35
+ encoder_out, _ = model.encode(speech, speech_lengths)
36
+ if isinstance(encoder_out, tuple):
37
+ encoder_out = encoder_out[0]
38
+ logits = model.ctc.ctc_lo(encoder_out)
39
+
40
+ # 2. Greedy search
41
+ y_hat = torch.argmax(logits, dim=-1)
42
+
43
+ # 3. Collapse
44
+ collapsed_ids = ctc_collapse_vectorized(y_hat, self.blank_id)
45
+
46
+ # 4. Map to text
47
+ results = []
48
+ for ids in collapsed_ids:
49
+ tokens = [self.token_list[i] for i in ids]
50
+ raw_text = "/".join(tokens)
51
+ # Filter special tokens
52
+ clean_tokens = [
53
+ t for t in tokens if not (t.startswith("<") and t.endswith(">"))
54
+ ]
55
+ processed = "".join(clean_tokens).strip() # replace(self.sym_space, " ")
56
+
57
+ results.append(
58
+ {
59
+ "processed_transcript": processed,
60
+ "predicted_transcript": raw_text,
61
+ }
62
+ )
63
+ return results
src/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from src.utils.pylogger import RankedLogger
src/utils/pylogger.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Mapping, Optional
3
+
4
+
5
+ class RankedLogger(logging.LoggerAdapter):
6
+ """Simplified logger for single-process inference (no Lightning)."""
7
+
8
+ def __init__(
9
+ self,
10
+ name: str = __name__,
11
+ rank_zero_only: bool = False,
12
+ extra: Optional[Mapping[str, object]] = None,
13
+ ) -> None:
14
+ logger = logging.getLogger(name)
15
+ super().__init__(logger=logger, extra=extra)
16
+ self.rank_zero_only = rank_zero_only
17
+
18
+ def log(
19
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
20
+ ) -> None:
21
+ if self.isEnabledFor(level):
22
+ msg, kwargs = self.process(msg, kwargs)
23
+ self.logger.log(level, msg, *args, **kwargs)