Upload 16 files
Browse files- .gitattributes +1 -0
- conformer/conformer/__init__.py +15 -0
- conformer/conformer/activation.py +42 -0
- conformer/conformer/attention.py +151 -0
- conformer/conformer/convolution.py +186 -0
- conformer/conformer/embedding.py +42 -0
- conformer/conformer/encoder.py +204 -0
- conformer/conformer/feed_forward.py +57 -0
- conformer/conformer/model.py +108 -0
- conformer/conformer/model_def.py +104 -0
- conformer/conformer/modules.py +73 -0
- model_pinyin.py +299 -0
- pinyin_index.json +808 -0
- stepstep=024500.ckpt +3 -0
- train_pinyin.py +125 -0
- tts-2025-04-27@197cd135f2a2451b9cab9cf2add1c1ab.wav +3 -0
- ui.py +167 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tts-2025-04-27@197cd135f2a2451b9cab9cf2add1c1ab.wav filter=lfs diff=lfs merge=lfs -text
|
conformer/conformer/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .model import Conformer
|
conformer/conformer/activation.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Swish(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
|
| 22 |
+
to a variety of challenging domains such as Image classification and Machine translation.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self):
|
| 25 |
+
super(Swish, self).__init__()
|
| 26 |
+
|
| 27 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 28 |
+
return inputs * inputs.sigmoid()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class GLU(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
|
| 34 |
+
in the paper “Language Modeling with Gated Convolutional Networks”
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, dim: int) -> None:
|
| 37 |
+
super(GLU, self).__init__()
|
| 38 |
+
self.dim = dim
|
| 39 |
+
|
| 40 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 41 |
+
outputs, gate = inputs.chunk(2, dim=self.dim)
|
| 42 |
+
return outputs * gate.sigmoid()
|
conformer/conformer/attention.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch import Tensor
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
from .embedding import PositionalEncoding
|
| 23 |
+
from .modules import Linear
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RelativeMultiHeadAttention(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
Multi-head attention with relative positional encoding.
|
| 29 |
+
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
d_model (int): The dimension of model
|
| 33 |
+
num_heads (int): The number of attention heads.
|
| 34 |
+
dropout_p (float): probability of dropout
|
| 35 |
+
|
| 36 |
+
Inputs: query, key, value, pos_embedding, mask
|
| 37 |
+
- **query** (batch, time, dim): Tensor containing query vector
|
| 38 |
+
- **key** (batch, time, dim): Tensor containing key vector
|
| 39 |
+
- **value** (batch, time, dim): Tensor containing value vector
|
| 40 |
+
- **pos_embedding** (batch, time, dim): Positional embedding tensor
|
| 41 |
+
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
- **outputs**: Tensor produces by relative multi head attention module.
|
| 45 |
+
"""
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
d_model: int = 512,
|
| 49 |
+
num_heads: int = 16,
|
| 50 |
+
dropout_p: float = 0.1,
|
| 51 |
+
):
|
| 52 |
+
super(RelativeMultiHeadAttention, self).__init__()
|
| 53 |
+
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
|
| 54 |
+
self.d_model = d_model
|
| 55 |
+
self.d_head = int(d_model / num_heads)
|
| 56 |
+
self.num_heads = num_heads
|
| 57 |
+
self.sqrt_dim = math.sqrt(d_model)
|
| 58 |
+
|
| 59 |
+
self.query_proj = Linear(d_model, d_model)
|
| 60 |
+
self.key_proj = Linear(d_model, d_model)
|
| 61 |
+
self.value_proj = Linear(d_model, d_model)
|
| 62 |
+
self.pos_proj = Linear(d_model, d_model, bias=False)
|
| 63 |
+
|
| 64 |
+
self.dropout = nn.Dropout(p=dropout_p)
|
| 65 |
+
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 66 |
+
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
| 67 |
+
torch.nn.init.xavier_uniform_(self.u_bias)
|
| 68 |
+
torch.nn.init.xavier_uniform_(self.v_bias)
|
| 69 |
+
|
| 70 |
+
self.out_proj = Linear(d_model, d_model)
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self,
|
| 74 |
+
query: Tensor,
|
| 75 |
+
key: Tensor,
|
| 76 |
+
value: Tensor,
|
| 77 |
+
pos_embedding: Tensor,
|
| 78 |
+
mask: Optional[Tensor] = None,
|
| 79 |
+
) -> Tensor:
|
| 80 |
+
batch_size = value.size(0)
|
| 81 |
+
|
| 82 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
| 83 |
+
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 84 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
| 85 |
+
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
| 86 |
+
|
| 87 |
+
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
|
| 88 |
+
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
|
| 89 |
+
pos_score = self._relative_shift(pos_score)
|
| 90 |
+
|
| 91 |
+
score = (content_score + pos_score) / self.sqrt_dim
|
| 92 |
+
|
| 93 |
+
if mask is not None:
|
| 94 |
+
mask = mask.unsqueeze(1)
|
| 95 |
+
score.masked_fill_(mask, -1e9)
|
| 96 |
+
|
| 97 |
+
attn = F.softmax(score, -1)
|
| 98 |
+
attn = self.dropout(attn)
|
| 99 |
+
|
| 100 |
+
context = torch.matmul(attn, value).transpose(1, 2)
|
| 101 |
+
context = context.contiguous().view(batch_size, -1, self.d_model)
|
| 102 |
+
|
| 103 |
+
return self.out_proj(context)
|
| 104 |
+
|
| 105 |
+
def _relative_shift(self, pos_score: Tensor) -> Tensor:
|
| 106 |
+
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
| 107 |
+
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
|
| 108 |
+
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
| 109 |
+
|
| 110 |
+
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
| 111 |
+
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
| 112 |
+
|
| 113 |
+
return pos_score
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class MultiHeadedSelfAttentionModule(nn.Module):
|
| 117 |
+
"""
|
| 118 |
+
Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
|
| 119 |
+
the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
|
| 120 |
+
module to generalize better on different input length and the resulting encoder is more robust to the variance of
|
| 121 |
+
the utterance length. Conformer use prenorm residual units with dropout which helps training
|
| 122 |
+
and regularizing deeper models.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
d_model (int): The dimension of model
|
| 126 |
+
num_heads (int): The number of attention heads.
|
| 127 |
+
dropout_p (float): probability of dropout
|
| 128 |
+
|
| 129 |
+
Inputs: inputs, mask
|
| 130 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
| 131 |
+
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
- **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
|
| 135 |
+
"""
|
| 136 |
+
def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1):
|
| 137 |
+
super(MultiHeadedSelfAttentionModule, self).__init__()
|
| 138 |
+
self.positional_encoding = PositionalEncoding(d_model)
|
| 139 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 140 |
+
self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p)
|
| 141 |
+
self.dropout = nn.Dropout(p=dropout_p)
|
| 142 |
+
|
| 143 |
+
def forward(self, inputs: Tensor, mask: Optional[Tensor] = None):
|
| 144 |
+
batch_size, seq_length, _ = inputs.size()
|
| 145 |
+
pos_embedding = self.positional_encoding(seq_length)
|
| 146 |
+
pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
|
| 147 |
+
|
| 148 |
+
inputs = self.layer_norm(inputs)
|
| 149 |
+
outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask)
|
| 150 |
+
|
| 151 |
+
return self.dropout(outputs)
|
conformer/conformer/convolution.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
from .activation import Swish, GLU
|
| 21 |
+
from .modules import Transpose
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DepthwiseConv1d(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
|
| 27 |
+
this operation is termed in literature as depthwise convolution.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
in_channels (int): Number of channels in the input
|
| 31 |
+
out_channels (int): Number of channels produced by the convolution
|
| 32 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 33 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
| 34 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 35 |
+
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
| 36 |
+
|
| 37 |
+
Inputs: inputs
|
| 38 |
+
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
| 39 |
+
|
| 40 |
+
Returns: outputs
|
| 41 |
+
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
|
| 42 |
+
"""
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
in_channels: int,
|
| 46 |
+
out_channels: int,
|
| 47 |
+
kernel_size: int,
|
| 48 |
+
stride: int = 1,
|
| 49 |
+
padding: int = 0,
|
| 50 |
+
bias: bool = False,
|
| 51 |
+
) -> None:
|
| 52 |
+
super(DepthwiseConv1d, self).__init__()
|
| 53 |
+
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
|
| 54 |
+
self.conv = nn.Conv1d(
|
| 55 |
+
in_channels=in_channels,
|
| 56 |
+
out_channels=out_channels,
|
| 57 |
+
kernel_size=kernel_size,
|
| 58 |
+
groups=in_channels,
|
| 59 |
+
stride=stride,
|
| 60 |
+
padding=padding,
|
| 61 |
+
bias=bias,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 65 |
+
return self.conv(inputs)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class PointwiseConv1d(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
|
| 71 |
+
This operation often used to match dimensions.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
in_channels (int): Number of channels in the input
|
| 75 |
+
out_channels (int): Number of channels produced by the convolution
|
| 76 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
| 77 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 78 |
+
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
| 79 |
+
|
| 80 |
+
Inputs: inputs
|
| 81 |
+
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
| 82 |
+
|
| 83 |
+
Returns: outputs
|
| 84 |
+
- **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
|
| 85 |
+
"""
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
in_channels: int,
|
| 89 |
+
out_channels: int,
|
| 90 |
+
stride: int = 1,
|
| 91 |
+
padding: int = 0,
|
| 92 |
+
bias: bool = True,
|
| 93 |
+
) -> None:
|
| 94 |
+
super(PointwiseConv1d, self).__init__()
|
| 95 |
+
self.conv = nn.Conv1d(
|
| 96 |
+
in_channels=in_channels,
|
| 97 |
+
out_channels=out_channels,
|
| 98 |
+
kernel_size=1,
|
| 99 |
+
stride=stride,
|
| 100 |
+
padding=padding,
|
| 101 |
+
bias=bias,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 105 |
+
return self.conv(inputs)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class ConformerConvModule(nn.Module):
|
| 109 |
+
"""
|
| 110 |
+
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
| 111 |
+
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
| 112 |
+
to aid training deep models.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
in_channels (int): Number of channels in the input
|
| 116 |
+
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
|
| 117 |
+
dropout_p (float, optional): probability of dropout
|
| 118 |
+
|
| 119 |
+
Inputs: inputs
|
| 120 |
+
inputs (batch, time, dim): Tensor contains input sequences
|
| 121 |
+
|
| 122 |
+
Outputs: outputs
|
| 123 |
+
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
| 124 |
+
"""
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
in_channels: int,
|
| 128 |
+
kernel_size: int = 31,
|
| 129 |
+
expansion_factor: int = 2,
|
| 130 |
+
dropout_p: float = 0.1,
|
| 131 |
+
) -> None:
|
| 132 |
+
super(ConformerConvModule, self).__init__()
|
| 133 |
+
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
| 134 |
+
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
|
| 135 |
+
|
| 136 |
+
self.sequential = nn.Sequential(
|
| 137 |
+
nn.LayerNorm(in_channels),
|
| 138 |
+
Transpose(shape=(1, 2)),
|
| 139 |
+
PointwiseConv1d(in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True),
|
| 140 |
+
GLU(dim=1),
|
| 141 |
+
DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
|
| 142 |
+
nn.BatchNorm1d(in_channels),
|
| 143 |
+
Swish(),
|
| 144 |
+
PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True),
|
| 145 |
+
nn.Dropout(p=dropout_p),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 149 |
+
return self.sequential(inputs).transpose(1, 2)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Conv2dSubampling(nn.Module):
|
| 153 |
+
"""
|
| 154 |
+
Convolutional 2D subsampling (to 1/4 length)
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
in_channels (int): Number of channels in the input image
|
| 158 |
+
out_channels (int): Number of channels produced by the convolution
|
| 159 |
+
|
| 160 |
+
Inputs: inputs
|
| 161 |
+
- **inputs** (batch, time, dim): Tensor containing sequence of inputs
|
| 162 |
+
|
| 163 |
+
Returns: outputs, output_lengths
|
| 164 |
+
- **outputs** (batch, time, dim): Tensor produced by the convolution
|
| 165 |
+
- **output_lengths** (batch): list of sequence output lengths
|
| 166 |
+
"""
|
| 167 |
+
def __init__(self, in_channels: int, out_channels: int) -> None:
|
| 168 |
+
super(Conv2dSubampling, self).__init__()
|
| 169 |
+
self.sequential = nn.Sequential(
|
| 170 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2),
|
| 171 |
+
nn.ReLU(),
|
| 172 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2),
|
| 173 |
+
nn.ReLU(),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
|
| 177 |
+
outputs = self.sequential(inputs.unsqueeze(1))
|
| 178 |
+
batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
|
| 179 |
+
|
| 180 |
+
outputs = outputs.permute(0, 2, 1, 3)
|
| 181 |
+
outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim)
|
| 182 |
+
|
| 183 |
+
output_lengths = input_lengths >> 2
|
| 184 |
+
output_lengths -= 1
|
| 185 |
+
|
| 186 |
+
return outputs, output_lengths
|
conformer/conformer/embedding.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from torch import Tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PositionalEncoding(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Positional Encoding proposed in "Attention Is All You Need".
|
| 24 |
+
Since transformer contains no recurrence and no convolution, in order for the model to make
|
| 25 |
+
use of the order of the sequence, we must add some positional information.
|
| 26 |
+
|
| 27 |
+
"Attention Is All You Need" use sine and cosine functions of different frequencies:
|
| 28 |
+
PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
|
| 29 |
+
PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, d_model: int = 512, max_len: int = 10000) -> None:
|
| 32 |
+
super(PositionalEncoding, self).__init__()
|
| 33 |
+
pe = torch.zeros(max_len, d_model, requires_grad=False)
|
| 34 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 35 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
|
| 36 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 37 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 38 |
+
pe = pe.unsqueeze(0)
|
| 39 |
+
self.register_buffer('pe', pe)
|
| 40 |
+
|
| 41 |
+
def forward(self, length: int) -> Tensor:
|
| 42 |
+
return self.pe[:, :length]
|
conformer/conformer/encoder.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
from .feed_forward import FeedForwardModule
|
| 21 |
+
from .attention import MultiHeadedSelfAttentionModule
|
| 22 |
+
from .convolution import (
|
| 23 |
+
ConformerConvModule,
|
| 24 |
+
Conv2dSubampling,
|
| 25 |
+
)
|
| 26 |
+
from .modules import (
|
| 27 |
+
ResidualConnectionModule,
|
| 28 |
+
Linear,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ConformerBlock(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
|
| 35 |
+
and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
|
| 36 |
+
the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
|
| 37 |
+
one before the attention layer and one after.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
| 41 |
+
num_attention_heads (int, optional): Number of attention heads
|
| 42 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
| 43 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
| 44 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
| 45 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
| 46 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
| 47 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
| 48 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
| 49 |
+
|
| 50 |
+
Inputs: inputs
|
| 51 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
| 52 |
+
|
| 53 |
+
Returns: outputs
|
| 54 |
+
- **outputs** (batch, time, dim): Tensor produces by conformer block.
|
| 55 |
+
"""
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
encoder_dim: int = 512,
|
| 59 |
+
num_attention_heads: int = 8,
|
| 60 |
+
feed_forward_expansion_factor: int = 4,
|
| 61 |
+
conv_expansion_factor: int = 2,
|
| 62 |
+
feed_forward_dropout_p: float = 0.1,
|
| 63 |
+
attention_dropout_p: float = 0.1,
|
| 64 |
+
conv_dropout_p: float = 0.1,
|
| 65 |
+
conv_kernel_size: int = 31,
|
| 66 |
+
half_step_residual: bool = True,
|
| 67 |
+
):
|
| 68 |
+
super(ConformerBlock, self).__init__()
|
| 69 |
+
if half_step_residual:
|
| 70 |
+
self.feed_forward_residual_factor = 0.5
|
| 71 |
+
else:
|
| 72 |
+
self.feed_forward_residual_factor = 1
|
| 73 |
+
|
| 74 |
+
self.sequential = nn.Sequential(
|
| 75 |
+
ResidualConnectionModule(
|
| 76 |
+
module=FeedForwardModule(
|
| 77 |
+
encoder_dim=encoder_dim,
|
| 78 |
+
expansion_factor=feed_forward_expansion_factor,
|
| 79 |
+
dropout_p=feed_forward_dropout_p,
|
| 80 |
+
),
|
| 81 |
+
module_factor=self.feed_forward_residual_factor,
|
| 82 |
+
),
|
| 83 |
+
ResidualConnectionModule(
|
| 84 |
+
module=MultiHeadedSelfAttentionModule(
|
| 85 |
+
d_model=encoder_dim,
|
| 86 |
+
num_heads=num_attention_heads,
|
| 87 |
+
dropout_p=attention_dropout_p,
|
| 88 |
+
),
|
| 89 |
+
),
|
| 90 |
+
ResidualConnectionModule(
|
| 91 |
+
module=ConformerConvModule(
|
| 92 |
+
in_channels=encoder_dim,
|
| 93 |
+
kernel_size=conv_kernel_size,
|
| 94 |
+
expansion_factor=conv_expansion_factor,
|
| 95 |
+
dropout_p=conv_dropout_p,
|
| 96 |
+
),
|
| 97 |
+
),
|
| 98 |
+
ResidualConnectionModule(
|
| 99 |
+
module=FeedForwardModule(
|
| 100 |
+
encoder_dim=encoder_dim,
|
| 101 |
+
expansion_factor=feed_forward_expansion_factor,
|
| 102 |
+
dropout_p=feed_forward_dropout_p,
|
| 103 |
+
),
|
| 104 |
+
module_factor=self.feed_forward_residual_factor,
|
| 105 |
+
),
|
| 106 |
+
nn.LayerNorm(encoder_dim),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 110 |
+
return self.sequential(inputs)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class ConformerEncoder(nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
Conformer encoder first processes the input with a convolution subsampling layer and then
|
| 116 |
+
with a number of conformer blocks.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
input_dim (int, optional): Dimension of input vector
|
| 120 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
| 121 |
+
num_layers (int, optional): Number of conformer blocks
|
| 122 |
+
num_attention_heads (int, optional): Number of attention heads
|
| 123 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
| 124 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
| 125 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
| 126 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
| 127 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
| 128 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
| 129 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
| 130 |
+
|
| 131 |
+
Inputs: inputs, input_lengths
|
| 132 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
| 133 |
+
- **input_lengths** (batch): list of sequence input lengths
|
| 134 |
+
|
| 135 |
+
Returns: outputs, output_lengths
|
| 136 |
+
- **outputs** (batch, out_channels, time): Tensor produces by conformer encoder.
|
| 137 |
+
- **output_lengths** (batch): list of sequence output lengths
|
| 138 |
+
"""
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
input_dim: int = 80,
|
| 142 |
+
encoder_dim: int = 512,
|
| 143 |
+
num_layers: int = 17,
|
| 144 |
+
num_attention_heads: int = 8,
|
| 145 |
+
feed_forward_expansion_factor: int = 4,
|
| 146 |
+
conv_expansion_factor: int = 2,
|
| 147 |
+
input_dropout_p: float = 0.1,
|
| 148 |
+
feed_forward_dropout_p: float = 0.1,
|
| 149 |
+
attention_dropout_p: float = 0.1,
|
| 150 |
+
conv_dropout_p: float = 0.1,
|
| 151 |
+
conv_kernel_size: int = 31,
|
| 152 |
+
half_step_residual: bool = True,
|
| 153 |
+
):
|
| 154 |
+
super(ConformerEncoder, self).__init__()
|
| 155 |
+
self.conv_subsample = Conv2dSubampling(in_channels=1, out_channels=encoder_dim)
|
| 156 |
+
self.input_projection = nn.Sequential(
|
| 157 |
+
Linear(encoder_dim * (((input_dim - 1) // 2 - 1) // 2), encoder_dim),
|
| 158 |
+
nn.Dropout(p=input_dropout_p),
|
| 159 |
+
)
|
| 160 |
+
self.layers = nn.ModuleList([ConformerBlock(
|
| 161 |
+
encoder_dim=encoder_dim,
|
| 162 |
+
num_attention_heads=num_attention_heads,
|
| 163 |
+
feed_forward_expansion_factor=feed_forward_expansion_factor,
|
| 164 |
+
conv_expansion_factor=conv_expansion_factor,
|
| 165 |
+
feed_forward_dropout_p=feed_forward_dropout_p,
|
| 166 |
+
attention_dropout_p=attention_dropout_p,
|
| 167 |
+
conv_dropout_p=conv_dropout_p,
|
| 168 |
+
conv_kernel_size=conv_kernel_size,
|
| 169 |
+
half_step_residual=half_step_residual,
|
| 170 |
+
) for _ in range(num_layers)])
|
| 171 |
+
|
| 172 |
+
def count_parameters(self) -> int:
|
| 173 |
+
""" Count parameters of encoder """
|
| 174 |
+
return sum([p.numel() for p in self.parameters()])
|
| 175 |
+
|
| 176 |
+
def update_dropout(self, dropout_p: float) -> None:
|
| 177 |
+
""" Update dropout probability of encoder """
|
| 178 |
+
for name, child in self.named_children():
|
| 179 |
+
if isinstance(child, nn.Dropout):
|
| 180 |
+
child.p = dropout_p
|
| 181 |
+
|
| 182 |
+
def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
|
| 183 |
+
"""
|
| 184 |
+
Forward propagate a `inputs` for encoder training.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
|
| 188 |
+
`FloatTensor` of size ``(batch, seq_length, dimension)``.
|
| 189 |
+
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
(Tensor, Tensor)
|
| 193 |
+
|
| 194 |
+
* outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
|
| 195 |
+
``(batch, seq_length, dimension)``
|
| 196 |
+
* output_lengths (torch.LongTensor): The length of output tensor. ``(batch)``
|
| 197 |
+
"""
|
| 198 |
+
outputs, output_lengths = self.conv_subsample(inputs, input_lengths)
|
| 199 |
+
outputs = self.input_projection(outputs)
|
| 200 |
+
|
| 201 |
+
for layer in self.layers:
|
| 202 |
+
outputs = layer(outputs)
|
| 203 |
+
|
| 204 |
+
return outputs, output_lengths
|
conformer/conformer/feed_forward.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
|
| 19 |
+
from .activation import Swish
|
| 20 |
+
from .modules import Linear
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FeedForwardModule(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit
|
| 26 |
+
and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps
|
| 27 |
+
regularizing the network.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
encoder_dim (int): Dimension of conformer encoder
|
| 31 |
+
expansion_factor (int): Expansion factor of feed forward module.
|
| 32 |
+
dropout_p (float): Ratio of dropout
|
| 33 |
+
|
| 34 |
+
Inputs: inputs
|
| 35 |
+
- **inputs** (batch, time, dim): Tensor contains input sequences
|
| 36 |
+
|
| 37 |
+
Outputs: outputs
|
| 38 |
+
- **outputs** (batch, time, dim): Tensor produces by feed forward module.
|
| 39 |
+
"""
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
encoder_dim: int = 512,
|
| 43 |
+
expansion_factor: int = 4,
|
| 44 |
+
dropout_p: float = 0.1,
|
| 45 |
+
) -> None:
|
| 46 |
+
super(FeedForwardModule, self).__init__()
|
| 47 |
+
self.sequential = nn.Sequential(
|
| 48 |
+
nn.LayerNorm(encoder_dim),
|
| 49 |
+
Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
|
| 50 |
+
Swish(),
|
| 51 |
+
nn.Dropout(p=dropout_p),
|
| 52 |
+
Linear(encoder_dim * expansion_factor, encoder_dim, bias=True),
|
| 53 |
+
nn.Dropout(p=dropout_p),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 57 |
+
return self.sequential(inputs)
|
conformer/conformer/model.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
from .encoder import ConformerEncoder
|
| 21 |
+
from .modules import Linear
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Conformer(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Conformer: Convolution-augmented Transformer for Speech Recognition
|
| 27 |
+
The paper used a one-lstm Transducer decoder, currently still only implemented
|
| 28 |
+
the conformer encoder shown in the paper.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
num_classes (int): Number of classification classes
|
| 32 |
+
input_dim (int, optional): Dimension of input vector
|
| 33 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
| 34 |
+
num_encoder_layers (int, optional): Number of conformer blocks
|
| 35 |
+
num_attention_heads (int, optional): Number of attention heads
|
| 36 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
| 37 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
| 38 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
| 39 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
| 40 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
| 41 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
| 42 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
| 43 |
+
|
| 44 |
+
Inputs: inputs, input_lengths
|
| 45 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
| 46 |
+
- **input_lengths** (batch): list of sequence input lengths
|
| 47 |
+
|
| 48 |
+
Returns: outputs, output_lengths
|
| 49 |
+
- **outputs** (batch, out_channels, time): Tensor produces by conformer.
|
| 50 |
+
- **output_lengths** (batch): list of sequence output lengths
|
| 51 |
+
"""
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
num_classes: int,
|
| 55 |
+
input_dim: int = 80,
|
| 56 |
+
encoder_dim: int = 512,
|
| 57 |
+
num_encoder_layers: int = 17,
|
| 58 |
+
num_attention_heads: int = 8,
|
| 59 |
+
feed_forward_expansion_factor: int = 4,
|
| 60 |
+
conv_expansion_factor: int = 2,
|
| 61 |
+
input_dropout_p: float = 0.1,
|
| 62 |
+
feed_forward_dropout_p: float = 0.1,
|
| 63 |
+
attention_dropout_p: float = 0.1,
|
| 64 |
+
conv_dropout_p: float = 0.1,
|
| 65 |
+
conv_kernel_size: int = 31,
|
| 66 |
+
half_step_residual: bool = True,
|
| 67 |
+
) -> None:
|
| 68 |
+
super(Conformer, self).__init__()
|
| 69 |
+
self.encoder = ConformerEncoder(
|
| 70 |
+
input_dim=input_dim,
|
| 71 |
+
encoder_dim=encoder_dim,
|
| 72 |
+
num_layers=num_encoder_layers,
|
| 73 |
+
num_attention_heads=num_attention_heads,
|
| 74 |
+
feed_forward_expansion_factor=feed_forward_expansion_factor,
|
| 75 |
+
conv_expansion_factor=conv_expansion_factor,
|
| 76 |
+
input_dropout_p=input_dropout_p,
|
| 77 |
+
feed_forward_dropout_p=feed_forward_dropout_p,
|
| 78 |
+
attention_dropout_p=attention_dropout_p,
|
| 79 |
+
conv_dropout_p=conv_dropout_p,
|
| 80 |
+
conv_kernel_size=conv_kernel_size,
|
| 81 |
+
half_step_residual=half_step_residual,
|
| 82 |
+
)
|
| 83 |
+
self.fc = Linear(encoder_dim, num_classes, bias=False)
|
| 84 |
+
|
| 85 |
+
def count_parameters(self) -> int:
|
| 86 |
+
""" Count parameters of encoder """
|
| 87 |
+
return self.encoder.count_parameters()
|
| 88 |
+
|
| 89 |
+
def update_dropout(self, dropout_p) -> None:
|
| 90 |
+
""" Update dropout probability of model """
|
| 91 |
+
self.encoder.update_dropout(dropout_p)
|
| 92 |
+
|
| 93 |
+
def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
|
| 94 |
+
"""
|
| 95 |
+
Forward propagate a `inputs` and `targets` pair for training.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
|
| 99 |
+
`FloatTensor` of size ``(batch, seq_length, dimension)``.
|
| 100 |
+
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
* predictions (torch.FloatTensor): Result of model predictions.
|
| 104 |
+
"""
|
| 105 |
+
encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths)
|
| 106 |
+
outputs = self.fc(encoder_outputs)
|
| 107 |
+
outputs = nn.functional.log_softmax(outputs, dim=-1)
|
| 108 |
+
return outputs, encoder_output_lengths
|
conformer/conformer/model_def.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
from .encoder import ConformerEncoder
|
| 21 |
+
from .modules import Linear
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Conformer(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Conformer: Convolution-augmented Transformer for Speech Recognition
|
| 27 |
+
The paper used a one-lstm Transducer decoder, currently still only implemented
|
| 28 |
+
the conformer encoder shown in the paper.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
num_classes (int): Number of classification classes
|
| 32 |
+
input_dim (int, optional): Dimension of input vector
|
| 33 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
| 34 |
+
num_encoder_layers (int, optional): Number of conformer blocks
|
| 35 |
+
num_attention_heads (int, optional): Number of attention heads
|
| 36 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
| 37 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
| 38 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
| 39 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
| 40 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
| 41 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
| 42 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
| 43 |
+
|
| 44 |
+
Inputs: inputs, input_lengths
|
| 45 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
| 46 |
+
- **input_lengths** (batch): list of sequence input lengths
|
| 47 |
+
|
| 48 |
+
Returns: outputs, output_lengths
|
| 49 |
+
- **outputs** (batch, out_channels, time): Tensor produces by conformer.
|
| 50 |
+
- **output_lengths** (batch): list of sequence output lengths
|
| 51 |
+
"""
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
input_dim: int = 80,
|
| 55 |
+
encoder_dim: int = 512,
|
| 56 |
+
num_encoder_layers: int = 17,
|
| 57 |
+
num_attention_heads: int = 8,
|
| 58 |
+
feed_forward_expansion_factor: int = 4,
|
| 59 |
+
conv_expansion_factor: int = 2,
|
| 60 |
+
input_dropout_p: float = 0.1,
|
| 61 |
+
feed_forward_dropout_p: float = 0.1,
|
| 62 |
+
attention_dropout_p: float = 0.1,
|
| 63 |
+
conv_dropout_p: float = 0.1,
|
| 64 |
+
conv_kernel_size: int = 31,
|
| 65 |
+
half_step_residual: bool = True,
|
| 66 |
+
) -> None:
|
| 67 |
+
super(Conformer, self).__init__()
|
| 68 |
+
self.encoder = ConformerEncoder(
|
| 69 |
+
input_dim=input_dim,
|
| 70 |
+
encoder_dim=encoder_dim,
|
| 71 |
+
num_layers=num_encoder_layers,
|
| 72 |
+
num_attention_heads=num_attention_heads,
|
| 73 |
+
feed_forward_expansion_factor=feed_forward_expansion_factor,
|
| 74 |
+
conv_expansion_factor=conv_expansion_factor,
|
| 75 |
+
input_dropout_p=input_dropout_p,
|
| 76 |
+
feed_forward_dropout_p=feed_forward_dropout_p,
|
| 77 |
+
attention_dropout_p=attention_dropout_p,
|
| 78 |
+
conv_dropout_p=conv_dropout_p,
|
| 79 |
+
conv_kernel_size=conv_kernel_size,
|
| 80 |
+
half_step_residual=half_step_residual,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def count_parameters(self) -> int:
|
| 84 |
+
""" Count parameters of encoder """
|
| 85 |
+
return self.encoder.count_parameters()
|
| 86 |
+
|
| 87 |
+
def update_dropout(self, dropout_p) -> None:
|
| 88 |
+
""" Update dropout probability of model """
|
| 89 |
+
self.encoder.update_dropout(dropout_p)
|
| 90 |
+
|
| 91 |
+
def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
|
| 92 |
+
"""
|
| 93 |
+
Forward propagate a `inputs` and `targets` pair for training.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
|
| 97 |
+
`FloatTensor` of size ``(batch, seq_length, dimension)``.
|
| 98 |
+
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
* predictions (torch.FloatTensor): Result of model predictions.
|
| 102 |
+
"""
|
| 103 |
+
encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths)
|
| 104 |
+
return encoder_outputs, encoder_output_lengths
|
conformer/conformer/modules.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, Soohwan Kim. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.init as init
|
| 18 |
+
from torch import Tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ResidualConnectionModule(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Residual Connection Module.
|
| 24 |
+
outputs = (module(inputs) x module_factor + inputs x input_factor)
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0):
|
| 27 |
+
super(ResidualConnectionModule, self).__init__()
|
| 28 |
+
self.module = module
|
| 29 |
+
self.module_factor = module_factor
|
| 30 |
+
self.input_factor = input_factor
|
| 31 |
+
|
| 32 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
| 33 |
+
return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Linear(nn.Module):
|
| 37 |
+
"""
|
| 38 |
+
Wrapper class of torch.nn.Linear
|
| 39 |
+
Weight initialize by xavier initialization and bias initialize to zeros.
|
| 40 |
+
"""
|
| 41 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
| 42 |
+
super(Linear, self).__init__()
|
| 43 |
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
| 44 |
+
init.xavier_uniform_(self.linear.weight)
|
| 45 |
+
if bias:
|
| 46 |
+
init.zeros_(self.linear.bias)
|
| 47 |
+
|
| 48 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 49 |
+
return self.linear(x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class View(nn.Module):
|
| 53 |
+
""" Wrapper class of torch.view() for Sequential module. """
|
| 54 |
+
def __init__(self, shape: tuple, contiguous: bool = False):
|
| 55 |
+
super(View, self).__init__()
|
| 56 |
+
self.shape = shape
|
| 57 |
+
self.contiguous = contiguous
|
| 58 |
+
|
| 59 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 60 |
+
if self.contiguous:
|
| 61 |
+
x = x.contiguous()
|
| 62 |
+
|
| 63 |
+
return x.view(*self.shape)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Transpose(nn.Module):
|
| 67 |
+
""" Wrapper class of torch.transpose() for Sequential module. """
|
| 68 |
+
def __init__(self, shape: tuple):
|
| 69 |
+
super(Transpose, self).__init__()
|
| 70 |
+
self.shape = shape
|
| 71 |
+
|
| 72 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 73 |
+
return x.transpose(*self.shape)
|
model_pinyin.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.optim import Adam
|
| 6 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 7 |
+
import math
|
| 8 |
+
from conformer.conformer.model_def import Conformer
|
| 9 |
+
|
| 10 |
+
class PositionalEncoding(nn.Module):
|
| 11 |
+
"""位置编码模块"""
|
| 12 |
+
def __init__(self, d_model, max_len=5000):
|
| 13 |
+
super(PositionalEncoding, self).__init__()
|
| 14 |
+
pe = torch.zeros(max_len, d_model)
|
| 15 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 16 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 17 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 18 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 19 |
+
pe = pe.unsqueeze(0)
|
| 20 |
+
self.register_buffer('pe', pe)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return x + self.pe[:, :x.size(1)]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CrossAttention(nn.Module):
|
| 27 |
+
"""交叉注意力模块 - 一层交叉注意力加一层自注意力"""
|
| 28 |
+
def __init__(self, query_dim, key_dim, heads=4, dropout=0.1):
|
| 29 |
+
super(CrossAttention, self).__init__()
|
| 30 |
+
# 交叉注意力层
|
| 31 |
+
self.cross_attn = nn.MultiheadAttention(query_dim, heads, dropout=dropout, batch_first=True)
|
| 32 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
| 33 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
| 34 |
+
self.ffn1 = nn.Sequential(
|
| 35 |
+
nn.Linear(query_dim, query_dim * 4),
|
| 36 |
+
nn.ReLU(),
|
| 37 |
+
nn.Dropout(dropout),
|
| 38 |
+
nn.Linear(query_dim * 4, query_dim),
|
| 39 |
+
nn.Dropout(dropout)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 自注意力层
|
| 43 |
+
self.self_attn = nn.MultiheadAttention(query_dim, heads, dropout=dropout, batch_first=True)
|
| 44 |
+
self.norm3 = nn.LayerNorm(query_dim)
|
| 45 |
+
self.norm4 = nn.LayerNorm(query_dim)
|
| 46 |
+
self.ffn2 = nn.Sequential(
|
| 47 |
+
nn.Linear(query_dim, query_dim * 4),
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
nn.Dropout(dropout),
|
| 50 |
+
nn.Linear(query_dim * 4, query_dim),
|
| 51 |
+
nn.Dropout(dropout)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# 投影层
|
| 55 |
+
self.proj_key = nn.Linear(key_dim, query_dim)
|
| 56 |
+
self.proj_value = nn.Linear(key_dim, query_dim)
|
| 57 |
+
|
| 58 |
+
def forward(self, query, key, value, key_padding_mask=None):
|
| 59 |
+
# 有音频输入时,先进行交叉注意力
|
| 60 |
+
# 投影key和value到query的维度
|
| 61 |
+
key_proj = self.proj_key(key)
|
| 62 |
+
value_proj = self.proj_value(value)
|
| 63 |
+
|
| 64 |
+
# 交叉注意力
|
| 65 |
+
query_norm = self.norm1(query)
|
| 66 |
+
cross_attn_output, _ = self.cross_attn(query_norm, key_proj, value_proj,
|
| 67 |
+
key_padding_mask=key_padding_mask)
|
| 68 |
+
query = query + cross_attn_output
|
| 69 |
+
query = query + self.ffn1(self.norm2(query))
|
| 70 |
+
|
| 71 |
+
# 然后进行自注意力
|
| 72 |
+
query_norm = self.norm3(query)
|
| 73 |
+
self_attn_output, _ = self.self_attn(query_norm, query_norm, query_norm)
|
| 74 |
+
query = query + self_attn_output
|
| 75 |
+
query = query + self.ffn2(self.norm4(query))
|
| 76 |
+
|
| 77 |
+
return query
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MMKWS2(nn.Module):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
# anchor
|
| 84 |
+
text_dim=64,
|
| 85 |
+
audio_dim=1024,
|
| 86 |
+
hidden_dim=128,
|
| 87 |
+
# compare
|
| 88 |
+
dim=80,
|
| 89 |
+
encoder_dim=128,
|
| 90 |
+
num_encoder_layers=6,
|
| 91 |
+
num_attention_heads=4,
|
| 92 |
+
dropout=0.1,
|
| 93 |
+
num_transformer_layers=2
|
| 94 |
+
):
|
| 95 |
+
super(MMKWS2, self).__init__()
|
| 96 |
+
# 音频嵌入降维
|
| 97 |
+
self.audio_proj = nn.Linear(audio_dim, hidden_dim)
|
| 98 |
+
# 文本嵌入投影
|
| 99 |
+
self.text_proj = nn.Embedding(num_embeddings=402, embedding_dim=hidden_dim) # 401 + padding -1
|
| 100 |
+
# 位置编码
|
| 101 |
+
self.pos_enc = PositionalEncoding(hidden_dim)
|
| 102 |
+
# 交叉注意力模块
|
| 103 |
+
self.cross_attn = CrossAttention(hidden_dim, hidden_dim, heads=num_attention_heads, dropout=dropout)
|
| 104 |
+
|
| 105 |
+
# Conformer层
|
| 106 |
+
self.conformer = Conformer(
|
| 107 |
+
input_dim=dim,
|
| 108 |
+
encoder_dim=encoder_dim,
|
| 109 |
+
num_encoder_layers=num_encoder_layers,
|
| 110 |
+
num_attention_heads=num_attention_heads,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# 特征映射层(将conformer输出维度映射到hidden_dim)
|
| 114 |
+
self.feat_proj = nn.Linear(encoder_dim, hidden_dim)
|
| 115 |
+
|
| 116 |
+
# Transformer层
|
| 117 |
+
self.transformer_encoder = nn.TransformerEncoder(
|
| 118 |
+
nn.TransformerEncoderLayer(
|
| 119 |
+
d_model=hidden_dim,
|
| 120 |
+
nhead=num_attention_heads,
|
| 121 |
+
dim_feedforward=hidden_dim*4,
|
| 122 |
+
dropout=dropout,
|
| 123 |
+
batch_first=True
|
| 124 |
+
),
|
| 125 |
+
num_layers=num_transformer_layers
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# GRU分类器
|
| 129 |
+
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
|
| 130 |
+
self.classifier = nn.Sequential(
|
| 131 |
+
nn.Linear(hidden_dim*2, hidden_dim),
|
| 132 |
+
nn.ReLU(),
|
| 133 |
+
nn.Dropout(dropout),
|
| 134 |
+
nn.Linear(hidden_dim, 1)
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# 序列标签预测
|
| 138 |
+
self.seq_classifier = nn.Linear(hidden_dim, 1)
|
| 139 |
+
|
| 140 |
+
def forward(self, anchor_wave_embedding, anchor_text_embedding, compare_wave, compare_lengths):
|
| 141 |
+
batch_size = anchor_wave_embedding.size(0)
|
| 142 |
+
|
| 143 |
+
# 1. 处理anchor_text嵌入
|
| 144 |
+
text_feat = self.text_proj(anchor_text_embedding) # [B, S, hidden_dim]
|
| 145 |
+
text_feat = self.pos_enc(text_feat)
|
| 146 |
+
|
| 147 |
+
# 2. 处理anchor_wave音频嵌入
|
| 148 |
+
audio_feat = self.audio_proj(anchor_wave_embedding) # [B, S, hidden_dim]
|
| 149 |
+
audio_feat = self.pos_enc(audio_feat)
|
| 150 |
+
|
| 151 |
+
# 3. 交叉注意力:文本和音频特征融合
|
| 152 |
+
fused_feat = self.cross_attn(text_feat, audio_feat, audio_feat, key_padding_mask=None) # [B, S, hidden_dim]
|
| 153 |
+
|
| 154 |
+
# 4. 处理compare_wave的fbank特征
|
| 155 |
+
compare_feat = self.conformer(compare_wave, compare_lengths)[0] # [B, T, encoder_dim]
|
| 156 |
+
compare_feat = self.feat_proj(compare_feat) # [B, T, hidden_dim]
|
| 157 |
+
compare_feat = self.pos_enc(compare_feat)
|
| 158 |
+
|
| 159 |
+
# 5. 合并特征
|
| 160 |
+
text_len = fused_feat.size(1)
|
| 161 |
+
combined_feat = torch.cat([fused_feat, compare_feat], dim=1) # [B, S+T, hidden_dim]
|
| 162 |
+
combined_feat = self.transformer_encoder(combined_feat) # [B, S+T, hidden_dim]
|
| 163 |
+
|
| 164 |
+
# 7. GRU分类
|
| 165 |
+
gru_out, _ = self.gru(combined_feat) # [B, S+T, hidden_dim*2]
|
| 166 |
+
|
| 167 |
+
# 全局分类
|
| 168 |
+
global_feat = gru_out[:, -1, :] # 取最后一个时间步
|
| 169 |
+
logits = self.classifier(global_feat).squeeze(-1) # [B]
|
| 170 |
+
|
| 171 |
+
# 序列标签预测
|
| 172 |
+
seq_logits = self.seq_classifier(combined_feat[:, :text_len, :]).squeeze(-1) # [B, S]
|
| 173 |
+
return logits, seq_logits
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def enrollment(self, anchor_wave_embedding, anchor_text_embedding):
|
| 177 |
+
batch_size = anchor_wave_embedding.size(0)
|
| 178 |
+
|
| 179 |
+
# 1. 处理anchor_text嵌入
|
| 180 |
+
text_feat = self.text_proj(anchor_text_embedding) # [B, S, hidden_dim]
|
| 181 |
+
text_feat = self.pos_enc(text_feat)
|
| 182 |
+
|
| 183 |
+
# 2. 处理anchor_wave音频嵌入
|
| 184 |
+
audio_feat = self.audio_proj(anchor_wave_embedding) # [B, S, hidden_dim]
|
| 185 |
+
audio_feat = self.pos_enc(audio_feat)
|
| 186 |
+
|
| 187 |
+
# 3. 交叉注意力:文本和音频特征融合
|
| 188 |
+
fused_feat = self.cross_attn(text_feat, audio_feat, audio_feat, key_padding_mask=None) # [B, S, hidden_dim]
|
| 189 |
+
|
| 190 |
+
return fused_feat
|
| 191 |
+
|
| 192 |
+
def verification(self, fused_feat, compare_wave, compare_lengths):
|
| 193 |
+
batch_size = fused_feat.size(0)
|
| 194 |
+
|
| 195 |
+
# 4. 处理compare_wave的fbank特征
|
| 196 |
+
compare_feat = self.conformer(compare_wave, compare_lengths)[0] # [B, T, encoder_dim]
|
| 197 |
+
compare_feat = self.feat_proj(compare_feat) # [B, T, hidden_dim]
|
| 198 |
+
compare_feat = self.pos_enc(compare_feat)
|
| 199 |
+
|
| 200 |
+
# 5. 合并特征
|
| 201 |
+
text_len = fused_feat.size(1)
|
| 202 |
+
combined_feat = torch.cat([fused_feat, compare_feat], dim=1) # [B, S+T, hidden_dim]
|
| 203 |
+
combined_feat = self.transformer_encoder(combined_feat) # [B, S+T, hidden_dim]
|
| 204 |
+
|
| 205 |
+
# 7. GRU分类
|
| 206 |
+
gru_out, _ = self.gru(combined_feat) # [B, S+T, hidden_dim*2]
|
| 207 |
+
|
| 208 |
+
# 全局分类
|
| 209 |
+
global_feat = gru_out[:, -1, :] # 取最后一个时间步
|
| 210 |
+
logits = self.classifier(global_feat).squeeze(-1) # [B]
|
| 211 |
+
|
| 212 |
+
return logits
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def count_verification_params(model):
|
| 216 |
+
modules = [
|
| 217 |
+
model.conformer,
|
| 218 |
+
model.feat_proj,
|
| 219 |
+
model.transformer_encoder,
|
| 220 |
+
model.gru,
|
| 221 |
+
model.classifier
|
| 222 |
+
]
|
| 223 |
+
total = 0
|
| 224 |
+
for m in modules:
|
| 225 |
+
total += sum(p.numel() for p in m.parameters())
|
| 226 |
+
return total
|
| 227 |
+
|
| 228 |
+
model = MMKWS2(
|
| 229 |
+
text_dim=64,
|
| 230 |
+
audio_dim=1024,
|
| 231 |
+
hidden_dim=128,
|
| 232 |
+
dim=80,
|
| 233 |
+
encoder_dim=128,
|
| 234 |
+
num_encoder_layers=6,
|
| 235 |
+
num_attention_heads=4,
|
| 236 |
+
dropout=0.1,
|
| 237 |
+
num_transformer_layers=2
|
| 238 |
+
)
|
| 239 |
+
print(f"verification相关参数量: {count_verification_params(model):,}") # 3.5M模型参数量
|
| 240 |
+
|
| 241 |
+
# if __name__ == "__main__":
|
| 242 |
+
# # 创建一个示例batch
|
| 243 |
+
# batch_size = 2
|
| 244 |
+
|
| 245 |
+
# # 创建模拟数据
|
| 246 |
+
# anchor_embedding = torch.randn(batch_size, 8, 64) # 文本嵌入
|
| 247 |
+
# anchor_wave = torch.randn(batch_size, 256, 1024) # 音频嵌入
|
| 248 |
+
# compare_wave = torch.randn(batch_size, 45, 80) # Fbank特征
|
| 249 |
+
|
| 250 |
+
# # 创建长度信息
|
| 251 |
+
# anchor_lengths = torch.LongTensor([8, 6]) # 两个样本的实际长度
|
| 252 |
+
# compare_lengths = torch.LongTensor([45, 40])
|
| 253 |
+
|
| 254 |
+
# # 创建模型
|
| 255 |
+
# model = MMKWS2(
|
| 256 |
+
# text_dim=64,
|
| 257 |
+
# audio_dim=1024,
|
| 258 |
+
# hidden_dim=128,
|
| 259 |
+
# dim=80,
|
| 260 |
+
# encoder_dim=128,
|
| 261 |
+
# num_encoder_layers=6,
|
| 262 |
+
# num_attention_heads=4,
|
| 263 |
+
# dropout=0.1,
|
| 264 |
+
# num_transformer_layers=2
|
| 265 |
+
# )
|
| 266 |
+
|
| 267 |
+
# # 计算模型参数量
|
| 268 |
+
# total_params = sum(p.numel() for p in model.parameters())
|
| 269 |
+
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 270 |
+
|
| 271 |
+
# print(f"模型总参数量: {total_params:,}")
|
| 272 |
+
# print(f"可训练参数量: {trainable_params:,}")
|
| 273 |
+
|
| 274 |
+
# # 打印模型结构
|
| 275 |
+
# print("\n模型结构:")
|
| 276 |
+
# print(model)
|
| 277 |
+
|
| 278 |
+
# # 模型推理
|
| 279 |
+
# print("\n开始推理...")
|
| 280 |
+
# model.eval()
|
| 281 |
+
# with torch.no_grad():
|
| 282 |
+
|
| 283 |
+
# print(anchor_embedding.shape)
|
| 284 |
+
# print(anchor_wave.shape)
|
| 285 |
+
# print(compare_wave.shape)
|
| 286 |
+
# print(anchor_lengths.shape)
|
| 287 |
+
# print(compare_lengths.shape)
|
| 288 |
+
# # 完整输入推理
|
| 289 |
+
# logits, seq_logits, text_len = model(
|
| 290 |
+
# anchor_embedding=anchor_embedding,
|
| 291 |
+
# anchor_wave=anchor_wave,
|
| 292 |
+
# compare_wave=compare_wave,
|
| 293 |
+
# anchor_lengths=anchor_lengths,
|
| 294 |
+
# compare_lengths=compare_lengths
|
| 295 |
+
# )
|
| 296 |
+
|
| 297 |
+
# print("\n推理结果:")
|
| 298 |
+
# print(f"分类logits形状: {logits.shape}")
|
| 299 |
+
# print(f"序列logits形状: {seq_logits.shape}")
|
pinyin_index.json
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pinyin_to_index": {
|
| 3 |
+
"a": 0,
|
| 4 |
+
"ai": 1,
|
| 5 |
+
"an": 2,
|
| 6 |
+
"ang": 3,
|
| 7 |
+
"ao": 4,
|
| 8 |
+
"ba": 5,
|
| 9 |
+
"bai": 6,
|
| 10 |
+
"ban": 7,
|
| 11 |
+
"bang": 8,
|
| 12 |
+
"bao": 9,
|
| 13 |
+
"bei": 10,
|
| 14 |
+
"ben": 11,
|
| 15 |
+
"beng": 12,
|
| 16 |
+
"bi": 13,
|
| 17 |
+
"bian": 14,
|
| 18 |
+
"biao": 15,
|
| 19 |
+
"bie": 16,
|
| 20 |
+
"bin": 17,
|
| 21 |
+
"bing": 18,
|
| 22 |
+
"bo": 19,
|
| 23 |
+
"bu": 20,
|
| 24 |
+
"ca": 21,
|
| 25 |
+
"cai": 22,
|
| 26 |
+
"can": 23,
|
| 27 |
+
"cang": 24,
|
| 28 |
+
"cao": 25,
|
| 29 |
+
"ce": 26,
|
| 30 |
+
"cen": 27,
|
| 31 |
+
"ceng": 28,
|
| 32 |
+
"cha": 29,
|
| 33 |
+
"chai": 30,
|
| 34 |
+
"chan": 31,
|
| 35 |
+
"chang": 32,
|
| 36 |
+
"chao": 33,
|
| 37 |
+
"che": 34,
|
| 38 |
+
"chen": 35,
|
| 39 |
+
"cheng": 36,
|
| 40 |
+
"chi": 37,
|
| 41 |
+
"chong": 38,
|
| 42 |
+
"chou": 39,
|
| 43 |
+
"chu": 40,
|
| 44 |
+
"chuai": 41,
|
| 45 |
+
"chuan": 42,
|
| 46 |
+
"chuang": 43,
|
| 47 |
+
"chui": 44,
|
| 48 |
+
"chun": 45,
|
| 49 |
+
"chuo": 46,
|
| 50 |
+
"ci": 47,
|
| 51 |
+
"cong": 48,
|
| 52 |
+
"cou": 49,
|
| 53 |
+
"cu": 50,
|
| 54 |
+
"cuan": 51,
|
| 55 |
+
"cui": 52,
|
| 56 |
+
"cun": 53,
|
| 57 |
+
"cuo": 54,
|
| 58 |
+
"da": 55,
|
| 59 |
+
"dai": 56,
|
| 60 |
+
"dan": 57,
|
| 61 |
+
"dang": 58,
|
| 62 |
+
"dao": 59,
|
| 63 |
+
"de": 60,
|
| 64 |
+
"dei": 61,
|
| 65 |
+
"deng": 62,
|
| 66 |
+
"di": 63,
|
| 67 |
+
"dia": 64,
|
| 68 |
+
"dian": 65,
|
| 69 |
+
"diao": 66,
|
| 70 |
+
"die": 67,
|
| 71 |
+
"ding": 68,
|
| 72 |
+
"diu": 69,
|
| 73 |
+
"dong": 70,
|
| 74 |
+
"dou": 71,
|
| 75 |
+
"du": 72,
|
| 76 |
+
"duan": 73,
|
| 77 |
+
"dui": 74,
|
| 78 |
+
"dun": 75,
|
| 79 |
+
"duo": 76,
|
| 80 |
+
"e": 77,
|
| 81 |
+
"en": 78,
|
| 82 |
+
"er": 79,
|
| 83 |
+
"fa": 80,
|
| 84 |
+
"fan": 81,
|
| 85 |
+
"fang": 82,
|
| 86 |
+
"fei": 83,
|
| 87 |
+
"fen": 84,
|
| 88 |
+
"feng": 85,
|
| 89 |
+
"fo": 86,
|
| 90 |
+
"fou": 87,
|
| 91 |
+
"fu": 88,
|
| 92 |
+
"ga": 89,
|
| 93 |
+
"gai": 90,
|
| 94 |
+
"gan": 91,
|
| 95 |
+
"gang": 92,
|
| 96 |
+
"gao": 93,
|
| 97 |
+
"ge": 94,
|
| 98 |
+
"gei": 95,
|
| 99 |
+
"gen": 96,
|
| 100 |
+
"geng": 97,
|
| 101 |
+
"gong": 98,
|
| 102 |
+
"gou": 99,
|
| 103 |
+
"gu": 100,
|
| 104 |
+
"gua": 101,
|
| 105 |
+
"guai": 102,
|
| 106 |
+
"guan": 103,
|
| 107 |
+
"guang": 104,
|
| 108 |
+
"gui": 105,
|
| 109 |
+
"gun": 106,
|
| 110 |
+
"guo": 107,
|
| 111 |
+
"ha": 108,
|
| 112 |
+
"hai": 109,
|
| 113 |
+
"han": 110,
|
| 114 |
+
"hang": 111,
|
| 115 |
+
"hao": 112,
|
| 116 |
+
"he": 113,
|
| 117 |
+
"hei": 114,
|
| 118 |
+
"hen": 115,
|
| 119 |
+
"heng": 116,
|
| 120 |
+
"hong": 117,
|
| 121 |
+
"hou": 118,
|
| 122 |
+
"hu": 119,
|
| 123 |
+
"hua": 120,
|
| 124 |
+
"huai": 121,
|
| 125 |
+
"huan": 122,
|
| 126 |
+
"huang": 123,
|
| 127 |
+
"hui": 124,
|
| 128 |
+
"hun": 125,
|
| 129 |
+
"huo": 126,
|
| 130 |
+
"ji": 127,
|
| 131 |
+
"jia": 128,
|
| 132 |
+
"jian": 129,
|
| 133 |
+
"jiang": 130,
|
| 134 |
+
"jiao": 131,
|
| 135 |
+
"jie": 132,
|
| 136 |
+
"jin": 133,
|
| 137 |
+
"jing": 134,
|
| 138 |
+
"jiong": 135,
|
| 139 |
+
"jiu": 136,
|
| 140 |
+
"ju": 137,
|
| 141 |
+
"juan": 138,
|
| 142 |
+
"jue": 139,
|
| 143 |
+
"jun": 140,
|
| 144 |
+
"ka": 141,
|
| 145 |
+
"kai": 142,
|
| 146 |
+
"kan": 143,
|
| 147 |
+
"kang": 144,
|
| 148 |
+
"kao": 145,
|
| 149 |
+
"ke": 146,
|
| 150 |
+
"ken": 147,
|
| 151 |
+
"keng": 148,
|
| 152 |
+
"kong": 149,
|
| 153 |
+
"kou": 150,
|
| 154 |
+
"ku": 151,
|
| 155 |
+
"kua": 152,
|
| 156 |
+
"kuai": 153,
|
| 157 |
+
"kuan": 154,
|
| 158 |
+
"kuang": 155,
|
| 159 |
+
"kui": 156,
|
| 160 |
+
"kun": 157,
|
| 161 |
+
"kuo": 158,
|
| 162 |
+
"la": 159,
|
| 163 |
+
"lai": 160,
|
| 164 |
+
"lan": 161,
|
| 165 |
+
"lang": 162,
|
| 166 |
+
"lao": 163,
|
| 167 |
+
"le": 164,
|
| 168 |
+
"lei": 165,
|
| 169 |
+
"leng": 166,
|
| 170 |
+
"li": 167,
|
| 171 |
+
"lia": 168,
|
| 172 |
+
"lian": 169,
|
| 173 |
+
"liang": 170,
|
| 174 |
+
"liao": 171,
|
| 175 |
+
"lie": 172,
|
| 176 |
+
"lin": 173,
|
| 177 |
+
"ling": 174,
|
| 178 |
+
"liu": 175,
|
| 179 |
+
"long": 176,
|
| 180 |
+
"lou": 177,
|
| 181 |
+
"lu": 178,
|
| 182 |
+
"luan": 179,
|
| 183 |
+
"lun": 180,
|
| 184 |
+
"luo": 181,
|
| 185 |
+
"lv": 182,
|
| 186 |
+
"lve": 183,
|
| 187 |
+
"ma": 184,
|
| 188 |
+
"mai": 185,
|
| 189 |
+
"man": 186,
|
| 190 |
+
"mang": 187,
|
| 191 |
+
"mao": 188,
|
| 192 |
+
"me": 189,
|
| 193 |
+
"mei": 190,
|
| 194 |
+
"men": 191,
|
| 195 |
+
"meng": 192,
|
| 196 |
+
"mi": 193,
|
| 197 |
+
"mian": 194,
|
| 198 |
+
"miao": 195,
|
| 199 |
+
"mie": 196,
|
| 200 |
+
"min": 197,
|
| 201 |
+
"ming": 198,
|
| 202 |
+
"miu": 199,
|
| 203 |
+
"mo": 200,
|
| 204 |
+
"mou": 201,
|
| 205 |
+
"mu": 202,
|
| 206 |
+
"na": 203,
|
| 207 |
+
"nai": 204,
|
| 208 |
+
"nan": 205,
|
| 209 |
+
"nang": 206,
|
| 210 |
+
"nao": 207,
|
| 211 |
+
"ne": 208,
|
| 212 |
+
"nei": 209,
|
| 213 |
+
"nen": 210,
|
| 214 |
+
"neng": 211,
|
| 215 |
+
"ni": 212,
|
| 216 |
+
"nian": 213,
|
| 217 |
+
"niang": 214,
|
| 218 |
+
"niao": 215,
|
| 219 |
+
"nie": 216,
|
| 220 |
+
"nin": 217,
|
| 221 |
+
"ning": 218,
|
| 222 |
+
"niu": 219,
|
| 223 |
+
"nong": 220,
|
| 224 |
+
"nu": 221,
|
| 225 |
+
"nuan": 222,
|
| 226 |
+
"nuo": 223,
|
| 227 |
+
"nv": 224,
|
| 228 |
+
"nve": 225,
|
| 229 |
+
"o": 226,
|
| 230 |
+
"ou": 227,
|
| 231 |
+
"pa": 228,
|
| 232 |
+
"pai": 229,
|
| 233 |
+
"pan": 230,
|
| 234 |
+
"pang": 231,
|
| 235 |
+
"pao": 232,
|
| 236 |
+
"pei": 233,
|
| 237 |
+
"pen": 234,
|
| 238 |
+
"peng": 235,
|
| 239 |
+
"pi": 236,
|
| 240 |
+
"pian": 237,
|
| 241 |
+
"piao": 238,
|
| 242 |
+
"pie": 239,
|
| 243 |
+
"pin": 240,
|
| 244 |
+
"ping": 241,
|
| 245 |
+
"po": 242,
|
| 246 |
+
"pou": 243,
|
| 247 |
+
"pu": 244,
|
| 248 |
+
"qi": 245,
|
| 249 |
+
"qia": 246,
|
| 250 |
+
"qian": 247,
|
| 251 |
+
"qiang": 248,
|
| 252 |
+
"qiao": 249,
|
| 253 |
+
"qie": 250,
|
| 254 |
+
"qin": 251,
|
| 255 |
+
"qing": 252,
|
| 256 |
+
"qiong": 253,
|
| 257 |
+
"qiu": 254,
|
| 258 |
+
"qu": 255,
|
| 259 |
+
"quan": 256,
|
| 260 |
+
"que": 257,
|
| 261 |
+
"qun": 258,
|
| 262 |
+
"ran": 259,
|
| 263 |
+
"rang": 260,
|
| 264 |
+
"rao": 261,
|
| 265 |
+
"re": 262,
|
| 266 |
+
"ren": 263,
|
| 267 |
+
"reng": 264,
|
| 268 |
+
"ri": 265,
|
| 269 |
+
"rong": 266,
|
| 270 |
+
"rou": 267,
|
| 271 |
+
"ru": 268,
|
| 272 |
+
"ruan": 269,
|
| 273 |
+
"rui": 270,
|
| 274 |
+
"run": 271,
|
| 275 |
+
"ruo": 272,
|
| 276 |
+
"sa": 273,
|
| 277 |
+
"sai": 274,
|
| 278 |
+
"san": 275,
|
| 279 |
+
"sang": 276,
|
| 280 |
+
"sao": 277,
|
| 281 |
+
"se": 278,
|
| 282 |
+
"sen": 279,
|
| 283 |
+
"seng": 280,
|
| 284 |
+
"sha": 281,
|
| 285 |
+
"shai": 282,
|
| 286 |
+
"shan": 283,
|
| 287 |
+
"shang": 284,
|
| 288 |
+
"shao": 285,
|
| 289 |
+
"she": 286,
|
| 290 |
+
"shei": 287,
|
| 291 |
+
"shen": 288,
|
| 292 |
+
"sheng": 289,
|
| 293 |
+
"shi": 290,
|
| 294 |
+
"shou": 291,
|
| 295 |
+
"shu": 292,
|
| 296 |
+
"shua": 293,
|
| 297 |
+
"shuai": 294,
|
| 298 |
+
"shuan": 295,
|
| 299 |
+
"shuang": 296,
|
| 300 |
+
"shui": 297,
|
| 301 |
+
"shun": 298,
|
| 302 |
+
"shuo": 299,
|
| 303 |
+
"si": 300,
|
| 304 |
+
"song": 301,
|
| 305 |
+
"sou": 302,
|
| 306 |
+
"su": 303,
|
| 307 |
+
"suan": 304,
|
| 308 |
+
"sui": 305,
|
| 309 |
+
"sun": 306,
|
| 310 |
+
"suo": 307,
|
| 311 |
+
"ta": 308,
|
| 312 |
+
"tai": 309,
|
| 313 |
+
"tan": 310,
|
| 314 |
+
"tang": 311,
|
| 315 |
+
"tao": 312,
|
| 316 |
+
"te": 313,
|
| 317 |
+
"teng": 314,
|
| 318 |
+
"ti": 315,
|
| 319 |
+
"tian": 316,
|
| 320 |
+
"tiao": 317,
|
| 321 |
+
"tie": 318,
|
| 322 |
+
"ting": 319,
|
| 323 |
+
"tong": 320,
|
| 324 |
+
"tou": 321,
|
| 325 |
+
"tu": 322,
|
| 326 |
+
"tuan": 323,
|
| 327 |
+
"tui": 324,
|
| 328 |
+
"tun": 325,
|
| 329 |
+
"tuo": 326,
|
| 330 |
+
"wa": 327,
|
| 331 |
+
"wai": 328,
|
| 332 |
+
"wan": 329,
|
| 333 |
+
"wang": 330,
|
| 334 |
+
"wei": 331,
|
| 335 |
+
"wen": 332,
|
| 336 |
+
"weng": 333,
|
| 337 |
+
"wo": 334,
|
| 338 |
+
"wu": 335,
|
| 339 |
+
"xi": 336,
|
| 340 |
+
"xia": 337,
|
| 341 |
+
"xian": 338,
|
| 342 |
+
"xiang": 339,
|
| 343 |
+
"xiao": 340,
|
| 344 |
+
"xie": 341,
|
| 345 |
+
"xin": 342,
|
| 346 |
+
"xing": 343,
|
| 347 |
+
"xiong": 344,
|
| 348 |
+
"xiu": 345,
|
| 349 |
+
"xu": 346,
|
| 350 |
+
"xuan": 347,
|
| 351 |
+
"xue": 348,
|
| 352 |
+
"xun": 349,
|
| 353 |
+
"ya": 350,
|
| 354 |
+
"yan": 351,
|
| 355 |
+
"yang": 352,
|
| 356 |
+
"yao": 353,
|
| 357 |
+
"ye": 354,
|
| 358 |
+
"yi": 355,
|
| 359 |
+
"yin": 356,
|
| 360 |
+
"ying": 357,
|
| 361 |
+
"yo": 358,
|
| 362 |
+
"yong": 359,
|
| 363 |
+
"you": 360,
|
| 364 |
+
"yu": 361,
|
| 365 |
+
"yuan": 362,
|
| 366 |
+
"yue": 363,
|
| 367 |
+
"yun": 364,
|
| 368 |
+
"za": 365,
|
| 369 |
+
"zai": 366,
|
| 370 |
+
"zan": 367,
|
| 371 |
+
"zang": 368,
|
| 372 |
+
"zao": 369,
|
| 373 |
+
"ze": 370,
|
| 374 |
+
"zei": 371,
|
| 375 |
+
"zen": 372,
|
| 376 |
+
"zeng": 373,
|
| 377 |
+
"zha": 374,
|
| 378 |
+
"zhai": 375,
|
| 379 |
+
"zhan": 376,
|
| 380 |
+
"zhang": 377,
|
| 381 |
+
"zhao": 378,
|
| 382 |
+
"zhe": 379,
|
| 383 |
+
"zhen": 380,
|
| 384 |
+
"zheng": 381,
|
| 385 |
+
"zhi": 382,
|
| 386 |
+
"zhong": 383,
|
| 387 |
+
"zhou": 384,
|
| 388 |
+
"zhu": 385,
|
| 389 |
+
"zhua": 386,
|
| 390 |
+
"zhuai": 387,
|
| 391 |
+
"zhuan": 388,
|
| 392 |
+
"zhuang": 389,
|
| 393 |
+
"zhui": 390,
|
| 394 |
+
"zhun": 391,
|
| 395 |
+
"zhuo": 392,
|
| 396 |
+
"zi": 393,
|
| 397 |
+
"zong": 394,
|
| 398 |
+
"zou": 395,
|
| 399 |
+
"zu": 396,
|
| 400 |
+
"zuan": 397,
|
| 401 |
+
"zui": 398,
|
| 402 |
+
"zun": 399,
|
| 403 |
+
"zuo": 400
|
| 404 |
+
},
|
| 405 |
+
"index_to_pinyin": {
|
| 406 |
+
"0": "a",
|
| 407 |
+
"1": "ai",
|
| 408 |
+
"2": "an",
|
| 409 |
+
"3": "ang",
|
| 410 |
+
"4": "ao",
|
| 411 |
+
"5": "ba",
|
| 412 |
+
"6": "bai",
|
| 413 |
+
"7": "ban",
|
| 414 |
+
"8": "bang",
|
| 415 |
+
"9": "bao",
|
| 416 |
+
"10": "bei",
|
| 417 |
+
"11": "ben",
|
| 418 |
+
"12": "beng",
|
| 419 |
+
"13": "bi",
|
| 420 |
+
"14": "bian",
|
| 421 |
+
"15": "biao",
|
| 422 |
+
"16": "bie",
|
| 423 |
+
"17": "bin",
|
| 424 |
+
"18": "bing",
|
| 425 |
+
"19": "bo",
|
| 426 |
+
"20": "bu",
|
| 427 |
+
"21": "ca",
|
| 428 |
+
"22": "cai",
|
| 429 |
+
"23": "can",
|
| 430 |
+
"24": "cang",
|
| 431 |
+
"25": "cao",
|
| 432 |
+
"26": "ce",
|
| 433 |
+
"27": "cen",
|
| 434 |
+
"28": "ceng",
|
| 435 |
+
"29": "cha",
|
| 436 |
+
"30": "chai",
|
| 437 |
+
"31": "chan",
|
| 438 |
+
"32": "chang",
|
| 439 |
+
"33": "chao",
|
| 440 |
+
"34": "che",
|
| 441 |
+
"35": "chen",
|
| 442 |
+
"36": "cheng",
|
| 443 |
+
"37": "chi",
|
| 444 |
+
"38": "chong",
|
| 445 |
+
"39": "chou",
|
| 446 |
+
"40": "chu",
|
| 447 |
+
"41": "chuai",
|
| 448 |
+
"42": "chuan",
|
| 449 |
+
"43": "chuang",
|
| 450 |
+
"44": "chui",
|
| 451 |
+
"45": "chun",
|
| 452 |
+
"46": "chuo",
|
| 453 |
+
"47": "ci",
|
| 454 |
+
"48": "cong",
|
| 455 |
+
"49": "cou",
|
| 456 |
+
"50": "cu",
|
| 457 |
+
"51": "cuan",
|
| 458 |
+
"52": "cui",
|
| 459 |
+
"53": "cun",
|
| 460 |
+
"54": "cuo",
|
| 461 |
+
"55": "da",
|
| 462 |
+
"56": "dai",
|
| 463 |
+
"57": "dan",
|
| 464 |
+
"58": "dang",
|
| 465 |
+
"59": "dao",
|
| 466 |
+
"60": "de",
|
| 467 |
+
"61": "dei",
|
| 468 |
+
"62": "deng",
|
| 469 |
+
"63": "di",
|
| 470 |
+
"64": "dia",
|
| 471 |
+
"65": "dian",
|
| 472 |
+
"66": "diao",
|
| 473 |
+
"67": "die",
|
| 474 |
+
"68": "ding",
|
| 475 |
+
"69": "diu",
|
| 476 |
+
"70": "dong",
|
| 477 |
+
"71": "dou",
|
| 478 |
+
"72": "du",
|
| 479 |
+
"73": "duan",
|
| 480 |
+
"74": "dui",
|
| 481 |
+
"75": "dun",
|
| 482 |
+
"76": "duo",
|
| 483 |
+
"77": "e",
|
| 484 |
+
"78": "en",
|
| 485 |
+
"79": "er",
|
| 486 |
+
"80": "fa",
|
| 487 |
+
"81": "fan",
|
| 488 |
+
"82": "fang",
|
| 489 |
+
"83": "fei",
|
| 490 |
+
"84": "fen",
|
| 491 |
+
"85": "feng",
|
| 492 |
+
"86": "fo",
|
| 493 |
+
"87": "fou",
|
| 494 |
+
"88": "fu",
|
| 495 |
+
"89": "ga",
|
| 496 |
+
"90": "gai",
|
| 497 |
+
"91": "gan",
|
| 498 |
+
"92": "gang",
|
| 499 |
+
"93": "gao",
|
| 500 |
+
"94": "ge",
|
| 501 |
+
"95": "gei",
|
| 502 |
+
"96": "gen",
|
| 503 |
+
"97": "geng",
|
| 504 |
+
"98": "gong",
|
| 505 |
+
"99": "gou",
|
| 506 |
+
"100": "gu",
|
| 507 |
+
"101": "gua",
|
| 508 |
+
"102": "guai",
|
| 509 |
+
"103": "guan",
|
| 510 |
+
"104": "guang",
|
| 511 |
+
"105": "gui",
|
| 512 |
+
"106": "gun",
|
| 513 |
+
"107": "guo",
|
| 514 |
+
"108": "ha",
|
| 515 |
+
"109": "hai",
|
| 516 |
+
"110": "han",
|
| 517 |
+
"111": "hang",
|
| 518 |
+
"112": "hao",
|
| 519 |
+
"113": "he",
|
| 520 |
+
"114": "hei",
|
| 521 |
+
"115": "hen",
|
| 522 |
+
"116": "heng",
|
| 523 |
+
"117": "hong",
|
| 524 |
+
"118": "hou",
|
| 525 |
+
"119": "hu",
|
| 526 |
+
"120": "hua",
|
| 527 |
+
"121": "huai",
|
| 528 |
+
"122": "huan",
|
| 529 |
+
"123": "huang",
|
| 530 |
+
"124": "hui",
|
| 531 |
+
"125": "hun",
|
| 532 |
+
"126": "huo",
|
| 533 |
+
"127": "ji",
|
| 534 |
+
"128": "jia",
|
| 535 |
+
"129": "jian",
|
| 536 |
+
"130": "jiang",
|
| 537 |
+
"131": "jiao",
|
| 538 |
+
"132": "jie",
|
| 539 |
+
"133": "jin",
|
| 540 |
+
"134": "jing",
|
| 541 |
+
"135": "jiong",
|
| 542 |
+
"136": "jiu",
|
| 543 |
+
"137": "ju",
|
| 544 |
+
"138": "juan",
|
| 545 |
+
"139": "jue",
|
| 546 |
+
"140": "jun",
|
| 547 |
+
"141": "ka",
|
| 548 |
+
"142": "kai",
|
| 549 |
+
"143": "kan",
|
| 550 |
+
"144": "kang",
|
| 551 |
+
"145": "kao",
|
| 552 |
+
"146": "ke",
|
| 553 |
+
"147": "ken",
|
| 554 |
+
"148": "keng",
|
| 555 |
+
"149": "kong",
|
| 556 |
+
"150": "kou",
|
| 557 |
+
"151": "ku",
|
| 558 |
+
"152": "kua",
|
| 559 |
+
"153": "kuai",
|
| 560 |
+
"154": "kuan",
|
| 561 |
+
"155": "kuang",
|
| 562 |
+
"156": "kui",
|
| 563 |
+
"157": "kun",
|
| 564 |
+
"158": "kuo",
|
| 565 |
+
"159": "la",
|
| 566 |
+
"160": "lai",
|
| 567 |
+
"161": "lan",
|
| 568 |
+
"162": "lang",
|
| 569 |
+
"163": "lao",
|
| 570 |
+
"164": "le",
|
| 571 |
+
"165": "lei",
|
| 572 |
+
"166": "leng",
|
| 573 |
+
"167": "li",
|
| 574 |
+
"168": "lia",
|
| 575 |
+
"169": "lian",
|
| 576 |
+
"170": "liang",
|
| 577 |
+
"171": "liao",
|
| 578 |
+
"172": "lie",
|
| 579 |
+
"173": "lin",
|
| 580 |
+
"174": "ling",
|
| 581 |
+
"175": "liu",
|
| 582 |
+
"176": "long",
|
| 583 |
+
"177": "lou",
|
| 584 |
+
"178": "lu",
|
| 585 |
+
"179": "luan",
|
| 586 |
+
"180": "lun",
|
| 587 |
+
"181": "luo",
|
| 588 |
+
"182": "lv",
|
| 589 |
+
"183": "lve",
|
| 590 |
+
"184": "ma",
|
| 591 |
+
"185": "mai",
|
| 592 |
+
"186": "man",
|
| 593 |
+
"187": "mang",
|
| 594 |
+
"188": "mao",
|
| 595 |
+
"189": "me",
|
| 596 |
+
"190": "mei",
|
| 597 |
+
"191": "men",
|
| 598 |
+
"192": "meng",
|
| 599 |
+
"193": "mi",
|
| 600 |
+
"194": "mian",
|
| 601 |
+
"195": "miao",
|
| 602 |
+
"196": "mie",
|
| 603 |
+
"197": "min",
|
| 604 |
+
"198": "ming",
|
| 605 |
+
"199": "miu",
|
| 606 |
+
"200": "mo",
|
| 607 |
+
"201": "mou",
|
| 608 |
+
"202": "mu",
|
| 609 |
+
"203": "na",
|
| 610 |
+
"204": "nai",
|
| 611 |
+
"205": "nan",
|
| 612 |
+
"206": "nang",
|
| 613 |
+
"207": "nao",
|
| 614 |
+
"208": "ne",
|
| 615 |
+
"209": "nei",
|
| 616 |
+
"210": "nen",
|
| 617 |
+
"211": "neng",
|
| 618 |
+
"212": "ni",
|
| 619 |
+
"213": "nian",
|
| 620 |
+
"214": "niang",
|
| 621 |
+
"215": "niao",
|
| 622 |
+
"216": "nie",
|
| 623 |
+
"217": "nin",
|
| 624 |
+
"218": "ning",
|
| 625 |
+
"219": "niu",
|
| 626 |
+
"220": "nong",
|
| 627 |
+
"221": "nu",
|
| 628 |
+
"222": "nuan",
|
| 629 |
+
"223": "nuo",
|
| 630 |
+
"224": "nv",
|
| 631 |
+
"225": "nve",
|
| 632 |
+
"226": "o",
|
| 633 |
+
"227": "ou",
|
| 634 |
+
"228": "pa",
|
| 635 |
+
"229": "pai",
|
| 636 |
+
"230": "pan",
|
| 637 |
+
"231": "pang",
|
| 638 |
+
"232": "pao",
|
| 639 |
+
"233": "pei",
|
| 640 |
+
"234": "pen",
|
| 641 |
+
"235": "peng",
|
| 642 |
+
"236": "pi",
|
| 643 |
+
"237": "pian",
|
| 644 |
+
"238": "piao",
|
| 645 |
+
"239": "pie",
|
| 646 |
+
"240": "pin",
|
| 647 |
+
"241": "ping",
|
| 648 |
+
"242": "po",
|
| 649 |
+
"243": "pou",
|
| 650 |
+
"244": "pu",
|
| 651 |
+
"245": "qi",
|
| 652 |
+
"246": "qia",
|
| 653 |
+
"247": "qian",
|
| 654 |
+
"248": "qiang",
|
| 655 |
+
"249": "qiao",
|
| 656 |
+
"250": "qie",
|
| 657 |
+
"251": "qin",
|
| 658 |
+
"252": "qing",
|
| 659 |
+
"253": "qiong",
|
| 660 |
+
"254": "qiu",
|
| 661 |
+
"255": "qu",
|
| 662 |
+
"256": "quan",
|
| 663 |
+
"257": "que",
|
| 664 |
+
"258": "qun",
|
| 665 |
+
"259": "ran",
|
| 666 |
+
"260": "rang",
|
| 667 |
+
"261": "rao",
|
| 668 |
+
"262": "re",
|
| 669 |
+
"263": "ren",
|
| 670 |
+
"264": "reng",
|
| 671 |
+
"265": "ri",
|
| 672 |
+
"266": "rong",
|
| 673 |
+
"267": "rou",
|
| 674 |
+
"268": "ru",
|
| 675 |
+
"269": "ruan",
|
| 676 |
+
"270": "rui",
|
| 677 |
+
"271": "run",
|
| 678 |
+
"272": "ruo",
|
| 679 |
+
"273": "sa",
|
| 680 |
+
"274": "sai",
|
| 681 |
+
"275": "san",
|
| 682 |
+
"276": "sang",
|
| 683 |
+
"277": "sao",
|
| 684 |
+
"278": "se",
|
| 685 |
+
"279": "sen",
|
| 686 |
+
"280": "seng",
|
| 687 |
+
"281": "sha",
|
| 688 |
+
"282": "shai",
|
| 689 |
+
"283": "shan",
|
| 690 |
+
"284": "shang",
|
| 691 |
+
"285": "shao",
|
| 692 |
+
"286": "she",
|
| 693 |
+
"287": "shei",
|
| 694 |
+
"288": "shen",
|
| 695 |
+
"289": "sheng",
|
| 696 |
+
"290": "shi",
|
| 697 |
+
"291": "shou",
|
| 698 |
+
"292": "shu",
|
| 699 |
+
"293": "shua",
|
| 700 |
+
"294": "shuai",
|
| 701 |
+
"295": "shuan",
|
| 702 |
+
"296": "shuang",
|
| 703 |
+
"297": "shui",
|
| 704 |
+
"298": "shun",
|
| 705 |
+
"299": "shuo",
|
| 706 |
+
"300": "si",
|
| 707 |
+
"301": "song",
|
| 708 |
+
"302": "sou",
|
| 709 |
+
"303": "su",
|
| 710 |
+
"304": "suan",
|
| 711 |
+
"305": "sui",
|
| 712 |
+
"306": "sun",
|
| 713 |
+
"307": "suo",
|
| 714 |
+
"308": "ta",
|
| 715 |
+
"309": "tai",
|
| 716 |
+
"310": "tan",
|
| 717 |
+
"311": "tang",
|
| 718 |
+
"312": "tao",
|
| 719 |
+
"313": "te",
|
| 720 |
+
"314": "teng",
|
| 721 |
+
"315": "ti",
|
| 722 |
+
"316": "tian",
|
| 723 |
+
"317": "tiao",
|
| 724 |
+
"318": "tie",
|
| 725 |
+
"319": "ting",
|
| 726 |
+
"320": "tong",
|
| 727 |
+
"321": "tou",
|
| 728 |
+
"322": "tu",
|
| 729 |
+
"323": "tuan",
|
| 730 |
+
"324": "tui",
|
| 731 |
+
"325": "tun",
|
| 732 |
+
"326": "tuo",
|
| 733 |
+
"327": "wa",
|
| 734 |
+
"328": "wai",
|
| 735 |
+
"329": "wan",
|
| 736 |
+
"330": "wang",
|
| 737 |
+
"331": "wei",
|
| 738 |
+
"332": "wen",
|
| 739 |
+
"333": "weng",
|
| 740 |
+
"334": "wo",
|
| 741 |
+
"335": "wu",
|
| 742 |
+
"336": "xi",
|
| 743 |
+
"337": "xia",
|
| 744 |
+
"338": "xian",
|
| 745 |
+
"339": "xiang",
|
| 746 |
+
"340": "xiao",
|
| 747 |
+
"341": "xie",
|
| 748 |
+
"342": "xin",
|
| 749 |
+
"343": "xing",
|
| 750 |
+
"344": "xiong",
|
| 751 |
+
"345": "xiu",
|
| 752 |
+
"346": "xu",
|
| 753 |
+
"347": "xuan",
|
| 754 |
+
"348": "xue",
|
| 755 |
+
"349": "xun",
|
| 756 |
+
"350": "ya",
|
| 757 |
+
"351": "yan",
|
| 758 |
+
"352": "yang",
|
| 759 |
+
"353": "yao",
|
| 760 |
+
"354": "ye",
|
| 761 |
+
"355": "yi",
|
| 762 |
+
"356": "yin",
|
| 763 |
+
"357": "ying",
|
| 764 |
+
"358": "yo",
|
| 765 |
+
"359": "yong",
|
| 766 |
+
"360": "you",
|
| 767 |
+
"361": "yu",
|
| 768 |
+
"362": "yuan",
|
| 769 |
+
"363": "yue",
|
| 770 |
+
"364": "yun",
|
| 771 |
+
"365": "za",
|
| 772 |
+
"366": "zai",
|
| 773 |
+
"367": "zan",
|
| 774 |
+
"368": "zang",
|
| 775 |
+
"369": "zao",
|
| 776 |
+
"370": "ze",
|
| 777 |
+
"371": "zei",
|
| 778 |
+
"372": "zen",
|
| 779 |
+
"373": "zeng",
|
| 780 |
+
"374": "zha",
|
| 781 |
+
"375": "zhai",
|
| 782 |
+
"376": "zhan",
|
| 783 |
+
"377": "zhang",
|
| 784 |
+
"378": "zhao",
|
| 785 |
+
"379": "zhe",
|
| 786 |
+
"380": "zhen",
|
| 787 |
+
"381": "zheng",
|
| 788 |
+
"382": "zhi",
|
| 789 |
+
"383": "zhong",
|
| 790 |
+
"384": "zhou",
|
| 791 |
+
"385": "zhu",
|
| 792 |
+
"386": "zhua",
|
| 793 |
+
"387": "zhuai",
|
| 794 |
+
"388": "zhuan",
|
| 795 |
+
"389": "zhuang",
|
| 796 |
+
"390": "zhui",
|
| 797 |
+
"391": "zhun",
|
| 798 |
+
"392": "zhuo",
|
| 799 |
+
"393": "zi",
|
| 800 |
+
"394": "zong",
|
| 801 |
+
"395": "zou",
|
| 802 |
+
"396": "zu",
|
| 803 |
+
"397": "zuan",
|
| 804 |
+
"398": "zui",
|
| 805 |
+
"399": "zun",
|
| 806 |
+
"400": "zuo"
|
| 807 |
+
}
|
| 808 |
+
}
|
stepstep=024500.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6bd2c480df21e2cd9ee5eadef5d67e69b059e00c30159889502ac1848132950
|
| 3 |
+
size 714212484
|
train_pinyin.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sklearn.metrics import roc_auc_score
|
| 9 |
+
from sklearn.metrics import roc_curve
|
| 10 |
+
|
| 11 |
+
import pytorch_lightning as pl
|
| 12 |
+
from pytorch_lightning import LightningModule, Trainer
|
| 13 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 14 |
+
from transformers import Wav2Vec2FeatureExtractor, HubertModel
|
| 15 |
+
from model_pinyin import MMKWS2
|
| 16 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 17 |
+
|
| 18 |
+
class MMKWS2_Wrapper(LightningModule):
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.model = MMKWS2()
|
| 22 |
+
self.criterion = nn.BCEWithLogitsLoss()
|
| 23 |
+
# 将hubert_model设为临时变量而非类属性
|
| 24 |
+
hubert_model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-large").half().eval()
|
| 25 |
+
self._hubert_model = hubert_model # 使用下划线前缀表示内部使用
|
| 26 |
+
|
| 27 |
+
def training_step(self, batch, batch_idx):
|
| 28 |
+
anchor_wave, anchor_text_embedding, compare_wave, compare_lengths, label, seq_label = \
|
| 29 |
+
batch['anchor_wave'], batch['anchor_embedding'], batch['compare_wave'], batch['compare_lengths'], batch['label'], batch['seq_label']
|
| 30 |
+
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
outputs = self._hubert_model(anchor_wave.half())
|
| 33 |
+
anchor_wave_embedding = outputs.last_hidden_state
|
| 34 |
+
|
| 35 |
+
anchor_wave_embedding = anchor_wave_embedding.to(anchor_wave.dtype)
|
| 36 |
+
|
| 37 |
+
logits, seq_logits = self.model(
|
| 38 |
+
anchor_wave_embedding,
|
| 39 |
+
anchor_text_embedding,
|
| 40 |
+
compare_wave,
|
| 41 |
+
compare_lengths
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# 句级二分类loss
|
| 45 |
+
utt_loss = self.criterion(logits, label.float())
|
| 46 |
+
|
| 47 |
+
# 序列loss(mask掉seq_label为-1的部分)
|
| 48 |
+
mask = (seq_label != -1).float()
|
| 49 |
+
seq_label_valid = seq_label.clone()
|
| 50 |
+
seq_label_valid[seq_label == -1] = 0 # 避免-1影响loss
|
| 51 |
+
|
| 52 |
+
seq_loss = F.binary_cross_entropy_with_logits(
|
| 53 |
+
seq_logits, seq_label_valid.float(), weight=mask, reduction='sum'
|
| 54 |
+
) / (mask.sum() + 1e-6)
|
| 55 |
+
|
| 56 |
+
loss = utt_loss + seq_loss
|
| 57 |
+
|
| 58 |
+
# 每500步记录日志
|
| 59 |
+
self.log('train/utt_loss', utt_loss, on_step=True, on_epoch=False, prog_bar=True)
|
| 60 |
+
self.log('train/seq_loss', seq_loss, on_step=True, on_epoch=False, prog_bar=True)
|
| 61 |
+
self.log('train/loss', loss, on_step=True, on_epoch=False, prog_bar=True)
|
| 62 |
+
return loss
|
| 63 |
+
|
| 64 |
+
def configure_optimizers(self):
|
| 65 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
| 66 |
+
|
| 67 |
+
def lr_lambda(step):
|
| 68 |
+
return 0.95 ** (step // 1000)
|
| 69 |
+
|
| 70 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
| 71 |
+
optimizer,
|
| 72 |
+
lr_lambda=lr_lambda
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return {
|
| 76 |
+
"optimizer": optimizer,
|
| 77 |
+
"lr_scheduler": {
|
| 78 |
+
"scheduler": scheduler,
|
| 79 |
+
"interval": "step", # 按步更新
|
| 80 |
+
"frequency": 1
|
| 81 |
+
},
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# 3. 设置 Trainer 和训练
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
pl.seed_everything(2024)
|
| 88 |
+
import os
|
| 89 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '7'
|
| 90 |
+
from dataset_pinyin import PairDataset
|
| 91 |
+
from dataloader_pinyin import get_dataloader
|
| 92 |
+
|
| 93 |
+
# 创建数据集
|
| 94 |
+
dataset1 = PairDataset(
|
| 95 |
+
'/nvme01/aizq/kws-agent/data/anchor_pairs.parquet',
|
| 96 |
+
'/nvme01/aizq/kws-agent/data/WenetPhrase_base/M_S',
|
| 97 |
+
augment=True
|
| 98 |
+
)
|
| 99 |
+
dataset2 = PairDataset(
|
| 100 |
+
'/nvme01/aizq/kws-agent/data/synthetic_pairs.parquet',
|
| 101 |
+
'/nvme01/aizq/kws-agent/data/WenetPhrase_base/M_S',
|
| 102 |
+
augment=True
|
| 103 |
+
)
|
| 104 |
+
dataset = ConcatDataset([dataset1, dataset2])
|
| 105 |
+
# 创建dataloader
|
| 106 |
+
dataloader = get_dataloader(dataset, batch_size=1024)
|
| 107 |
+
|
| 108 |
+
model = MMKWS2_Wrapper()
|
| 109 |
+
model_checkpoint = ModelCheckpoint(
|
| 110 |
+
dirpath="/nvme01/aizq/kws-agent/ckpts",
|
| 111 |
+
filename="step{step:06d}",
|
| 112 |
+
save_top_k=-1,
|
| 113 |
+
save_on_train_epoch_end=False, # 按训练步数保存
|
| 114 |
+
every_n_train_steps=500 # 每10k步保存一次
|
| 115 |
+
)
|
| 116 |
+
logger = pl.loggers.TensorBoardLogger('/nvme01/aizq/kws-agent/logs/', name='MMKWS+')
|
| 117 |
+
trainer = Trainer(
|
| 118 |
+
devices=1,
|
| 119 |
+
accelerator='gpu',
|
| 120 |
+
logger=logger,
|
| 121 |
+
max_epochs=4, # 训练1轮
|
| 122 |
+
callbacks=[model_checkpoint],
|
| 123 |
+
accumulate_grad_batches=2, # 2048 batchsize
|
| 124 |
+
)
|
| 125 |
+
trainer.fit(model, train_dataloaders=dataloader)
|
tts-2025-04-27@197cd135f2a2451b9cab9cf2add1c1ab.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3947a005d2c85f427cead98d3aeb65200bb85a956383aad43467b1b6a7f0c5f1
|
| 3 |
+
size 279404
|
ui.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from torchaudio.compliance.kaldi import fbank
|
| 10 |
+
from pypinyin import lazy_pinyin
|
| 11 |
+
from train_pinyin import MMKWS2_Wrapper
|
| 12 |
+
|
| 13 |
+
# 设备与模型加载
|
| 14 |
+
# device = torch.device("cuda:4")
|
| 15 |
+
device = torch.device("cpu")
|
| 16 |
+
wrapper = MMKWS2_Wrapper.load_from_checkpoint(
|
| 17 |
+
"stepstep=024500.ckpt",
|
| 18 |
+
map_location=device,
|
| 19 |
+
)
|
| 20 |
+
wrapper.eval()
|
| 21 |
+
|
| 22 |
+
# 注册信息
|
| 23 |
+
registered = {"text": "", "audios": []}
|
| 24 |
+
enroll = None
|
| 25 |
+
enroll_text = None
|
| 26 |
+
last_wake_time = 0
|
| 27 |
+
|
| 28 |
+
def load_pinyin_index(save_path):
|
| 29 |
+
"""加载拼音索引映射"""
|
| 30 |
+
with open(save_path, "r", encoding="utf-8") as f:
|
| 31 |
+
data = json.load(f)
|
| 32 |
+
return data["pinyin_to_index"], data["index_to_pinyin"]
|
| 33 |
+
|
| 34 |
+
pinyin_to_index, index_to_pinyin = load_pinyin_index("pinyin_index.json")
|
| 35 |
+
|
| 36 |
+
def add_audio(text, audio, audio_list):
|
| 37 |
+
if not text:
|
| 38 |
+
return audio_list, "请先输入唤醒词文本"
|
| 39 |
+
if audio is None or audio[1] is None or len(audio[1]) == 0:
|
| 40 |
+
return audio_list, "请上传或录制音频"
|
| 41 |
+
audio_list = audio_list or []
|
| 42 |
+
if len(audio_list) >= 5:
|
| 43 |
+
return audio_list, "最多支持5条音频"
|
| 44 |
+
audio_list.append(audio)
|
| 45 |
+
return audio_list, f"已录入 {len(audio_list)} 条音频"
|
| 46 |
+
|
| 47 |
+
def register_keyword(text, audio_list):
|
| 48 |
+
if not text:
|
| 49 |
+
return gr.update(value="请先输入唤醒词文本")
|
| 50 |
+
if not audio_list or len(audio_list) == 0:
|
| 51 |
+
return gr.update(value="请至少上传或录制一条音频")
|
| 52 |
+
registered["text"] = text
|
| 53 |
+
registered["audios"] = audio_list
|
| 54 |
+
global enroll_text
|
| 55 |
+
enroll_text = text
|
| 56 |
+
fused_feats = []
|
| 57 |
+
for audio in audio_list:
|
| 58 |
+
anchor_wave, _ = torchaudio.load(audio)
|
| 59 |
+
anchor_text_embedding = torch.tensor([pinyin_to_index[p] + 1 for p in lazy_pinyin(text)])
|
| 60 |
+
anchor_wave = anchor_wave.to(device)
|
| 61 |
+
anchor_text_embedding = anchor_text_embedding.to(device).unsqueeze(0)
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
outputs = wrapper._hubert_model(anchor_wave.half())
|
| 64 |
+
anchor_wave_embedding = outputs.last_hidden_state
|
| 65 |
+
anchor_wave_embedding = anchor_wave_embedding.to(anchor_wave.dtype)
|
| 66 |
+
fused_feat = wrapper.model.enrollment(
|
| 67 |
+
anchor_wave_embedding,
|
| 68 |
+
anchor_text_embedding
|
| 69 |
+
)
|
| 70 |
+
fused_feats.append(fused_feat)
|
| 71 |
+
fused_feats = torch.cat(fused_feats, dim=0)
|
| 72 |
+
fused_feats, _ = fused_feats.max(dim=0)
|
| 73 |
+
fused_feats = fused_feats.unsqueeze(0)
|
| 74 |
+
global enroll
|
| 75 |
+
enroll = fused_feats
|
| 76 |
+
return gr.update(value=f"注册完成,唤醒词:{text},音频数:{len(audio_list)}")
|
| 77 |
+
|
| 78 |
+
def update_gallery(audio_list):
|
| 79 |
+
if audio_list and len(audio_list) > 0:
|
| 80 |
+
return gr.update(visible=True, value=audio_list[-1])
|
| 81 |
+
else:
|
| 82 |
+
return gr.update(visible=False, value=None)
|
| 83 |
+
|
| 84 |
+
def streaming_detect_handler(current_audio, state, audio_player):
|
| 85 |
+
global last_wake_time, enroll_text, enroll
|
| 86 |
+
if current_audio is None or current_audio[1] is None or len(current_audio[1]) == 0:
|
| 87 |
+
return state, gr.update()
|
| 88 |
+
if enroll_text is None:
|
| 89 |
+
return state, gr.update()
|
| 90 |
+
pad = len(enroll_text) * 5
|
| 91 |
+
state = (state or []) + [current_audio]
|
| 92 |
+
state = state[-pad:]
|
| 93 |
+
if len(state) < pad:
|
| 94 |
+
return state, gr.update()
|
| 95 |
+
sr = state[0][0]
|
| 96 |
+
audio_list = [x[1] for x in state]
|
| 97 |
+
audio_concat = np.concatenate(audio_list, axis=0)
|
| 98 |
+
audio_concat = audio_concat.astype(np.float32) / 32768.0
|
| 99 |
+
audio_concat = torch.from_numpy(audio_concat).unsqueeze(0)
|
| 100 |
+
audio_concat = torchaudio.functional.resample(audio_concat, sr, 16000)
|
| 101 |
+
audio_concat = audio_concat / torch.max(torch.abs(audio_concat))
|
| 102 |
+
compare_wave = fbank(audio_concat, num_mel_bins=80)
|
| 103 |
+
compare_wave = compare_wave.to(device).unsqueeze(0)
|
| 104 |
+
compare_lengths = torch.tensor([compare_wave.size(1)], device=compare_wave.device)
|
| 105 |
+
if enroll is None:
|
| 106 |
+
return state, None
|
| 107 |
+
current_time = time.time()
|
| 108 |
+
if current_time - last_wake_time <= 2:
|
| 109 |
+
return state, gr.update()
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
preds = wrapper.model.verification(
|
| 112 |
+
enroll,
|
| 113 |
+
compare_wave,
|
| 114 |
+
compare_lengths
|
| 115 |
+
)
|
| 116 |
+
preds = torch.sigmoid(preds).item()
|
| 117 |
+
if preds >= 0.85:
|
| 118 |
+
print(f"Wake up! {preds}")
|
| 119 |
+
last_wake_time = current_time
|
| 120 |
+
audio_path = "tts-2025-04-27@197cd135f2a2451b9cab9cf2add1c1ab.wav"
|
| 121 |
+
state = []
|
| 122 |
+
return state, gr.Audio(value=audio_path, visible=True, autoplay=True)
|
| 123 |
+
if preds >= 0.6:
|
| 124 |
+
print(f"Preds! {preds}")
|
| 125 |
+
return state, None
|
| 126 |
+
|
| 127 |
+
with gr.Blocks() as demo:
|
| 128 |
+
gr.Markdown("# 自定义关键词检测 Demo")
|
| 129 |
+
with gr.Row():
|
| 130 |
+
with gr.Column(scale=1):
|
| 131 |
+
gr.Markdown("## 注册唤醒词")
|
| 132 |
+
text_input = gr.Textbox(label="唤醒词文本", placeholder="请输入唤醒词")
|
| 133 |
+
audio_list_state = gr.State([])
|
| 134 |
+
audio_input = gr.Audio(label="上传或录制音频", type="filepath")
|
| 135 |
+
add_btn = gr.Button("添加音频")
|
| 136 |
+
audio_status = gr.Textbox(label="音频状态", interactive=False)
|
| 137 |
+
audio_gallery = gr.Audio(label="已添加音频", type="filepath", interactive=False, visible=False)
|
| 138 |
+
register_btn = gr.Button("注册完成")
|
| 139 |
+
register_status = gr.Textbox(label="注册状态", interactive=False)
|
| 140 |
+
add_btn.click(
|
| 141 |
+
add_audio,
|
| 142 |
+
inputs=[text_input, audio_input, audio_list_state],
|
| 143 |
+
outputs=[audio_list_state, audio_status]
|
| 144 |
+
).then(
|
| 145 |
+
update_gallery,
|
| 146 |
+
inputs=audio_list_state,
|
| 147 |
+
outputs=audio_gallery
|
| 148 |
+
)
|
| 149 |
+
register_btn.click(
|
| 150 |
+
register_keyword,
|
| 151 |
+
inputs=[text_input, audio_list_state],
|
| 152 |
+
outputs=register_status
|
| 153 |
+
)
|
| 154 |
+
with gr.Column(scale=2):
|
| 155 |
+
gr.Markdown("## 实时检测")
|
| 156 |
+
mic = gr.Audio(sources="microphone", streaming=True, label="实时监听")
|
| 157 |
+
state = gr.State(value=[])
|
| 158 |
+
audio_player = gr.Audio(label="唤醒提示", visible=False)
|
| 159 |
+
mic.stream(
|
| 160 |
+
streaming_detect_handler,
|
| 161 |
+
inputs=[mic, state, audio_player],
|
| 162 |
+
outputs=[state, audio_player],
|
| 163 |
+
time_limit=1000,
|
| 164 |
+
stream_every=0.05,
|
| 165 |
+
)
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
demo.launch()
|