ZhiqiAi commited on
Commit
693e27f
·
verified ·
1 Parent(s): d4b7dc7

Upload 16 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tts-2025-04-27@197cd135f2a2451b9cab9cf2add1c1ab.wav filter=lfs diff=lfs merge=lfs -text
conformer/conformer/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .model import Conformer
conformer/conformer/activation.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch.nn as nn
16
+ from torch import Tensor
17
+
18
+
19
+ class Swish(nn.Module):
20
+ """
21
+ Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
22
+ to a variety of challenging domains such as Image classification and Machine translation.
23
+ """
24
+ def __init__(self):
25
+ super(Swish, self).__init__()
26
+
27
+ def forward(self, inputs: Tensor) -> Tensor:
28
+ return inputs * inputs.sigmoid()
29
+
30
+
31
+ class GLU(nn.Module):
32
+ """
33
+ The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
34
+ in the paper “Language Modeling with Gated Convolutional Networks”
35
+ """
36
+ def __init__(self, dim: int) -> None:
37
+ super(GLU, self).__init__()
38
+ self.dim = dim
39
+
40
+ def forward(self, inputs: Tensor) -> Tensor:
41
+ outputs, gate = inputs.chunk(2, dim=self.dim)
42
+ return outputs * gate.sigmoid()
conformer/conformer/attention.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch import Tensor
20
+ from typing import Optional
21
+
22
+ from .embedding import PositionalEncoding
23
+ from .modules import Linear
24
+
25
+
26
+ class RelativeMultiHeadAttention(nn.Module):
27
+ """
28
+ Multi-head attention with relative positional encoding.
29
+ This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
30
+
31
+ Args:
32
+ d_model (int): The dimension of model
33
+ num_heads (int): The number of attention heads.
34
+ dropout_p (float): probability of dropout
35
+
36
+ Inputs: query, key, value, pos_embedding, mask
37
+ - **query** (batch, time, dim): Tensor containing query vector
38
+ - **key** (batch, time, dim): Tensor containing key vector
39
+ - **value** (batch, time, dim): Tensor containing value vector
40
+ - **pos_embedding** (batch, time, dim): Positional embedding tensor
41
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
42
+
43
+ Returns:
44
+ - **outputs**: Tensor produces by relative multi head attention module.
45
+ """
46
+ def __init__(
47
+ self,
48
+ d_model: int = 512,
49
+ num_heads: int = 16,
50
+ dropout_p: float = 0.1,
51
+ ):
52
+ super(RelativeMultiHeadAttention, self).__init__()
53
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
54
+ self.d_model = d_model
55
+ self.d_head = int(d_model / num_heads)
56
+ self.num_heads = num_heads
57
+ self.sqrt_dim = math.sqrt(d_model)
58
+
59
+ self.query_proj = Linear(d_model, d_model)
60
+ self.key_proj = Linear(d_model, d_model)
61
+ self.value_proj = Linear(d_model, d_model)
62
+ self.pos_proj = Linear(d_model, d_model, bias=False)
63
+
64
+ self.dropout = nn.Dropout(p=dropout_p)
65
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
66
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
67
+ torch.nn.init.xavier_uniform_(self.u_bias)
68
+ torch.nn.init.xavier_uniform_(self.v_bias)
69
+
70
+ self.out_proj = Linear(d_model, d_model)
71
+
72
+ def forward(
73
+ self,
74
+ query: Tensor,
75
+ key: Tensor,
76
+ value: Tensor,
77
+ pos_embedding: Tensor,
78
+ mask: Optional[Tensor] = None,
79
+ ) -> Tensor:
80
+ batch_size = value.size(0)
81
+
82
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
83
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
84
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
85
+ pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
86
+
87
+ content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
88
+ pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
89
+ pos_score = self._relative_shift(pos_score)
90
+
91
+ score = (content_score + pos_score) / self.sqrt_dim
92
+
93
+ if mask is not None:
94
+ mask = mask.unsqueeze(1)
95
+ score.masked_fill_(mask, -1e9)
96
+
97
+ attn = F.softmax(score, -1)
98
+ attn = self.dropout(attn)
99
+
100
+ context = torch.matmul(attn, value).transpose(1, 2)
101
+ context = context.contiguous().view(batch_size, -1, self.d_model)
102
+
103
+ return self.out_proj(context)
104
+
105
+ def _relative_shift(self, pos_score: Tensor) -> Tensor:
106
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
107
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
108
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
109
+
110
+ padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
111
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
112
+
113
+ return pos_score
114
+
115
+
116
+ class MultiHeadedSelfAttentionModule(nn.Module):
117
+ """
118
+ Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
119
+ the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
120
+ module to generalize better on different input length and the resulting encoder is more robust to the variance of
121
+ the utterance length. Conformer use prenorm residual units with dropout which helps training
122
+ and regularizing deeper models.
123
+
124
+ Args:
125
+ d_model (int): The dimension of model
126
+ num_heads (int): The number of attention heads.
127
+ dropout_p (float): probability of dropout
128
+
129
+ Inputs: inputs, mask
130
+ - **inputs** (batch, time, dim): Tensor containing input vector
131
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
132
+
133
+ Returns:
134
+ - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
135
+ """
136
+ def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1):
137
+ super(MultiHeadedSelfAttentionModule, self).__init__()
138
+ self.positional_encoding = PositionalEncoding(d_model)
139
+ self.layer_norm = nn.LayerNorm(d_model)
140
+ self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p)
141
+ self.dropout = nn.Dropout(p=dropout_p)
142
+
143
+ def forward(self, inputs: Tensor, mask: Optional[Tensor] = None):
144
+ batch_size, seq_length, _ = inputs.size()
145
+ pos_embedding = self.positional_encoding(seq_length)
146
+ pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
147
+
148
+ inputs = self.layer_norm(inputs)
149
+ outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask)
150
+
151
+ return self.dropout(outputs)
conformer/conformer/convolution.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch import Tensor
18
+ from typing import Tuple
19
+
20
+ from .activation import Swish, GLU
21
+ from .modules import Transpose
22
+
23
+
24
+ class DepthwiseConv1d(nn.Module):
25
+ """
26
+ When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
27
+ this operation is termed in literature as depthwise convolution.
28
+
29
+ Args:
30
+ in_channels (int): Number of channels in the input
31
+ out_channels (int): Number of channels produced by the convolution
32
+ kernel_size (int or tuple): Size of the convolving kernel
33
+ stride (int, optional): Stride of the convolution. Default: 1
34
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
35
+ bias (bool, optional): If True, adds a learnable bias to the output. Default: True
36
+
37
+ Inputs: inputs
38
+ - **inputs** (batch, in_channels, time): Tensor containing input vector
39
+
40
+ Returns: outputs
41
+ - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
42
+ """
43
+ def __init__(
44
+ self,
45
+ in_channels: int,
46
+ out_channels: int,
47
+ kernel_size: int,
48
+ stride: int = 1,
49
+ padding: int = 0,
50
+ bias: bool = False,
51
+ ) -> None:
52
+ super(DepthwiseConv1d, self).__init__()
53
+ assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
54
+ self.conv = nn.Conv1d(
55
+ in_channels=in_channels,
56
+ out_channels=out_channels,
57
+ kernel_size=kernel_size,
58
+ groups=in_channels,
59
+ stride=stride,
60
+ padding=padding,
61
+ bias=bias,
62
+ )
63
+
64
+ def forward(self, inputs: Tensor) -> Tensor:
65
+ return self.conv(inputs)
66
+
67
+
68
+ class PointwiseConv1d(nn.Module):
69
+ """
70
+ When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
71
+ This operation often used to match dimensions.
72
+
73
+ Args:
74
+ in_channels (int): Number of channels in the input
75
+ out_channels (int): Number of channels produced by the convolution
76
+ stride (int, optional): Stride of the convolution. Default: 1
77
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
78
+ bias (bool, optional): If True, adds a learnable bias to the output. Default: True
79
+
80
+ Inputs: inputs
81
+ - **inputs** (batch, in_channels, time): Tensor containing input vector
82
+
83
+ Returns: outputs
84
+ - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
85
+ """
86
+ def __init__(
87
+ self,
88
+ in_channels: int,
89
+ out_channels: int,
90
+ stride: int = 1,
91
+ padding: int = 0,
92
+ bias: bool = True,
93
+ ) -> None:
94
+ super(PointwiseConv1d, self).__init__()
95
+ self.conv = nn.Conv1d(
96
+ in_channels=in_channels,
97
+ out_channels=out_channels,
98
+ kernel_size=1,
99
+ stride=stride,
100
+ padding=padding,
101
+ bias=bias,
102
+ )
103
+
104
+ def forward(self, inputs: Tensor) -> Tensor:
105
+ return self.conv(inputs)
106
+
107
+
108
+ class ConformerConvModule(nn.Module):
109
+ """
110
+ Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
111
+ This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
112
+ to aid training deep models.
113
+
114
+ Args:
115
+ in_channels (int): Number of channels in the input
116
+ kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
117
+ dropout_p (float, optional): probability of dropout
118
+
119
+ Inputs: inputs
120
+ inputs (batch, time, dim): Tensor contains input sequences
121
+
122
+ Outputs: outputs
123
+ outputs (batch, time, dim): Tensor produces by conformer convolution module.
124
+ """
125
+ def __init__(
126
+ self,
127
+ in_channels: int,
128
+ kernel_size: int = 31,
129
+ expansion_factor: int = 2,
130
+ dropout_p: float = 0.1,
131
+ ) -> None:
132
+ super(ConformerConvModule, self).__init__()
133
+ assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
134
+ assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
135
+
136
+ self.sequential = nn.Sequential(
137
+ nn.LayerNorm(in_channels),
138
+ Transpose(shape=(1, 2)),
139
+ PointwiseConv1d(in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True),
140
+ GLU(dim=1),
141
+ DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
142
+ nn.BatchNorm1d(in_channels),
143
+ Swish(),
144
+ PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True),
145
+ nn.Dropout(p=dropout_p),
146
+ )
147
+
148
+ def forward(self, inputs: Tensor) -> Tensor:
149
+ return self.sequential(inputs).transpose(1, 2)
150
+
151
+
152
+ class Conv2dSubampling(nn.Module):
153
+ """
154
+ Convolutional 2D subsampling (to 1/4 length)
155
+
156
+ Args:
157
+ in_channels (int): Number of channels in the input image
158
+ out_channels (int): Number of channels produced by the convolution
159
+
160
+ Inputs: inputs
161
+ - **inputs** (batch, time, dim): Tensor containing sequence of inputs
162
+
163
+ Returns: outputs, output_lengths
164
+ - **outputs** (batch, time, dim): Tensor produced by the convolution
165
+ - **output_lengths** (batch): list of sequence output lengths
166
+ """
167
+ def __init__(self, in_channels: int, out_channels: int) -> None:
168
+ super(Conv2dSubampling, self).__init__()
169
+ self.sequential = nn.Sequential(
170
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2),
171
+ nn.ReLU(),
172
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2),
173
+ nn.ReLU(),
174
+ )
175
+
176
+ def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
177
+ outputs = self.sequential(inputs.unsqueeze(1))
178
+ batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
179
+
180
+ outputs = outputs.permute(0, 2, 1, 3)
181
+ outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim)
182
+
183
+ output_lengths = input_lengths >> 2
184
+ output_lengths -= 1
185
+
186
+ return outputs, output_lengths
conformer/conformer/embedding.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+
20
+
21
+ class PositionalEncoding(nn.Module):
22
+ """
23
+ Positional Encoding proposed in "Attention Is All You Need".
24
+ Since transformer contains no recurrence and no convolution, in order for the model to make
25
+ use of the order of the sequence, we must add some positional information.
26
+
27
+ "Attention Is All You Need" use sine and cosine functions of different frequencies:
28
+ PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
29
+ PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
30
+ """
31
+ def __init__(self, d_model: int = 512, max_len: int = 10000) -> None:
32
+ super(PositionalEncoding, self).__init__()
33
+ pe = torch.zeros(max_len, d_model, requires_grad=False)
34
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
35
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
36
+ pe[:, 0::2] = torch.sin(position * div_term)
37
+ pe[:, 1::2] = torch.cos(position * div_term)
38
+ pe = pe.unsqueeze(0)
39
+ self.register_buffer('pe', pe)
40
+
41
+ def forward(self, length: int) -> Tensor:
42
+ return self.pe[:, :length]
conformer/conformer/encoder.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch import Tensor
18
+ from typing import Tuple
19
+
20
+ from .feed_forward import FeedForwardModule
21
+ from .attention import MultiHeadedSelfAttentionModule
22
+ from .convolution import (
23
+ ConformerConvModule,
24
+ Conv2dSubampling,
25
+ )
26
+ from .modules import (
27
+ ResidualConnectionModule,
28
+ Linear,
29
+ )
30
+
31
+
32
+ class ConformerBlock(nn.Module):
33
+ """
34
+ Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
35
+ and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
36
+ the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
37
+ one before the attention layer and one after.
38
+
39
+ Args:
40
+ encoder_dim (int, optional): Dimension of conformer encoder
41
+ num_attention_heads (int, optional): Number of attention heads
42
+ feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
43
+ conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
44
+ feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
45
+ attention_dropout_p (float, optional): Probability of attention module dropout
46
+ conv_dropout_p (float, optional): Probability of conformer convolution module dropout
47
+ conv_kernel_size (int or tuple, optional): Size of the convolving kernel
48
+ half_step_residual (bool): Flag indication whether to use half step residual or not
49
+
50
+ Inputs: inputs
51
+ - **inputs** (batch, time, dim): Tensor containing input vector
52
+
53
+ Returns: outputs
54
+ - **outputs** (batch, time, dim): Tensor produces by conformer block.
55
+ """
56
+ def __init__(
57
+ self,
58
+ encoder_dim: int = 512,
59
+ num_attention_heads: int = 8,
60
+ feed_forward_expansion_factor: int = 4,
61
+ conv_expansion_factor: int = 2,
62
+ feed_forward_dropout_p: float = 0.1,
63
+ attention_dropout_p: float = 0.1,
64
+ conv_dropout_p: float = 0.1,
65
+ conv_kernel_size: int = 31,
66
+ half_step_residual: bool = True,
67
+ ):
68
+ super(ConformerBlock, self).__init__()
69
+ if half_step_residual:
70
+ self.feed_forward_residual_factor = 0.5
71
+ else:
72
+ self.feed_forward_residual_factor = 1
73
+
74
+ self.sequential = nn.Sequential(
75
+ ResidualConnectionModule(
76
+ module=FeedForwardModule(
77
+ encoder_dim=encoder_dim,
78
+ expansion_factor=feed_forward_expansion_factor,
79
+ dropout_p=feed_forward_dropout_p,
80
+ ),
81
+ module_factor=self.feed_forward_residual_factor,
82
+ ),
83
+ ResidualConnectionModule(
84
+ module=MultiHeadedSelfAttentionModule(
85
+ d_model=encoder_dim,
86
+ num_heads=num_attention_heads,
87
+ dropout_p=attention_dropout_p,
88
+ ),
89
+ ),
90
+ ResidualConnectionModule(
91
+ module=ConformerConvModule(
92
+ in_channels=encoder_dim,
93
+ kernel_size=conv_kernel_size,
94
+ expansion_factor=conv_expansion_factor,
95
+ dropout_p=conv_dropout_p,
96
+ ),
97
+ ),
98
+ ResidualConnectionModule(
99
+ module=FeedForwardModule(
100
+ encoder_dim=encoder_dim,
101
+ expansion_factor=feed_forward_expansion_factor,
102
+ dropout_p=feed_forward_dropout_p,
103
+ ),
104
+ module_factor=self.feed_forward_residual_factor,
105
+ ),
106
+ nn.LayerNorm(encoder_dim),
107
+ )
108
+
109
+ def forward(self, inputs: Tensor) -> Tensor:
110
+ return self.sequential(inputs)
111
+
112
+
113
+ class ConformerEncoder(nn.Module):
114
+ """
115
+ Conformer encoder first processes the input with a convolution subsampling layer and then
116
+ with a number of conformer blocks.
117
+
118
+ Args:
119
+ input_dim (int, optional): Dimension of input vector
120
+ encoder_dim (int, optional): Dimension of conformer encoder
121
+ num_layers (int, optional): Number of conformer blocks
122
+ num_attention_heads (int, optional): Number of attention heads
123
+ feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
124
+ conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
125
+ feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
126
+ attention_dropout_p (float, optional): Probability of attention module dropout
127
+ conv_dropout_p (float, optional): Probability of conformer convolution module dropout
128
+ conv_kernel_size (int or tuple, optional): Size of the convolving kernel
129
+ half_step_residual (bool): Flag indication whether to use half step residual or not
130
+
131
+ Inputs: inputs, input_lengths
132
+ - **inputs** (batch, time, dim): Tensor containing input vector
133
+ - **input_lengths** (batch): list of sequence input lengths
134
+
135
+ Returns: outputs, output_lengths
136
+ - **outputs** (batch, out_channels, time): Tensor produces by conformer encoder.
137
+ - **output_lengths** (batch): list of sequence output lengths
138
+ """
139
+ def __init__(
140
+ self,
141
+ input_dim: int = 80,
142
+ encoder_dim: int = 512,
143
+ num_layers: int = 17,
144
+ num_attention_heads: int = 8,
145
+ feed_forward_expansion_factor: int = 4,
146
+ conv_expansion_factor: int = 2,
147
+ input_dropout_p: float = 0.1,
148
+ feed_forward_dropout_p: float = 0.1,
149
+ attention_dropout_p: float = 0.1,
150
+ conv_dropout_p: float = 0.1,
151
+ conv_kernel_size: int = 31,
152
+ half_step_residual: bool = True,
153
+ ):
154
+ super(ConformerEncoder, self).__init__()
155
+ self.conv_subsample = Conv2dSubampling(in_channels=1, out_channels=encoder_dim)
156
+ self.input_projection = nn.Sequential(
157
+ Linear(encoder_dim * (((input_dim - 1) // 2 - 1) // 2), encoder_dim),
158
+ nn.Dropout(p=input_dropout_p),
159
+ )
160
+ self.layers = nn.ModuleList([ConformerBlock(
161
+ encoder_dim=encoder_dim,
162
+ num_attention_heads=num_attention_heads,
163
+ feed_forward_expansion_factor=feed_forward_expansion_factor,
164
+ conv_expansion_factor=conv_expansion_factor,
165
+ feed_forward_dropout_p=feed_forward_dropout_p,
166
+ attention_dropout_p=attention_dropout_p,
167
+ conv_dropout_p=conv_dropout_p,
168
+ conv_kernel_size=conv_kernel_size,
169
+ half_step_residual=half_step_residual,
170
+ ) for _ in range(num_layers)])
171
+
172
+ def count_parameters(self) -> int:
173
+ """ Count parameters of encoder """
174
+ return sum([p.numel() for p in self.parameters()])
175
+
176
+ def update_dropout(self, dropout_p: float) -> None:
177
+ """ Update dropout probability of encoder """
178
+ for name, child in self.named_children():
179
+ if isinstance(child, nn.Dropout):
180
+ child.p = dropout_p
181
+
182
+ def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
183
+ """
184
+ Forward propagate a `inputs` for encoder training.
185
+
186
+ Args:
187
+ inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
188
+ `FloatTensor` of size ``(batch, seq_length, dimension)``.
189
+ input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
190
+
191
+ Returns:
192
+ (Tensor, Tensor)
193
+
194
+ * outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
195
+ ``(batch, seq_length, dimension)``
196
+ * output_lengths (torch.LongTensor): The length of output tensor. ``(batch)``
197
+ """
198
+ outputs, output_lengths = self.conv_subsample(inputs, input_lengths)
199
+ outputs = self.input_projection(outputs)
200
+
201
+ for layer in self.layers:
202
+ outputs = layer(outputs)
203
+
204
+ return outputs, output_lengths
conformer/conformer/feed_forward.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch import Tensor
18
+
19
+ from .activation import Swish
20
+ from .modules import Linear
21
+
22
+
23
+ class FeedForwardModule(nn.Module):
24
+ """
25
+ Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit
26
+ and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps
27
+ regularizing the network.
28
+
29
+ Args:
30
+ encoder_dim (int): Dimension of conformer encoder
31
+ expansion_factor (int): Expansion factor of feed forward module.
32
+ dropout_p (float): Ratio of dropout
33
+
34
+ Inputs: inputs
35
+ - **inputs** (batch, time, dim): Tensor contains input sequences
36
+
37
+ Outputs: outputs
38
+ - **outputs** (batch, time, dim): Tensor produces by feed forward module.
39
+ """
40
+ def __init__(
41
+ self,
42
+ encoder_dim: int = 512,
43
+ expansion_factor: int = 4,
44
+ dropout_p: float = 0.1,
45
+ ) -> None:
46
+ super(FeedForwardModule, self).__init__()
47
+ self.sequential = nn.Sequential(
48
+ nn.LayerNorm(encoder_dim),
49
+ Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
50
+ Swish(),
51
+ nn.Dropout(p=dropout_p),
52
+ Linear(encoder_dim * expansion_factor, encoder_dim, bias=True),
53
+ nn.Dropout(p=dropout_p),
54
+ )
55
+
56
+ def forward(self, inputs: Tensor) -> Tensor:
57
+ return self.sequential(inputs)
conformer/conformer/model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch import Tensor
18
+ from typing import Tuple
19
+
20
+ from .encoder import ConformerEncoder
21
+ from .modules import Linear
22
+
23
+
24
+ class Conformer(nn.Module):
25
+ """
26
+ Conformer: Convolution-augmented Transformer for Speech Recognition
27
+ The paper used a one-lstm Transducer decoder, currently still only implemented
28
+ the conformer encoder shown in the paper.
29
+
30
+ Args:
31
+ num_classes (int): Number of classification classes
32
+ input_dim (int, optional): Dimension of input vector
33
+ encoder_dim (int, optional): Dimension of conformer encoder
34
+ num_encoder_layers (int, optional): Number of conformer blocks
35
+ num_attention_heads (int, optional): Number of attention heads
36
+ feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
37
+ conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
38
+ feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
39
+ attention_dropout_p (float, optional): Probability of attention module dropout
40
+ conv_dropout_p (float, optional): Probability of conformer convolution module dropout
41
+ conv_kernel_size (int or tuple, optional): Size of the convolving kernel
42
+ half_step_residual (bool): Flag indication whether to use half step residual or not
43
+
44
+ Inputs: inputs, input_lengths
45
+ - **inputs** (batch, time, dim): Tensor containing input vector
46
+ - **input_lengths** (batch): list of sequence input lengths
47
+
48
+ Returns: outputs, output_lengths
49
+ - **outputs** (batch, out_channels, time): Tensor produces by conformer.
50
+ - **output_lengths** (batch): list of sequence output lengths
51
+ """
52
+ def __init__(
53
+ self,
54
+ num_classes: int,
55
+ input_dim: int = 80,
56
+ encoder_dim: int = 512,
57
+ num_encoder_layers: int = 17,
58
+ num_attention_heads: int = 8,
59
+ feed_forward_expansion_factor: int = 4,
60
+ conv_expansion_factor: int = 2,
61
+ input_dropout_p: float = 0.1,
62
+ feed_forward_dropout_p: float = 0.1,
63
+ attention_dropout_p: float = 0.1,
64
+ conv_dropout_p: float = 0.1,
65
+ conv_kernel_size: int = 31,
66
+ half_step_residual: bool = True,
67
+ ) -> None:
68
+ super(Conformer, self).__init__()
69
+ self.encoder = ConformerEncoder(
70
+ input_dim=input_dim,
71
+ encoder_dim=encoder_dim,
72
+ num_layers=num_encoder_layers,
73
+ num_attention_heads=num_attention_heads,
74
+ feed_forward_expansion_factor=feed_forward_expansion_factor,
75
+ conv_expansion_factor=conv_expansion_factor,
76
+ input_dropout_p=input_dropout_p,
77
+ feed_forward_dropout_p=feed_forward_dropout_p,
78
+ attention_dropout_p=attention_dropout_p,
79
+ conv_dropout_p=conv_dropout_p,
80
+ conv_kernel_size=conv_kernel_size,
81
+ half_step_residual=half_step_residual,
82
+ )
83
+ self.fc = Linear(encoder_dim, num_classes, bias=False)
84
+
85
+ def count_parameters(self) -> int:
86
+ """ Count parameters of encoder """
87
+ return self.encoder.count_parameters()
88
+
89
+ def update_dropout(self, dropout_p) -> None:
90
+ """ Update dropout probability of model """
91
+ self.encoder.update_dropout(dropout_p)
92
+
93
+ def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
94
+ """
95
+ Forward propagate a `inputs` and `targets` pair for training.
96
+
97
+ Args:
98
+ inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
99
+ `FloatTensor` of size ``(batch, seq_length, dimension)``.
100
+ input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
101
+
102
+ Returns:
103
+ * predictions (torch.FloatTensor): Result of model predictions.
104
+ """
105
+ encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths)
106
+ outputs = self.fc(encoder_outputs)
107
+ outputs = nn.functional.log_softmax(outputs, dim=-1)
108
+ return outputs, encoder_output_lengths
conformer/conformer/model_def.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch import Tensor
18
+ from typing import Tuple
19
+
20
+ from .encoder import ConformerEncoder
21
+ from .modules import Linear
22
+
23
+
24
+ class Conformer(nn.Module):
25
+ """
26
+ Conformer: Convolution-augmented Transformer for Speech Recognition
27
+ The paper used a one-lstm Transducer decoder, currently still only implemented
28
+ the conformer encoder shown in the paper.
29
+
30
+ Args:
31
+ num_classes (int): Number of classification classes
32
+ input_dim (int, optional): Dimension of input vector
33
+ encoder_dim (int, optional): Dimension of conformer encoder
34
+ num_encoder_layers (int, optional): Number of conformer blocks
35
+ num_attention_heads (int, optional): Number of attention heads
36
+ feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
37
+ conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
38
+ feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
39
+ attention_dropout_p (float, optional): Probability of attention module dropout
40
+ conv_dropout_p (float, optional): Probability of conformer convolution module dropout
41
+ conv_kernel_size (int or tuple, optional): Size of the convolving kernel
42
+ half_step_residual (bool): Flag indication whether to use half step residual or not
43
+
44
+ Inputs: inputs, input_lengths
45
+ - **inputs** (batch, time, dim): Tensor containing input vector
46
+ - **input_lengths** (batch): list of sequence input lengths
47
+
48
+ Returns: outputs, output_lengths
49
+ - **outputs** (batch, out_channels, time): Tensor produces by conformer.
50
+ - **output_lengths** (batch): list of sequence output lengths
51
+ """
52
+ def __init__(
53
+ self,
54
+ input_dim: int = 80,
55
+ encoder_dim: int = 512,
56
+ num_encoder_layers: int = 17,
57
+ num_attention_heads: int = 8,
58
+ feed_forward_expansion_factor: int = 4,
59
+ conv_expansion_factor: int = 2,
60
+ input_dropout_p: float = 0.1,
61
+ feed_forward_dropout_p: float = 0.1,
62
+ attention_dropout_p: float = 0.1,
63
+ conv_dropout_p: float = 0.1,
64
+ conv_kernel_size: int = 31,
65
+ half_step_residual: bool = True,
66
+ ) -> None:
67
+ super(Conformer, self).__init__()
68
+ self.encoder = ConformerEncoder(
69
+ input_dim=input_dim,
70
+ encoder_dim=encoder_dim,
71
+ num_layers=num_encoder_layers,
72
+ num_attention_heads=num_attention_heads,
73
+ feed_forward_expansion_factor=feed_forward_expansion_factor,
74
+ conv_expansion_factor=conv_expansion_factor,
75
+ input_dropout_p=input_dropout_p,
76
+ feed_forward_dropout_p=feed_forward_dropout_p,
77
+ attention_dropout_p=attention_dropout_p,
78
+ conv_dropout_p=conv_dropout_p,
79
+ conv_kernel_size=conv_kernel_size,
80
+ half_step_residual=half_step_residual,
81
+ )
82
+
83
+ def count_parameters(self) -> int:
84
+ """ Count parameters of encoder """
85
+ return self.encoder.count_parameters()
86
+
87
+ def update_dropout(self, dropout_p) -> None:
88
+ """ Update dropout probability of model """
89
+ self.encoder.update_dropout(dropout_p)
90
+
91
+ def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
92
+ """
93
+ Forward propagate a `inputs` and `targets` pair for training.
94
+
95
+ Args:
96
+ inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
97
+ `FloatTensor` of size ``(batch, seq_length, dimension)``.
98
+ input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
99
+
100
+ Returns:
101
+ * predictions (torch.FloatTensor): Result of model predictions.
102
+ """
103
+ encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths)
104
+ return encoder_outputs, encoder_output_lengths
conformer/conformer/modules.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.init as init
18
+ from torch import Tensor
19
+
20
+
21
+ class ResidualConnectionModule(nn.Module):
22
+ """
23
+ Residual Connection Module.
24
+ outputs = (module(inputs) x module_factor + inputs x input_factor)
25
+ """
26
+ def __init__(self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0):
27
+ super(ResidualConnectionModule, self).__init__()
28
+ self.module = module
29
+ self.module_factor = module_factor
30
+ self.input_factor = input_factor
31
+
32
+ def forward(self, inputs: Tensor) -> Tensor:
33
+ return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
34
+
35
+
36
+ class Linear(nn.Module):
37
+ """
38
+ Wrapper class of torch.nn.Linear
39
+ Weight initialize by xavier initialization and bias initialize to zeros.
40
+ """
41
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
42
+ super(Linear, self).__init__()
43
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
44
+ init.xavier_uniform_(self.linear.weight)
45
+ if bias:
46
+ init.zeros_(self.linear.bias)
47
+
48
+ def forward(self, x: Tensor) -> Tensor:
49
+ return self.linear(x)
50
+
51
+
52
+ class View(nn.Module):
53
+ """ Wrapper class of torch.view() for Sequential module. """
54
+ def __init__(self, shape: tuple, contiguous: bool = False):
55
+ super(View, self).__init__()
56
+ self.shape = shape
57
+ self.contiguous = contiguous
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ if self.contiguous:
61
+ x = x.contiguous()
62
+
63
+ return x.view(*self.shape)
64
+
65
+
66
+ class Transpose(nn.Module):
67
+ """ Wrapper class of torch.transpose() for Sequential module. """
68
+ def __init__(self, shape: tuple):
69
+ super(Transpose, self).__init__()
70
+ self.shape = shape
71
+
72
+ def forward(self, x: Tensor) -> Tensor:
73
+ return x.transpose(*self.shape)
model_pinyin.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+ from torch.optim import Adam
6
+ from torch.optim.lr_scheduler import CosineAnnealingLR
7
+ import math
8
+ from conformer.conformer.model_def import Conformer
9
+
10
+ class PositionalEncoding(nn.Module):
11
+ """位置编码模块"""
12
+ def __init__(self, d_model, max_len=5000):
13
+ super(PositionalEncoding, self).__init__()
14
+ pe = torch.zeros(max_len, d_model)
15
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
16
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
17
+ pe[:, 0::2] = torch.sin(position * div_term)
18
+ pe[:, 1::2] = torch.cos(position * div_term)
19
+ pe = pe.unsqueeze(0)
20
+ self.register_buffer('pe', pe)
21
+
22
+ def forward(self, x):
23
+ return x + self.pe[:, :x.size(1)]
24
+
25
+
26
+ class CrossAttention(nn.Module):
27
+ """交叉注意力模块 - 一层交叉注意力加一层自注意力"""
28
+ def __init__(self, query_dim, key_dim, heads=4, dropout=0.1):
29
+ super(CrossAttention, self).__init__()
30
+ # 交叉注意力层
31
+ self.cross_attn = nn.MultiheadAttention(query_dim, heads, dropout=dropout, batch_first=True)
32
+ self.norm1 = nn.LayerNorm(query_dim)
33
+ self.norm2 = nn.LayerNorm(query_dim)
34
+ self.ffn1 = nn.Sequential(
35
+ nn.Linear(query_dim, query_dim * 4),
36
+ nn.ReLU(),
37
+ nn.Dropout(dropout),
38
+ nn.Linear(query_dim * 4, query_dim),
39
+ nn.Dropout(dropout)
40
+ )
41
+
42
+ # 自注意力层
43
+ self.self_attn = nn.MultiheadAttention(query_dim, heads, dropout=dropout, batch_first=True)
44
+ self.norm3 = nn.LayerNorm(query_dim)
45
+ self.norm4 = nn.LayerNorm(query_dim)
46
+ self.ffn2 = nn.Sequential(
47
+ nn.Linear(query_dim, query_dim * 4),
48
+ nn.ReLU(),
49
+ nn.Dropout(dropout),
50
+ nn.Linear(query_dim * 4, query_dim),
51
+ nn.Dropout(dropout)
52
+ )
53
+
54
+ # 投影层
55
+ self.proj_key = nn.Linear(key_dim, query_dim)
56
+ self.proj_value = nn.Linear(key_dim, query_dim)
57
+
58
+ def forward(self, query, key, value, key_padding_mask=None):
59
+ # 有音频输入时,先进行交叉注意力
60
+ # 投影key和value到query的维度
61
+ key_proj = self.proj_key(key)
62
+ value_proj = self.proj_value(value)
63
+
64
+ # 交叉注意力
65
+ query_norm = self.norm1(query)
66
+ cross_attn_output, _ = self.cross_attn(query_norm, key_proj, value_proj,
67
+ key_padding_mask=key_padding_mask)
68
+ query = query + cross_attn_output
69
+ query = query + self.ffn1(self.norm2(query))
70
+
71
+ # 然后进行自注意力
72
+ query_norm = self.norm3(query)
73
+ self_attn_output, _ = self.self_attn(query_norm, query_norm, query_norm)
74
+ query = query + self_attn_output
75
+ query = query + self.ffn2(self.norm4(query))
76
+
77
+ return query
78
+
79
+
80
+ class MMKWS2(nn.Module):
81
+ def __init__(
82
+ self,
83
+ # anchor
84
+ text_dim=64,
85
+ audio_dim=1024,
86
+ hidden_dim=128,
87
+ # compare
88
+ dim=80,
89
+ encoder_dim=128,
90
+ num_encoder_layers=6,
91
+ num_attention_heads=4,
92
+ dropout=0.1,
93
+ num_transformer_layers=2
94
+ ):
95
+ super(MMKWS2, self).__init__()
96
+ # 音频嵌入降维
97
+ self.audio_proj = nn.Linear(audio_dim, hidden_dim)
98
+ # 文本嵌入投影
99
+ self.text_proj = nn.Embedding(num_embeddings=402, embedding_dim=hidden_dim) # 401 + padding -1
100
+ # 位置编码
101
+ self.pos_enc = PositionalEncoding(hidden_dim)
102
+ # 交叉注意力模块
103
+ self.cross_attn = CrossAttention(hidden_dim, hidden_dim, heads=num_attention_heads, dropout=dropout)
104
+
105
+ # Conformer层
106
+ self.conformer = Conformer(
107
+ input_dim=dim,
108
+ encoder_dim=encoder_dim,
109
+ num_encoder_layers=num_encoder_layers,
110
+ num_attention_heads=num_attention_heads,
111
+ )
112
+
113
+ # 特征映射层(将conformer输出维度映射到hidden_dim)
114
+ self.feat_proj = nn.Linear(encoder_dim, hidden_dim)
115
+
116
+ # Transformer层
117
+ self.transformer_encoder = nn.TransformerEncoder(
118
+ nn.TransformerEncoderLayer(
119
+ d_model=hidden_dim,
120
+ nhead=num_attention_heads,
121
+ dim_feedforward=hidden_dim*4,
122
+ dropout=dropout,
123
+ batch_first=True
124
+ ),
125
+ num_layers=num_transformer_layers
126
+ )
127
+
128
+ # GRU分类器
129
+ self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
130
+ self.classifier = nn.Sequential(
131
+ nn.Linear(hidden_dim*2, hidden_dim),
132
+ nn.ReLU(),
133
+ nn.Dropout(dropout),
134
+ nn.Linear(hidden_dim, 1)
135
+ )
136
+
137
+ # 序列标签预测
138
+ self.seq_classifier = nn.Linear(hidden_dim, 1)
139
+
140
+ def forward(self, anchor_wave_embedding, anchor_text_embedding, compare_wave, compare_lengths):
141
+ batch_size = anchor_wave_embedding.size(0)
142
+
143
+ # 1. 处理anchor_text嵌入
144
+ text_feat = self.text_proj(anchor_text_embedding) # [B, S, hidden_dim]
145
+ text_feat = self.pos_enc(text_feat)
146
+
147
+ # 2. 处理anchor_wave音频嵌入
148
+ audio_feat = self.audio_proj(anchor_wave_embedding) # [B, S, hidden_dim]
149
+ audio_feat = self.pos_enc(audio_feat)
150
+
151
+ # 3. 交叉注意力:文本和音频特征融合
152
+ fused_feat = self.cross_attn(text_feat, audio_feat, audio_feat, key_padding_mask=None) # [B, S, hidden_dim]
153
+
154
+ # 4. 处理compare_wave的fbank特征
155
+ compare_feat = self.conformer(compare_wave, compare_lengths)[0] # [B, T, encoder_dim]
156
+ compare_feat = self.feat_proj(compare_feat) # [B, T, hidden_dim]
157
+ compare_feat = self.pos_enc(compare_feat)
158
+
159
+ # 5. 合并特征
160
+ text_len = fused_feat.size(1)
161
+ combined_feat = torch.cat([fused_feat, compare_feat], dim=1) # [B, S+T, hidden_dim]
162
+ combined_feat = self.transformer_encoder(combined_feat) # [B, S+T, hidden_dim]
163
+
164
+ # 7. GRU分类
165
+ gru_out, _ = self.gru(combined_feat) # [B, S+T, hidden_dim*2]
166
+
167
+ # 全局分类
168
+ global_feat = gru_out[:, -1, :] # 取最后一个时间步
169
+ logits = self.classifier(global_feat).squeeze(-1) # [B]
170
+
171
+ # 序列标签预测
172
+ seq_logits = self.seq_classifier(combined_feat[:, :text_len, :]).squeeze(-1) # [B, S]
173
+ return logits, seq_logits
174
+
175
+
176
+ def enrollment(self, anchor_wave_embedding, anchor_text_embedding):
177
+ batch_size = anchor_wave_embedding.size(0)
178
+
179
+ # 1. 处理anchor_text嵌入
180
+ text_feat = self.text_proj(anchor_text_embedding) # [B, S, hidden_dim]
181
+ text_feat = self.pos_enc(text_feat)
182
+
183
+ # 2. 处理anchor_wave音频嵌入
184
+ audio_feat = self.audio_proj(anchor_wave_embedding) # [B, S, hidden_dim]
185
+ audio_feat = self.pos_enc(audio_feat)
186
+
187
+ # 3. 交叉注意力:文本和音频特征融合
188
+ fused_feat = self.cross_attn(text_feat, audio_feat, audio_feat, key_padding_mask=None) # [B, S, hidden_dim]
189
+
190
+ return fused_feat
191
+
192
+ def verification(self, fused_feat, compare_wave, compare_lengths):
193
+ batch_size = fused_feat.size(0)
194
+
195
+ # 4. 处理compare_wave的fbank特征
196
+ compare_feat = self.conformer(compare_wave, compare_lengths)[0] # [B, T, encoder_dim]
197
+ compare_feat = self.feat_proj(compare_feat) # [B, T, hidden_dim]
198
+ compare_feat = self.pos_enc(compare_feat)
199
+
200
+ # 5. 合并特征
201
+ text_len = fused_feat.size(1)
202
+ combined_feat = torch.cat([fused_feat, compare_feat], dim=1) # [B, S+T, hidden_dim]
203
+ combined_feat = self.transformer_encoder(combined_feat) # [B, S+T, hidden_dim]
204
+
205
+ # 7. GRU分类
206
+ gru_out, _ = self.gru(combined_feat) # [B, S+T, hidden_dim*2]
207
+
208
+ # 全局分类
209
+ global_feat = gru_out[:, -1, :] # 取最后一个时间步
210
+ logits = self.classifier(global_feat).squeeze(-1) # [B]
211
+
212
+ return logits
213
+
214
+
215
+ def count_verification_params(model):
216
+ modules = [
217
+ model.conformer,
218
+ model.feat_proj,
219
+ model.transformer_encoder,
220
+ model.gru,
221
+ model.classifier
222
+ ]
223
+ total = 0
224
+ for m in modules:
225
+ total += sum(p.numel() for p in m.parameters())
226
+ return total
227
+
228
+ model = MMKWS2(
229
+ text_dim=64,
230
+ audio_dim=1024,
231
+ hidden_dim=128,
232
+ dim=80,
233
+ encoder_dim=128,
234
+ num_encoder_layers=6,
235
+ num_attention_heads=4,
236
+ dropout=0.1,
237
+ num_transformer_layers=2
238
+ )
239
+ print(f"verification相关参数量: {count_verification_params(model):,}") # 3.5M模型参数量
240
+
241
+ # if __name__ == "__main__":
242
+ # # 创建一个示例batch
243
+ # batch_size = 2
244
+
245
+ # # 创建模拟数据
246
+ # anchor_embedding = torch.randn(batch_size, 8, 64) # 文本嵌入
247
+ # anchor_wave = torch.randn(batch_size, 256, 1024) # 音频嵌入
248
+ # compare_wave = torch.randn(batch_size, 45, 80) # Fbank特征
249
+
250
+ # # 创建长度信息
251
+ # anchor_lengths = torch.LongTensor([8, 6]) # 两个样本的实际长度
252
+ # compare_lengths = torch.LongTensor([45, 40])
253
+
254
+ # # 创建模型
255
+ # model = MMKWS2(
256
+ # text_dim=64,
257
+ # audio_dim=1024,
258
+ # hidden_dim=128,
259
+ # dim=80,
260
+ # encoder_dim=128,
261
+ # num_encoder_layers=6,
262
+ # num_attention_heads=4,
263
+ # dropout=0.1,
264
+ # num_transformer_layers=2
265
+ # )
266
+
267
+ # # 计算模型参数量
268
+ # total_params = sum(p.numel() for p in model.parameters())
269
+ # trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
270
+
271
+ # print(f"模型总参数量: {total_params:,}")
272
+ # print(f"可训练参数量: {trainable_params:,}")
273
+
274
+ # # 打印模型结构
275
+ # print("\n模型结构:")
276
+ # print(model)
277
+
278
+ # # 模型推理
279
+ # print("\n开始推理...")
280
+ # model.eval()
281
+ # with torch.no_grad():
282
+
283
+ # print(anchor_embedding.shape)
284
+ # print(anchor_wave.shape)
285
+ # print(compare_wave.shape)
286
+ # print(anchor_lengths.shape)
287
+ # print(compare_lengths.shape)
288
+ # # 完整输入推理
289
+ # logits, seq_logits, text_len = model(
290
+ # anchor_embedding=anchor_embedding,
291
+ # anchor_wave=anchor_wave,
292
+ # compare_wave=compare_wave,
293
+ # anchor_lengths=anchor_lengths,
294
+ # compare_lengths=compare_lengths
295
+ # )
296
+
297
+ # print("\n推理结果:")
298
+ # print(f"分类logits形状: {logits.shape}")
299
+ # print(f"序列logits形状: {seq_logits.shape}")
pinyin_index.json ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pinyin_to_index": {
3
+ "a": 0,
4
+ "ai": 1,
5
+ "an": 2,
6
+ "ang": 3,
7
+ "ao": 4,
8
+ "ba": 5,
9
+ "bai": 6,
10
+ "ban": 7,
11
+ "bang": 8,
12
+ "bao": 9,
13
+ "bei": 10,
14
+ "ben": 11,
15
+ "beng": 12,
16
+ "bi": 13,
17
+ "bian": 14,
18
+ "biao": 15,
19
+ "bie": 16,
20
+ "bin": 17,
21
+ "bing": 18,
22
+ "bo": 19,
23
+ "bu": 20,
24
+ "ca": 21,
25
+ "cai": 22,
26
+ "can": 23,
27
+ "cang": 24,
28
+ "cao": 25,
29
+ "ce": 26,
30
+ "cen": 27,
31
+ "ceng": 28,
32
+ "cha": 29,
33
+ "chai": 30,
34
+ "chan": 31,
35
+ "chang": 32,
36
+ "chao": 33,
37
+ "che": 34,
38
+ "chen": 35,
39
+ "cheng": 36,
40
+ "chi": 37,
41
+ "chong": 38,
42
+ "chou": 39,
43
+ "chu": 40,
44
+ "chuai": 41,
45
+ "chuan": 42,
46
+ "chuang": 43,
47
+ "chui": 44,
48
+ "chun": 45,
49
+ "chuo": 46,
50
+ "ci": 47,
51
+ "cong": 48,
52
+ "cou": 49,
53
+ "cu": 50,
54
+ "cuan": 51,
55
+ "cui": 52,
56
+ "cun": 53,
57
+ "cuo": 54,
58
+ "da": 55,
59
+ "dai": 56,
60
+ "dan": 57,
61
+ "dang": 58,
62
+ "dao": 59,
63
+ "de": 60,
64
+ "dei": 61,
65
+ "deng": 62,
66
+ "di": 63,
67
+ "dia": 64,
68
+ "dian": 65,
69
+ "diao": 66,
70
+ "die": 67,
71
+ "ding": 68,
72
+ "diu": 69,
73
+ "dong": 70,
74
+ "dou": 71,
75
+ "du": 72,
76
+ "duan": 73,
77
+ "dui": 74,
78
+ "dun": 75,
79
+ "duo": 76,
80
+ "e": 77,
81
+ "en": 78,
82
+ "er": 79,
83
+ "fa": 80,
84
+ "fan": 81,
85
+ "fang": 82,
86
+ "fei": 83,
87
+ "fen": 84,
88
+ "feng": 85,
89
+ "fo": 86,
90
+ "fou": 87,
91
+ "fu": 88,
92
+ "ga": 89,
93
+ "gai": 90,
94
+ "gan": 91,
95
+ "gang": 92,
96
+ "gao": 93,
97
+ "ge": 94,
98
+ "gei": 95,
99
+ "gen": 96,
100
+ "geng": 97,
101
+ "gong": 98,
102
+ "gou": 99,
103
+ "gu": 100,
104
+ "gua": 101,
105
+ "guai": 102,
106
+ "guan": 103,
107
+ "guang": 104,
108
+ "gui": 105,
109
+ "gun": 106,
110
+ "guo": 107,
111
+ "ha": 108,
112
+ "hai": 109,
113
+ "han": 110,
114
+ "hang": 111,
115
+ "hao": 112,
116
+ "he": 113,
117
+ "hei": 114,
118
+ "hen": 115,
119
+ "heng": 116,
120
+ "hong": 117,
121
+ "hou": 118,
122
+ "hu": 119,
123
+ "hua": 120,
124
+ "huai": 121,
125
+ "huan": 122,
126
+ "huang": 123,
127
+ "hui": 124,
128
+ "hun": 125,
129
+ "huo": 126,
130
+ "ji": 127,
131
+ "jia": 128,
132
+ "jian": 129,
133
+ "jiang": 130,
134
+ "jiao": 131,
135
+ "jie": 132,
136
+ "jin": 133,
137
+ "jing": 134,
138
+ "jiong": 135,
139
+ "jiu": 136,
140
+ "ju": 137,
141
+ "juan": 138,
142
+ "jue": 139,
143
+ "jun": 140,
144
+ "ka": 141,
145
+ "kai": 142,
146
+ "kan": 143,
147
+ "kang": 144,
148
+ "kao": 145,
149
+ "ke": 146,
150
+ "ken": 147,
151
+ "keng": 148,
152
+ "kong": 149,
153
+ "kou": 150,
154
+ "ku": 151,
155
+ "kua": 152,
156
+ "kuai": 153,
157
+ "kuan": 154,
158
+ "kuang": 155,
159
+ "kui": 156,
160
+ "kun": 157,
161
+ "kuo": 158,
162
+ "la": 159,
163
+ "lai": 160,
164
+ "lan": 161,
165
+ "lang": 162,
166
+ "lao": 163,
167
+ "le": 164,
168
+ "lei": 165,
169
+ "leng": 166,
170
+ "li": 167,
171
+ "lia": 168,
172
+ "lian": 169,
173
+ "liang": 170,
174
+ "liao": 171,
175
+ "lie": 172,
176
+ "lin": 173,
177
+ "ling": 174,
178
+ "liu": 175,
179
+ "long": 176,
180
+ "lou": 177,
181
+ "lu": 178,
182
+ "luan": 179,
183
+ "lun": 180,
184
+ "luo": 181,
185
+ "lv": 182,
186
+ "lve": 183,
187
+ "ma": 184,
188
+ "mai": 185,
189
+ "man": 186,
190
+ "mang": 187,
191
+ "mao": 188,
192
+ "me": 189,
193
+ "mei": 190,
194
+ "men": 191,
195
+ "meng": 192,
196
+ "mi": 193,
197
+ "mian": 194,
198
+ "miao": 195,
199
+ "mie": 196,
200
+ "min": 197,
201
+ "ming": 198,
202
+ "miu": 199,
203
+ "mo": 200,
204
+ "mou": 201,
205
+ "mu": 202,
206
+ "na": 203,
207
+ "nai": 204,
208
+ "nan": 205,
209
+ "nang": 206,
210
+ "nao": 207,
211
+ "ne": 208,
212
+ "nei": 209,
213
+ "nen": 210,
214
+ "neng": 211,
215
+ "ni": 212,
216
+ "nian": 213,
217
+ "niang": 214,
218
+ "niao": 215,
219
+ "nie": 216,
220
+ "nin": 217,
221
+ "ning": 218,
222
+ "niu": 219,
223
+ "nong": 220,
224
+ "nu": 221,
225
+ "nuan": 222,
226
+ "nuo": 223,
227
+ "nv": 224,
228
+ "nve": 225,
229
+ "o": 226,
230
+ "ou": 227,
231
+ "pa": 228,
232
+ "pai": 229,
233
+ "pan": 230,
234
+ "pang": 231,
235
+ "pao": 232,
236
+ "pei": 233,
237
+ "pen": 234,
238
+ "peng": 235,
239
+ "pi": 236,
240
+ "pian": 237,
241
+ "piao": 238,
242
+ "pie": 239,
243
+ "pin": 240,
244
+ "ping": 241,
245
+ "po": 242,
246
+ "pou": 243,
247
+ "pu": 244,
248
+ "qi": 245,
249
+ "qia": 246,
250
+ "qian": 247,
251
+ "qiang": 248,
252
+ "qiao": 249,
253
+ "qie": 250,
254
+ "qin": 251,
255
+ "qing": 252,
256
+ "qiong": 253,
257
+ "qiu": 254,
258
+ "qu": 255,
259
+ "quan": 256,
260
+ "que": 257,
261
+ "qun": 258,
262
+ "ran": 259,
263
+ "rang": 260,
264
+ "rao": 261,
265
+ "re": 262,
266
+ "ren": 263,
267
+ "reng": 264,
268
+ "ri": 265,
269
+ "rong": 266,
270
+ "rou": 267,
271
+ "ru": 268,
272
+ "ruan": 269,
273
+ "rui": 270,
274
+ "run": 271,
275
+ "ruo": 272,
276
+ "sa": 273,
277
+ "sai": 274,
278
+ "san": 275,
279
+ "sang": 276,
280
+ "sao": 277,
281
+ "se": 278,
282
+ "sen": 279,
283
+ "seng": 280,
284
+ "sha": 281,
285
+ "shai": 282,
286
+ "shan": 283,
287
+ "shang": 284,
288
+ "shao": 285,
289
+ "she": 286,
290
+ "shei": 287,
291
+ "shen": 288,
292
+ "sheng": 289,
293
+ "shi": 290,
294
+ "shou": 291,
295
+ "shu": 292,
296
+ "shua": 293,
297
+ "shuai": 294,
298
+ "shuan": 295,
299
+ "shuang": 296,
300
+ "shui": 297,
301
+ "shun": 298,
302
+ "shuo": 299,
303
+ "si": 300,
304
+ "song": 301,
305
+ "sou": 302,
306
+ "su": 303,
307
+ "suan": 304,
308
+ "sui": 305,
309
+ "sun": 306,
310
+ "suo": 307,
311
+ "ta": 308,
312
+ "tai": 309,
313
+ "tan": 310,
314
+ "tang": 311,
315
+ "tao": 312,
316
+ "te": 313,
317
+ "teng": 314,
318
+ "ti": 315,
319
+ "tian": 316,
320
+ "tiao": 317,
321
+ "tie": 318,
322
+ "ting": 319,
323
+ "tong": 320,
324
+ "tou": 321,
325
+ "tu": 322,
326
+ "tuan": 323,
327
+ "tui": 324,
328
+ "tun": 325,
329
+ "tuo": 326,
330
+ "wa": 327,
331
+ "wai": 328,
332
+ "wan": 329,
333
+ "wang": 330,
334
+ "wei": 331,
335
+ "wen": 332,
336
+ "weng": 333,
337
+ "wo": 334,
338
+ "wu": 335,
339
+ "xi": 336,
340
+ "xia": 337,
341
+ "xian": 338,
342
+ "xiang": 339,
343
+ "xiao": 340,
344
+ "xie": 341,
345
+ "xin": 342,
346
+ "xing": 343,
347
+ "xiong": 344,
348
+ "xiu": 345,
349
+ "xu": 346,
350
+ "xuan": 347,
351
+ "xue": 348,
352
+ "xun": 349,
353
+ "ya": 350,
354
+ "yan": 351,
355
+ "yang": 352,
356
+ "yao": 353,
357
+ "ye": 354,
358
+ "yi": 355,
359
+ "yin": 356,
360
+ "ying": 357,
361
+ "yo": 358,
362
+ "yong": 359,
363
+ "you": 360,
364
+ "yu": 361,
365
+ "yuan": 362,
366
+ "yue": 363,
367
+ "yun": 364,
368
+ "za": 365,
369
+ "zai": 366,
370
+ "zan": 367,
371
+ "zang": 368,
372
+ "zao": 369,
373
+ "ze": 370,
374
+ "zei": 371,
375
+ "zen": 372,
376
+ "zeng": 373,
377
+ "zha": 374,
378
+ "zhai": 375,
379
+ "zhan": 376,
380
+ "zhang": 377,
381
+ "zhao": 378,
382
+ "zhe": 379,
383
+ "zhen": 380,
384
+ "zheng": 381,
385
+ "zhi": 382,
386
+ "zhong": 383,
387
+ "zhou": 384,
388
+ "zhu": 385,
389
+ "zhua": 386,
390
+ "zhuai": 387,
391
+ "zhuan": 388,
392
+ "zhuang": 389,
393
+ "zhui": 390,
394
+ "zhun": 391,
395
+ "zhuo": 392,
396
+ "zi": 393,
397
+ "zong": 394,
398
+ "zou": 395,
399
+ "zu": 396,
400
+ "zuan": 397,
401
+ "zui": 398,
402
+ "zun": 399,
403
+ "zuo": 400
404
+ },
405
+ "index_to_pinyin": {
406
+ "0": "a",
407
+ "1": "ai",
408
+ "2": "an",
409
+ "3": "ang",
410
+ "4": "ao",
411
+ "5": "ba",
412
+ "6": "bai",
413
+ "7": "ban",
414
+ "8": "bang",
415
+ "9": "bao",
416
+ "10": "bei",
417
+ "11": "ben",
418
+ "12": "beng",
419
+ "13": "bi",
420
+ "14": "bian",
421
+ "15": "biao",
422
+ "16": "bie",
423
+ "17": "bin",
424
+ "18": "bing",
425
+ "19": "bo",
426
+ "20": "bu",
427
+ "21": "ca",
428
+ "22": "cai",
429
+ "23": "can",
430
+ "24": "cang",
431
+ "25": "cao",
432
+ "26": "ce",
433
+ "27": "cen",
434
+ "28": "ceng",
435
+ "29": "cha",
436
+ "30": "chai",
437
+ "31": "chan",
438
+ "32": "chang",
439
+ "33": "chao",
440
+ "34": "che",
441
+ "35": "chen",
442
+ "36": "cheng",
443
+ "37": "chi",
444
+ "38": "chong",
445
+ "39": "chou",
446
+ "40": "chu",
447
+ "41": "chuai",
448
+ "42": "chuan",
449
+ "43": "chuang",
450
+ "44": "chui",
451
+ "45": "chun",
452
+ "46": "chuo",
453
+ "47": "ci",
454
+ "48": "cong",
455
+ "49": "cou",
456
+ "50": "cu",
457
+ "51": "cuan",
458
+ "52": "cui",
459
+ "53": "cun",
460
+ "54": "cuo",
461
+ "55": "da",
462
+ "56": "dai",
463
+ "57": "dan",
464
+ "58": "dang",
465
+ "59": "dao",
466
+ "60": "de",
467
+ "61": "dei",
468
+ "62": "deng",
469
+ "63": "di",
470
+ "64": "dia",
471
+ "65": "dian",
472
+ "66": "diao",
473
+ "67": "die",
474
+ "68": "ding",
475
+ "69": "diu",
476
+ "70": "dong",
477
+ "71": "dou",
478
+ "72": "du",
479
+ "73": "duan",
480
+ "74": "dui",
481
+ "75": "dun",
482
+ "76": "duo",
483
+ "77": "e",
484
+ "78": "en",
485
+ "79": "er",
486
+ "80": "fa",
487
+ "81": "fan",
488
+ "82": "fang",
489
+ "83": "fei",
490
+ "84": "fen",
491
+ "85": "feng",
492
+ "86": "fo",
493
+ "87": "fou",
494
+ "88": "fu",
495
+ "89": "ga",
496
+ "90": "gai",
497
+ "91": "gan",
498
+ "92": "gang",
499
+ "93": "gao",
500
+ "94": "ge",
501
+ "95": "gei",
502
+ "96": "gen",
503
+ "97": "geng",
504
+ "98": "gong",
505
+ "99": "gou",
506
+ "100": "gu",
507
+ "101": "gua",
508
+ "102": "guai",
509
+ "103": "guan",
510
+ "104": "guang",
511
+ "105": "gui",
512
+ "106": "gun",
513
+ "107": "guo",
514
+ "108": "ha",
515
+ "109": "hai",
516
+ "110": "han",
517
+ "111": "hang",
518
+ "112": "hao",
519
+ "113": "he",
520
+ "114": "hei",
521
+ "115": "hen",
522
+ "116": "heng",
523
+ "117": "hong",
524
+ "118": "hou",
525
+ "119": "hu",
526
+ "120": "hua",
527
+ "121": "huai",
528
+ "122": "huan",
529
+ "123": "huang",
530
+ "124": "hui",
531
+ "125": "hun",
532
+ "126": "huo",
533
+ "127": "ji",
534
+ "128": "jia",
535
+ "129": "jian",
536
+ "130": "jiang",
537
+ "131": "jiao",
538
+ "132": "jie",
539
+ "133": "jin",
540
+ "134": "jing",
541
+ "135": "jiong",
542
+ "136": "jiu",
543
+ "137": "ju",
544
+ "138": "juan",
545
+ "139": "jue",
546
+ "140": "jun",
547
+ "141": "ka",
548
+ "142": "kai",
549
+ "143": "kan",
550
+ "144": "kang",
551
+ "145": "kao",
552
+ "146": "ke",
553
+ "147": "ken",
554
+ "148": "keng",
555
+ "149": "kong",
556
+ "150": "kou",
557
+ "151": "ku",
558
+ "152": "kua",
559
+ "153": "kuai",
560
+ "154": "kuan",
561
+ "155": "kuang",
562
+ "156": "kui",
563
+ "157": "kun",
564
+ "158": "kuo",
565
+ "159": "la",
566
+ "160": "lai",
567
+ "161": "lan",
568
+ "162": "lang",
569
+ "163": "lao",
570
+ "164": "le",
571
+ "165": "lei",
572
+ "166": "leng",
573
+ "167": "li",
574
+ "168": "lia",
575
+ "169": "lian",
576
+ "170": "liang",
577
+ "171": "liao",
578
+ "172": "lie",
579
+ "173": "lin",
580
+ "174": "ling",
581
+ "175": "liu",
582
+ "176": "long",
583
+ "177": "lou",
584
+ "178": "lu",
585
+ "179": "luan",
586
+ "180": "lun",
587
+ "181": "luo",
588
+ "182": "lv",
589
+ "183": "lve",
590
+ "184": "ma",
591
+ "185": "mai",
592
+ "186": "man",
593
+ "187": "mang",
594
+ "188": "mao",
595
+ "189": "me",
596
+ "190": "mei",
597
+ "191": "men",
598
+ "192": "meng",
599
+ "193": "mi",
600
+ "194": "mian",
601
+ "195": "miao",
602
+ "196": "mie",
603
+ "197": "min",
604
+ "198": "ming",
605
+ "199": "miu",
606
+ "200": "mo",
607
+ "201": "mou",
608
+ "202": "mu",
609
+ "203": "na",
610
+ "204": "nai",
611
+ "205": "nan",
612
+ "206": "nang",
613
+ "207": "nao",
614
+ "208": "ne",
615
+ "209": "nei",
616
+ "210": "nen",
617
+ "211": "neng",
618
+ "212": "ni",
619
+ "213": "nian",
620
+ "214": "niang",
621
+ "215": "niao",
622
+ "216": "nie",
623
+ "217": "nin",
624
+ "218": "ning",
625
+ "219": "niu",
626
+ "220": "nong",
627
+ "221": "nu",
628
+ "222": "nuan",
629
+ "223": "nuo",
630
+ "224": "nv",
631
+ "225": "nve",
632
+ "226": "o",
633
+ "227": "ou",
634
+ "228": "pa",
635
+ "229": "pai",
636
+ "230": "pan",
637
+ "231": "pang",
638
+ "232": "pao",
639
+ "233": "pei",
640
+ "234": "pen",
641
+ "235": "peng",
642
+ "236": "pi",
643
+ "237": "pian",
644
+ "238": "piao",
645
+ "239": "pie",
646
+ "240": "pin",
647
+ "241": "ping",
648
+ "242": "po",
649
+ "243": "pou",
650
+ "244": "pu",
651
+ "245": "qi",
652
+ "246": "qia",
653
+ "247": "qian",
654
+ "248": "qiang",
655
+ "249": "qiao",
656
+ "250": "qie",
657
+ "251": "qin",
658
+ "252": "qing",
659
+ "253": "qiong",
660
+ "254": "qiu",
661
+ "255": "qu",
662
+ "256": "quan",
663
+ "257": "que",
664
+ "258": "qun",
665
+ "259": "ran",
666
+ "260": "rang",
667
+ "261": "rao",
668
+ "262": "re",
669
+ "263": "ren",
670
+ "264": "reng",
671
+ "265": "ri",
672
+ "266": "rong",
673
+ "267": "rou",
674
+ "268": "ru",
675
+ "269": "ruan",
676
+ "270": "rui",
677
+ "271": "run",
678
+ "272": "ruo",
679
+ "273": "sa",
680
+ "274": "sai",
681
+ "275": "san",
682
+ "276": "sang",
683
+ "277": "sao",
684
+ "278": "se",
685
+ "279": "sen",
686
+ "280": "seng",
687
+ "281": "sha",
688
+ "282": "shai",
689
+ "283": "shan",
690
+ "284": "shang",
691
+ "285": "shao",
692
+ "286": "she",
693
+ "287": "shei",
694
+ "288": "shen",
695
+ "289": "sheng",
696
+ "290": "shi",
697
+ "291": "shou",
698
+ "292": "shu",
699
+ "293": "shua",
700
+ "294": "shuai",
701
+ "295": "shuan",
702
+ "296": "shuang",
703
+ "297": "shui",
704
+ "298": "shun",
705
+ "299": "shuo",
706
+ "300": "si",
707
+ "301": "song",
708
+ "302": "sou",
709
+ "303": "su",
710
+ "304": "suan",
711
+ "305": "sui",
712
+ "306": "sun",
713
+ "307": "suo",
714
+ "308": "ta",
715
+ "309": "tai",
716
+ "310": "tan",
717
+ "311": "tang",
718
+ "312": "tao",
719
+ "313": "te",
720
+ "314": "teng",
721
+ "315": "ti",
722
+ "316": "tian",
723
+ "317": "tiao",
724
+ "318": "tie",
725
+ "319": "ting",
726
+ "320": "tong",
727
+ "321": "tou",
728
+ "322": "tu",
729
+ "323": "tuan",
730
+ "324": "tui",
731
+ "325": "tun",
732
+ "326": "tuo",
733
+ "327": "wa",
734
+ "328": "wai",
735
+ "329": "wan",
736
+ "330": "wang",
737
+ "331": "wei",
738
+ "332": "wen",
739
+ "333": "weng",
740
+ "334": "wo",
741
+ "335": "wu",
742
+ "336": "xi",
743
+ "337": "xia",
744
+ "338": "xian",
745
+ "339": "xiang",
746
+ "340": "xiao",
747
+ "341": "xie",
748
+ "342": "xin",
749
+ "343": "xing",
750
+ "344": "xiong",
751
+ "345": "xiu",
752
+ "346": "xu",
753
+ "347": "xuan",
754
+ "348": "xue",
755
+ "349": "xun",
756
+ "350": "ya",
757
+ "351": "yan",
758
+ "352": "yang",
759
+ "353": "yao",
760
+ "354": "ye",
761
+ "355": "yi",
762
+ "356": "yin",
763
+ "357": "ying",
764
+ "358": "yo",
765
+ "359": "yong",
766
+ "360": "you",
767
+ "361": "yu",
768
+ "362": "yuan",
769
+ "363": "yue",
770
+ "364": "yun",
771
+ "365": "za",
772
+ "366": "zai",
773
+ "367": "zan",
774
+ "368": "zang",
775
+ "369": "zao",
776
+ "370": "ze",
777
+ "371": "zei",
778
+ "372": "zen",
779
+ "373": "zeng",
780
+ "374": "zha",
781
+ "375": "zhai",
782
+ "376": "zhan",
783
+ "377": "zhang",
784
+ "378": "zhao",
785
+ "379": "zhe",
786
+ "380": "zhen",
787
+ "381": "zheng",
788
+ "382": "zhi",
789
+ "383": "zhong",
790
+ "384": "zhou",
791
+ "385": "zhu",
792
+ "386": "zhua",
793
+ "387": "zhuai",
794
+ "388": "zhuan",
795
+ "389": "zhuang",
796
+ "390": "zhui",
797
+ "391": "zhun",
798
+ "392": "zhuo",
799
+ "393": "zi",
800
+ "394": "zong",
801
+ "395": "zou",
802
+ "396": "zu",
803
+ "397": "zuan",
804
+ "398": "zui",
805
+ "399": "zun",
806
+ "400": "zuo"
807
+ }
808
+ }
stepstep=024500.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6bd2c480df21e2cd9ee5eadef5d67e69b059e00c30159889502ac1848132950
3
+ size 714212484
train_pinyin.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import DataLoader
5
+ from torch.optim.lr_scheduler import LambdaLR
6
+
7
+ import numpy as np
8
+ from sklearn.metrics import roc_auc_score
9
+ from sklearn.metrics import roc_curve
10
+
11
+ import pytorch_lightning as pl
12
+ from pytorch_lightning import LightningModule, Trainer
13
+ from pytorch_lightning.callbacks import ModelCheckpoint
14
+ from transformers import Wav2Vec2FeatureExtractor, HubertModel
15
+ from model_pinyin import MMKWS2
16
+ from torch.utils.data import Dataset, ConcatDataset
17
+
18
+ class MMKWS2_Wrapper(LightningModule):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.model = MMKWS2()
22
+ self.criterion = nn.BCEWithLogitsLoss()
23
+ # 将hubert_model设为临时变量而非类属性
24
+ hubert_model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-large").half().eval()
25
+ self._hubert_model = hubert_model # 使用下划线前缀表示内部使用
26
+
27
+ def training_step(self, batch, batch_idx):
28
+ anchor_wave, anchor_text_embedding, compare_wave, compare_lengths, label, seq_label = \
29
+ batch['anchor_wave'], batch['anchor_embedding'], batch['compare_wave'], batch['compare_lengths'], batch['label'], batch['seq_label']
30
+
31
+ with torch.no_grad():
32
+ outputs = self._hubert_model(anchor_wave.half())
33
+ anchor_wave_embedding = outputs.last_hidden_state
34
+
35
+ anchor_wave_embedding = anchor_wave_embedding.to(anchor_wave.dtype)
36
+
37
+ logits, seq_logits = self.model(
38
+ anchor_wave_embedding,
39
+ anchor_text_embedding,
40
+ compare_wave,
41
+ compare_lengths
42
+ )
43
+
44
+ # 句级二分类loss
45
+ utt_loss = self.criterion(logits, label.float())
46
+
47
+ # 序列loss(mask掉seq_label为-1的部分)
48
+ mask = (seq_label != -1).float()
49
+ seq_label_valid = seq_label.clone()
50
+ seq_label_valid[seq_label == -1] = 0 # 避免-1影响loss
51
+
52
+ seq_loss = F.binary_cross_entropy_with_logits(
53
+ seq_logits, seq_label_valid.float(), weight=mask, reduction='sum'
54
+ ) / (mask.sum() + 1e-6)
55
+
56
+ loss = utt_loss + seq_loss
57
+
58
+ # 每500步记录日志
59
+ self.log('train/utt_loss', utt_loss, on_step=True, on_epoch=False, prog_bar=True)
60
+ self.log('train/seq_loss', seq_loss, on_step=True, on_epoch=False, prog_bar=True)
61
+ self.log('train/loss', loss, on_step=True, on_epoch=False, prog_bar=True)
62
+ return loss
63
+
64
+ def configure_optimizers(self):
65
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
66
+
67
+ def lr_lambda(step):
68
+ return 0.95 ** (step // 1000)
69
+
70
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
71
+ optimizer,
72
+ lr_lambda=lr_lambda
73
+ )
74
+
75
+ return {
76
+ "optimizer": optimizer,
77
+ "lr_scheduler": {
78
+ "scheduler": scheduler,
79
+ "interval": "step", # 按步更新
80
+ "frequency": 1
81
+ },
82
+ }
83
+
84
+
85
+ # 3. 设置 Trainer 和训练
86
+ if __name__ == "__main__":
87
+ pl.seed_everything(2024)
88
+ import os
89
+ os.environ['CUDA_VISIBLE_DEVICES'] = '7'
90
+ from dataset_pinyin import PairDataset
91
+ from dataloader_pinyin import get_dataloader
92
+
93
+ # 创建数据集
94
+ dataset1 = PairDataset(
95
+ '/nvme01/aizq/kws-agent/data/anchor_pairs.parquet',
96
+ '/nvme01/aizq/kws-agent/data/WenetPhrase_base/M_S',
97
+ augment=True
98
+ )
99
+ dataset2 = PairDataset(
100
+ '/nvme01/aizq/kws-agent/data/synthetic_pairs.parquet',
101
+ '/nvme01/aizq/kws-agent/data/WenetPhrase_base/M_S',
102
+ augment=True
103
+ )
104
+ dataset = ConcatDataset([dataset1, dataset2])
105
+ # 创建dataloader
106
+ dataloader = get_dataloader(dataset, batch_size=1024)
107
+
108
+ model = MMKWS2_Wrapper()
109
+ model_checkpoint = ModelCheckpoint(
110
+ dirpath="/nvme01/aizq/kws-agent/ckpts",
111
+ filename="step{step:06d}",
112
+ save_top_k=-1,
113
+ save_on_train_epoch_end=False, # 按训练步数保存
114
+ every_n_train_steps=500 # 每10k步保存一次
115
+ )
116
+ logger = pl.loggers.TensorBoardLogger('/nvme01/aizq/kws-agent/logs/', name='MMKWS+')
117
+ trainer = Trainer(
118
+ devices=1,
119
+ accelerator='gpu',
120
+ logger=logger,
121
+ max_epochs=4, # 训练1轮
122
+ callbacks=[model_checkpoint],
123
+ accumulate_grad_batches=2, # 2048 batchsize
124
+ )
125
+ trainer.fit(model, train_dataloaders=dataloader)
tts-2025-04-27@197cd135f2a2451b9cab9cf2add1c1ab.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3947a005d2c85f427cead98d3aeb65200bb85a956383aad43467b1b6a7f0c5f1
3
+ size 279404
ui.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import json
4
+ import time
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio
8
+ import gradio as gr
9
+ from torchaudio.compliance.kaldi import fbank
10
+ from pypinyin import lazy_pinyin
11
+ from train_pinyin import MMKWS2_Wrapper
12
+
13
+ # 设备与模型加载
14
+ # device = torch.device("cuda:4")
15
+ device = torch.device("cpu")
16
+ wrapper = MMKWS2_Wrapper.load_from_checkpoint(
17
+ "stepstep=024500.ckpt",
18
+ map_location=device,
19
+ )
20
+ wrapper.eval()
21
+
22
+ # 注册信息
23
+ registered = {"text": "", "audios": []}
24
+ enroll = None
25
+ enroll_text = None
26
+ last_wake_time = 0
27
+
28
+ def load_pinyin_index(save_path):
29
+ """加载拼音索引映射"""
30
+ with open(save_path, "r", encoding="utf-8") as f:
31
+ data = json.load(f)
32
+ return data["pinyin_to_index"], data["index_to_pinyin"]
33
+
34
+ pinyin_to_index, index_to_pinyin = load_pinyin_index("pinyin_index.json")
35
+
36
+ def add_audio(text, audio, audio_list):
37
+ if not text:
38
+ return audio_list, "请先输入唤醒词文本"
39
+ if audio is None or audio[1] is None or len(audio[1]) == 0:
40
+ return audio_list, "请上传或录制音频"
41
+ audio_list = audio_list or []
42
+ if len(audio_list) >= 5:
43
+ return audio_list, "最多支持5条音频"
44
+ audio_list.append(audio)
45
+ return audio_list, f"已录入 {len(audio_list)} 条音频"
46
+
47
+ def register_keyword(text, audio_list):
48
+ if not text:
49
+ return gr.update(value="请先输入唤醒词文本")
50
+ if not audio_list or len(audio_list) == 0:
51
+ return gr.update(value="请至少上传或录制一条音频")
52
+ registered["text"] = text
53
+ registered["audios"] = audio_list
54
+ global enroll_text
55
+ enroll_text = text
56
+ fused_feats = []
57
+ for audio in audio_list:
58
+ anchor_wave, _ = torchaudio.load(audio)
59
+ anchor_text_embedding = torch.tensor([pinyin_to_index[p] + 1 for p in lazy_pinyin(text)])
60
+ anchor_wave = anchor_wave.to(device)
61
+ anchor_text_embedding = anchor_text_embedding.to(device).unsqueeze(0)
62
+ with torch.no_grad():
63
+ outputs = wrapper._hubert_model(anchor_wave.half())
64
+ anchor_wave_embedding = outputs.last_hidden_state
65
+ anchor_wave_embedding = anchor_wave_embedding.to(anchor_wave.dtype)
66
+ fused_feat = wrapper.model.enrollment(
67
+ anchor_wave_embedding,
68
+ anchor_text_embedding
69
+ )
70
+ fused_feats.append(fused_feat)
71
+ fused_feats = torch.cat(fused_feats, dim=0)
72
+ fused_feats, _ = fused_feats.max(dim=0)
73
+ fused_feats = fused_feats.unsqueeze(0)
74
+ global enroll
75
+ enroll = fused_feats
76
+ return gr.update(value=f"注册完成,唤醒词:{text},音频数:{len(audio_list)}")
77
+
78
+ def update_gallery(audio_list):
79
+ if audio_list and len(audio_list) > 0:
80
+ return gr.update(visible=True, value=audio_list[-1])
81
+ else:
82
+ return gr.update(visible=False, value=None)
83
+
84
+ def streaming_detect_handler(current_audio, state, audio_player):
85
+ global last_wake_time, enroll_text, enroll
86
+ if current_audio is None or current_audio[1] is None or len(current_audio[1]) == 0:
87
+ return state, gr.update()
88
+ if enroll_text is None:
89
+ return state, gr.update()
90
+ pad = len(enroll_text) * 5
91
+ state = (state or []) + [current_audio]
92
+ state = state[-pad:]
93
+ if len(state) < pad:
94
+ return state, gr.update()
95
+ sr = state[0][0]
96
+ audio_list = [x[1] for x in state]
97
+ audio_concat = np.concatenate(audio_list, axis=0)
98
+ audio_concat = audio_concat.astype(np.float32) / 32768.0
99
+ audio_concat = torch.from_numpy(audio_concat).unsqueeze(0)
100
+ audio_concat = torchaudio.functional.resample(audio_concat, sr, 16000)
101
+ audio_concat = audio_concat / torch.max(torch.abs(audio_concat))
102
+ compare_wave = fbank(audio_concat, num_mel_bins=80)
103
+ compare_wave = compare_wave.to(device).unsqueeze(0)
104
+ compare_lengths = torch.tensor([compare_wave.size(1)], device=compare_wave.device)
105
+ if enroll is None:
106
+ return state, None
107
+ current_time = time.time()
108
+ if current_time - last_wake_time <= 2:
109
+ return state, gr.update()
110
+ with torch.no_grad():
111
+ preds = wrapper.model.verification(
112
+ enroll,
113
+ compare_wave,
114
+ compare_lengths
115
+ )
116
+ preds = torch.sigmoid(preds).item()
117
+ if preds >= 0.85:
118
+ print(f"Wake up! {preds}")
119
+ last_wake_time = current_time
120
+ audio_path = "tts-2025-04-27@197cd135f2a2451b9cab9cf2add1c1ab.wav"
121
+ state = []
122
+ return state, gr.Audio(value=audio_path, visible=True, autoplay=True)
123
+ if preds >= 0.6:
124
+ print(f"Preds! {preds}")
125
+ return state, None
126
+
127
+ with gr.Blocks() as demo:
128
+ gr.Markdown("# 自定义关键词检测 Demo")
129
+ with gr.Row():
130
+ with gr.Column(scale=1):
131
+ gr.Markdown("## 注册唤醒词")
132
+ text_input = gr.Textbox(label="唤醒词文本", placeholder="请输入唤醒词")
133
+ audio_list_state = gr.State([])
134
+ audio_input = gr.Audio(label="上传或录制音频", type="filepath")
135
+ add_btn = gr.Button("添加音频")
136
+ audio_status = gr.Textbox(label="音频状态", interactive=False)
137
+ audio_gallery = gr.Audio(label="已添加音频", type="filepath", interactive=False, visible=False)
138
+ register_btn = gr.Button("注册完成")
139
+ register_status = gr.Textbox(label="注册状态", interactive=False)
140
+ add_btn.click(
141
+ add_audio,
142
+ inputs=[text_input, audio_input, audio_list_state],
143
+ outputs=[audio_list_state, audio_status]
144
+ ).then(
145
+ update_gallery,
146
+ inputs=audio_list_state,
147
+ outputs=audio_gallery
148
+ )
149
+ register_btn.click(
150
+ register_keyword,
151
+ inputs=[text_input, audio_list_state],
152
+ outputs=register_status
153
+ )
154
+ with gr.Column(scale=2):
155
+ gr.Markdown("## 实时检测")
156
+ mic = gr.Audio(sources="microphone", streaming=True, label="实时监听")
157
+ state = gr.State(value=[])
158
+ audio_player = gr.Audio(label="唤醒提示", visible=False)
159
+ mic.stream(
160
+ streaming_detect_handler,
161
+ inputs=[mic, state, audio_player],
162
+ outputs=[state, audio_player],
163
+ time_limit=1000,
164
+ stream_every=0.05,
165
+ )
166
+ if __name__ == "__main__":
167
+ demo.launch()