tahirturk commited on
Commit
686888f
·
1 Parent(s): a9bcfd7

src changes

Browse files
Files changed (49) hide show
  1. src/chatterbox/__init__.py +0 -3
  2. src/chatterbox/models/__init__.py +0 -0
  3. src/chatterbox/models/s3gen/__init__.py +0 -2
  4. src/chatterbox/models/s3gen/configs.py +0 -10
  5. src/chatterbox/models/s3gen/const.py +0 -1
  6. src/chatterbox/models/s3gen/decoder.py +0 -317
  7. src/chatterbox/models/s3gen/f0_predictor.py +0 -55
  8. src/chatterbox/models/s3gen/flow.py +0 -290
  9. src/chatterbox/models/s3gen/flow_matching.py +0 -218
  10. src/chatterbox/models/s3gen/hifigan.py +0 -474
  11. src/chatterbox/models/s3gen/matcha/decoder.py +0 -443
  12. src/chatterbox/models/s3gen/matcha/flow_matching.py +0 -129
  13. src/chatterbox/models/s3gen/matcha/text_encoder.py +0 -413
  14. src/chatterbox/models/s3gen/matcha/transformer.py +0 -316
  15. src/chatterbox/models/s3gen/s3gen.py +0 -298
  16. src/chatterbox/models/s3gen/transformer/__init__.py +0 -0
  17. src/chatterbox/models/s3gen/transformer/activation.py +0 -84
  18. src/chatterbox/models/s3gen/transformer/attention.py +0 -330
  19. src/chatterbox/models/s3gen/transformer/convolution.py +0 -145
  20. src/chatterbox/models/s3gen/transformer/embedding.py +0 -294
  21. src/chatterbox/models/s3gen/transformer/encoder_layer.py +0 -236
  22. src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py +0 -115
  23. src/chatterbox/models/s3gen/transformer/subsampling.py +0 -383
  24. src/chatterbox/models/s3gen/transformer/upsample_encoder.py +0 -318
  25. src/chatterbox/models/s3gen/utils/class_utils.py +0 -71
  26. src/chatterbox/models/s3gen/utils/mask.py +0 -193
  27. src/chatterbox/models/s3gen/utils/mel.py +0 -85
  28. src/chatterbox/models/s3gen/xvector.py +0 -428
  29. src/chatterbox/models/s3tokenizer/__init__.py +0 -30
  30. src/chatterbox/models/s3tokenizer/s3tokenizer.py +0 -168
  31. src/chatterbox/models/t3/__init__.py +0 -1
  32. src/chatterbox/models/t3/inference/alignment_stream_analyzer.py +0 -178
  33. src/chatterbox/models/t3/inference/t3_hf_backend.py +0 -116
  34. src/chatterbox/models/t3/llama_configs.py +0 -37
  35. src/chatterbox/models/t3/modules/cond_enc.py +0 -97
  36. src/chatterbox/models/t3/modules/learned_pos_emb.py +0 -32
  37. src/chatterbox/models/t3/modules/perceiver.py +0 -212
  38. src/chatterbox/models/t3/modules/t3_config.py +0 -41
  39. src/chatterbox/models/t3/t3.py +0 -394
  40. src/chatterbox/models/tokenizers/__init__.py +0 -1
  41. src/chatterbox/models/tokenizers/tokenizer.py +0 -312
  42. src/chatterbox/models/utils.py +0 -4
  43. src/chatterbox/models/voice_encoder/__init__.py +0 -1
  44. src/chatterbox/models/voice_encoder/config.py +0 -18
  45. src/chatterbox/models/voice_encoder/melspec.py +0 -78
  46. src/chatterbox/models/voice_encoder/voice_encoder.py +0 -274
  47. src/chatterbox/mtl_tts.py +0 -301
  48. src/chatterbox/tts.py +0 -272
  49. src/chatterbox/vc.py +0 -104
src/chatterbox/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .tts import ChatterboxTTS
2
- from .vc import ChatterboxVC
3
- from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
 
 
 
 
src/chatterbox/models/__init__.py DELETED
File without changes
src/chatterbox/models/s3gen/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .s3gen import S3Token2Wav as S3Gen
2
- from .const import S3GEN_SR
 
 
 
src/chatterbox/models/s3gen/configs.py DELETED
@@ -1,10 +0,0 @@
1
- from ..utils import AttrDict
2
-
3
- CFM_PARAMS = AttrDict({
4
- "sigma_min": 1e-06,
5
- "solver": "euler",
6
- "t_scheduler": "cosine",
7
- "training_cfg_rate": 0.2,
8
- "inference_cfg_rate": 0.7,
9
- "reg_loss_type": "l1"
10
- })
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/const.py DELETED
@@ -1 +0,0 @@
1
- S3GEN_SR = 24000
 
 
src/chatterbox/models/s3gen/decoder.py DELETED
@@ -1,317 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- from einops import pack, rearrange, repeat
18
-
19
- from .utils.mask import add_optional_chunk_mask
20
- from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
21
- TimestepEmbedding, Upsample1D
22
- from .matcha.transformer import BasicTransformerBlock
23
-
24
-
25
- def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
26
- assert mask.dtype == torch.bool
27
- assert dtype in [torch.float32, torch.bfloat16, torch.float16]
28
- mask = mask.to(dtype)
29
- # attention mask bias
30
- # NOTE(Mddct): torch.finfo jit issues
31
- # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
32
- mask = (1.0 - mask) * -1.0e+10
33
- return mask
34
-
35
-
36
-
37
- class Transpose(torch.nn.Module):
38
- def __init__(self, dim0: int, dim1: int):
39
- super().__init__()
40
- self.dim0 = dim0
41
- self.dim1 = dim1
42
-
43
- def forward(self, x: torch.Tensor):
44
- x = torch.transpose(x, self.dim0, self.dim1)
45
- return x
46
-
47
-
48
- class CausalBlock1D(Block1D):
49
- def __init__(self, dim: int, dim_out: int):
50
- super(CausalBlock1D, self).__init__(dim, dim_out)
51
- self.block = torch.nn.Sequential(
52
- CausalConv1d(dim, dim_out, 3),
53
- Transpose(1, 2),
54
- nn.LayerNorm(dim_out),
55
- Transpose(1, 2),
56
- nn.Mish(),
57
- )
58
-
59
- def forward(self, x: torch.Tensor, mask: torch.Tensor):
60
- output = self.block(x * mask)
61
- return output * mask
62
-
63
-
64
- class CausalResnetBlock1D(ResnetBlock1D):
65
- def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
66
- super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
67
- self.block1 = CausalBlock1D(dim, dim_out)
68
- self.block2 = CausalBlock1D(dim_out, dim_out)
69
-
70
-
71
- class CausalConv1d(torch.nn.Conv1d):
72
- def __init__(
73
- self,
74
- in_channels: int,
75
- out_channels: int,
76
- kernel_size: int,
77
- stride: int = 1,
78
- dilation: int = 1,
79
- groups: int = 1,
80
- bias: bool = True,
81
- padding_mode: str = 'zeros',
82
- device=None,
83
- dtype=None
84
- ) -> None:
85
- super(CausalConv1d, self).__init__(in_channels, out_channels,
86
- kernel_size, stride,
87
- padding=0, dilation=dilation,
88
- groups=groups, bias=bias,
89
- padding_mode=padding_mode,
90
- device=device, dtype=dtype)
91
- assert stride == 1
92
- self.causal_padding = (kernel_size - 1, 0)
93
-
94
- def forward(self, x: torch.Tensor):
95
- x = F.pad(x, self.causal_padding)
96
- x = super(CausalConv1d, self).forward(x)
97
- return x
98
-
99
-
100
- class ConditionalDecoder(nn.Module):
101
- def __init__(
102
- self,
103
- in_channels=320,
104
- out_channels=80,
105
- causal=True,
106
- channels=[256],
107
- dropout=0.0,
108
- attention_head_dim=64,
109
- n_blocks=4,
110
- num_mid_blocks=12,
111
- num_heads=8,
112
- act_fn="gelu",
113
- ):
114
- """
115
- This decoder requires an input with the same shape of the target. So, if your text content
116
- is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
117
- """
118
- super().__init__()
119
- channels = tuple(channels)
120
- self.in_channels = in_channels
121
- self.out_channels = out_channels
122
- self.causal = causal
123
- self.time_embeddings = SinusoidalPosEmb(in_channels)
124
- time_embed_dim = channels[0] * 4
125
- self.time_mlp = TimestepEmbedding(
126
- in_channels=in_channels,
127
- time_embed_dim=time_embed_dim,
128
- act_fn="silu",
129
- )
130
- self.down_blocks = nn.ModuleList([])
131
- self.mid_blocks = nn.ModuleList([])
132
- self.up_blocks = nn.ModuleList([])
133
-
134
- # NOTE jrm: `static_chunk_size` is missing?
135
- self.static_chunk_size = 0
136
-
137
- output_channel = in_channels
138
- for i in range(len(channels)): # pylint: disable=consider-using-enumerate
139
- input_channel = output_channel
140
- output_channel = channels[i]
141
- is_last = i == len(channels) - 1
142
- resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
143
- ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
144
- transformer_blocks = nn.ModuleList(
145
- [
146
- BasicTransformerBlock(
147
- dim=output_channel,
148
- num_attention_heads=num_heads,
149
- attention_head_dim=attention_head_dim,
150
- dropout=dropout,
151
- activation_fn=act_fn,
152
- )
153
- for _ in range(n_blocks)
154
- ]
155
- )
156
- downsample = (
157
- Downsample1D(output_channel) if not is_last else
158
- CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
159
- )
160
- self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
161
-
162
- for _ in range(num_mid_blocks):
163
- input_channel = channels[-1]
164
- out_channels = channels[-1]
165
- resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
166
- ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
167
-
168
- transformer_blocks = nn.ModuleList(
169
- [
170
- BasicTransformerBlock(
171
- dim=output_channel,
172
- num_attention_heads=num_heads,
173
- attention_head_dim=attention_head_dim,
174
- dropout=dropout,
175
- activation_fn=act_fn,
176
- )
177
- for _ in range(n_blocks)
178
- ]
179
- )
180
-
181
- self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
182
-
183
- channels = channels[::-1] + (channels[0],)
184
- for i in range(len(channels) - 1):
185
- input_channel = channels[i] * 2
186
- output_channel = channels[i + 1]
187
- is_last = i == len(channels) - 2
188
- resnet = CausalResnetBlock1D(
189
- dim=input_channel,
190
- dim_out=output_channel,
191
- time_emb_dim=time_embed_dim,
192
- ) if self.causal else ResnetBlock1D(
193
- dim=input_channel,
194
- dim_out=output_channel,
195
- time_emb_dim=time_embed_dim,
196
- )
197
- transformer_blocks = nn.ModuleList(
198
- [
199
- BasicTransformerBlock(
200
- dim=output_channel,
201
- num_attention_heads=num_heads,
202
- attention_head_dim=attention_head_dim,
203
- dropout=dropout,
204
- activation_fn=act_fn,
205
- )
206
- for _ in range(n_blocks)
207
- ]
208
- )
209
- upsample = (
210
- Upsample1D(output_channel, use_conv_transpose=True)
211
- if not is_last
212
- else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
213
- )
214
- self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
215
- self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
216
- self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
217
- self.initialize_weights()
218
-
219
- def initialize_weights(self):
220
- for m in self.modules():
221
- if isinstance(m, nn.Conv1d):
222
- nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
223
- if m.bias is not None:
224
- nn.init.constant_(m.bias, 0)
225
- elif isinstance(m, nn.GroupNorm):
226
- nn.init.constant_(m.weight, 1)
227
- nn.init.constant_(m.bias, 0)
228
- elif isinstance(m, nn.Linear):
229
- nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
230
- if m.bias is not None:
231
- nn.init.constant_(m.bias, 0)
232
-
233
- def forward(self, x, mask, mu, t, spks=None, cond=None):
234
- """Forward pass of the UNet1DConditional model.
235
-
236
- Args:
237
- x (torch.Tensor): shape (batch_size, in_channels, time)
238
- mask (_type_): shape (batch_size, 1, time)
239
- t (_type_): shape (batch_size)
240
- spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
241
- cond (_type_, optional): placeholder for future use. Defaults to None.
242
-
243
- Raises:
244
- ValueError: _description_
245
- ValueError: _description_
246
-
247
- Returns:
248
- _type_: _description_
249
- """
250
-
251
- t = self.time_embeddings(t).to(t.dtype)
252
- t = self.time_mlp(t)
253
-
254
- x = pack([x, mu], "b * t")[0]
255
-
256
- if spks is not None:
257
- spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
258
- x = pack([x, spks], "b * t")[0]
259
- if cond is not None:
260
- x = pack([x, cond], "b * t")[0]
261
-
262
- hiddens = []
263
- masks = [mask]
264
- for resnet, transformer_blocks, downsample in self.down_blocks:
265
- mask_down = masks[-1]
266
- x = resnet(x, mask_down, t)
267
- x = rearrange(x, "b c t -> b t c").contiguous()
268
- # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
269
- attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
270
- attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
271
- for transformer_block in transformer_blocks:
272
- x = transformer_block(
273
- hidden_states=x,
274
- attention_mask=attn_mask,
275
- timestep=t,
276
- )
277
- x = rearrange(x, "b t c -> b c t").contiguous()
278
- hiddens.append(x) # Save hidden states for skip connections
279
- x = downsample(x * mask_down)
280
- masks.append(mask_down[:, :, ::2])
281
- masks = masks[:-1]
282
- mask_mid = masks[-1]
283
-
284
- for resnet, transformer_blocks in self.mid_blocks:
285
- x = resnet(x, mask_mid, t)
286
- x = rearrange(x, "b c t -> b t c").contiguous()
287
- # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
288
- attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
289
- attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
290
- for transformer_block in transformer_blocks:
291
- x = transformer_block(
292
- hidden_states=x,
293
- attention_mask=attn_mask,
294
- timestep=t,
295
- )
296
- x = rearrange(x, "b t c -> b c t").contiguous()
297
-
298
- for resnet, transformer_blocks, upsample in self.up_blocks:
299
- mask_up = masks.pop()
300
- skip = hiddens.pop()
301
- x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
302
- x = resnet(x, mask_up, t)
303
- x = rearrange(x, "b c t -> b t c").contiguous()
304
- # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
305
- attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
306
- attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
307
- for transformer_block in transformer_blocks:
308
- x = transformer_block(
309
- hidden_states=x,
310
- attention_mask=attn_mask,
311
- timestep=t,
312
- )
313
- x = rearrange(x, "b t c -> b c t").contiguous()
314
- x = upsample(x * mask_up)
315
- x = self.final_block(x, mask_up)
316
- output = self.final_proj(x * mask_up)
317
- return output * mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/f0_predictor.py DELETED
@@ -1,55 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import torch
15
- import torch.nn as nn
16
- from torch.nn.utils.parametrizations import weight_norm
17
-
18
-
19
- class ConvRNNF0Predictor(nn.Module):
20
- def __init__(self,
21
- num_class: int = 1,
22
- in_channels: int = 80,
23
- cond_channels: int = 512
24
- ):
25
- super().__init__()
26
-
27
- self.num_class = num_class
28
- self.condnet = nn.Sequential(
29
- weight_norm(
30
- nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31
- ),
32
- nn.ELU(),
33
- weight_norm(
34
- nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35
- ),
36
- nn.ELU(),
37
- weight_norm(
38
- nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
- ),
40
- nn.ELU(),
41
- weight_norm(
42
- nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
- ),
44
- nn.ELU(),
45
- weight_norm(
46
- nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
- ),
48
- nn.ELU(),
49
- )
50
- self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51
-
52
- def forward(self, x: torch.Tensor) -> torch.Tensor:
53
- x = self.condnet(x)
54
- x = x.transpose(1, 2)
55
- return torch.abs(self.classifier(x).squeeze(-1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/flow.py DELETED
@@ -1,290 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import logging
15
- import random
16
- from typing import Dict, Optional
17
-
18
- logger = logging.getLogger(__name__)
19
- import torch
20
- import torch.nn as nn
21
- from torch.nn import functional as F
22
- from .utils.mask import make_pad_mask
23
- from .configs import CFM_PARAMS
24
-
25
-
26
- class MaskedDiffWithXvec(torch.nn.Module):
27
- def __init__(
28
- self,
29
- input_size: int = 512,
30
- output_size: int = 80,
31
- spk_embed_dim: int = 192,
32
- output_type: str = "mel",
33
- vocab_size: int = 4096,
34
- input_frame_rate: int = 50,
35
- only_mask_loss: bool = True,
36
- encoder: torch.nn.Module = None,
37
- length_regulator: torch.nn.Module = None,
38
- decoder: torch.nn.Module = None,
39
- decoder_conf: Dict = {
40
- 'in_channels': 240,
41
- 'out_channel': 80,
42
- 'spk_emb_dim': 80,
43
- 'n_spks': 1,
44
- 'cfm_params': CFM_PARAMS,
45
- 'decoder_params': {
46
- 'channels': [256, 256],
47
- 'dropout': 0.0,
48
- 'attention_head_dim': 64,
49
- 'n_blocks': 4,
50
- 'num_mid_blocks': 12,
51
- 'num_heads': 8,
52
- 'act_fn': 'gelu',
53
- }
54
- },
55
- mel_feat_conf: Dict = {
56
- 'n_fft': 1024,
57
- 'num_mels': 80,
58
- 'sampling_rate': 22050,
59
- 'hop_size': 256,
60
- 'win_size': 1024,
61
- 'fmin': 0,
62
- 'fmax': 8000
63
- }
64
- ):
65
- super().__init__()
66
- self.input_size = input_size
67
- self.output_size = output_size
68
- self.decoder_conf = decoder_conf
69
- self.mel_feat_conf = mel_feat_conf
70
- self.vocab_size = vocab_size
71
- self.output_type = output_type
72
- self.input_frame_rate = input_frame_rate
73
- logging.info(f"input frame rate={self.input_frame_rate}")
74
- self.input_embedding = nn.Embedding(vocab_size, input_size)
75
- self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
76
- self.encoder = encoder
77
- self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
78
- self.decoder = decoder
79
- self.length_regulator = length_regulator
80
- self.only_mask_loss = only_mask_loss
81
-
82
- def forward(
83
- self,
84
- batch: dict,
85
- device: torch.device,
86
- ) -> Dict[str, Optional[torch.Tensor]]:
87
- token = batch['speech_token'].to(device)
88
- token_len = batch['speech_token_len'].to(device)
89
- feat = batch['speech_feat'].to(device)
90
- feat_len = batch['speech_feat_len'].to(device)
91
- embedding = batch['embedding'].to(device)
92
-
93
- # xvec projection
94
- embedding = F.normalize(embedding, dim=1)
95
- embedding = self.spk_embed_affine_layer(embedding)
96
-
97
- # concat text and prompt_text
98
- mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
99
- token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
100
-
101
- # text encode
102
- h, h_lengths = self.encoder(token, token_len)
103
- h = self.encoder_proj(h)
104
- h, h_lengths = self.length_regulator(h, feat_len)
105
-
106
- # get conditions
107
- conds = torch.zeros(feat.shape, device=token.device)
108
- for i, j in enumerate(feat_len):
109
- if random.random() < 0.5:
110
- continue
111
- index = random.randint(0, int(0.3 * j))
112
- conds[i, :index] = feat[i, :index]
113
- conds = conds.transpose(1, 2)
114
-
115
- mask = (~make_pad_mask(feat_len)).to(h)
116
- feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
117
- loss, _ = self.decoder.compute_loss(
118
- feat.transpose(1, 2).contiguous(),
119
- mask.unsqueeze(1),
120
- h.transpose(1, 2).contiguous(),
121
- embedding,
122
- cond=conds
123
- )
124
- return {'loss': loss}
125
-
126
- @torch.inference_mode()
127
- def inference(self,
128
- token,
129
- token_len,
130
- prompt_token,
131
- prompt_token_len,
132
- prompt_feat,
133
- prompt_feat_len,
134
- embedding,
135
- flow_cache):
136
- if self.fp16 is True:
137
- prompt_feat = prompt_feat.half()
138
- embedding = embedding.half()
139
-
140
- assert token.shape[0] == 1
141
- # xvec projection
142
- embedding = F.normalize(embedding, dim=1)
143
- embedding = self.spk_embed_affine_layer(embedding)
144
-
145
- # concat text and prompt_text
146
- token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
147
- token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
148
- mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
149
-
150
- # Check for out-of-bounds token IDs
151
- vocab_size = self.input_embedding.num_embeddings
152
- if token.max() >= vocab_size or token.min() < 0:
153
- logging.warning(f"S3Gen: Token IDs out of bounds: min={token.min().item()}, max={token.max().item()}, vocab_size={vocab_size}")
154
-
155
- token = self.input_embedding(torch.clamp(token, min=0, max=vocab_size-1)) * mask
156
-
157
- # text encode
158
- h, h_lengths = self.encoder(token, token_len)
159
- h = self.encoder_proj(h)
160
- mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
161
- h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
162
-
163
- # get conditions
164
- conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
165
- conds[:, :mel_len1] = prompt_feat
166
- conds = conds.transpose(1, 2)
167
-
168
- mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
169
- feat, flow_cache = self.decoder(
170
- mu=h.transpose(1, 2).contiguous(),
171
- mask=mask.unsqueeze(1),
172
- spks=embedding,
173
- cond=conds,
174
- n_timesteps=10,
175
- prompt_len=mel_len1,
176
- flow_cache=flow_cache
177
- )
178
- feat = feat[:, :, mel_len1:]
179
- assert feat.shape[2] == mel_len2
180
- return feat.float(), flow_cache
181
-
182
-
183
- class CausalMaskedDiffWithXvec(torch.nn.Module):
184
- def __init__(
185
- self,
186
- input_size: int = 512,
187
- output_size: int = 80,
188
- spk_embed_dim: int = 192,
189
- output_type: str = "mel",
190
- vocab_size: int = 6561,
191
- input_frame_rate: int = 25,
192
- only_mask_loss: bool = True,
193
- token_mel_ratio: int = 2,
194
- pre_lookahead_len: int = 3,
195
- encoder: torch.nn.Module = None,
196
- decoder: torch.nn.Module = None,
197
- decoder_conf: Dict = {
198
- 'in_channels': 240,
199
- 'out_channel': 80,
200
- 'spk_emb_dim': 80,
201
- 'n_spks': 1,
202
- 'cfm_params': CFM_PARAMS,
203
- 'decoder_params': {
204
- 'channels': [256, 256],
205
- 'dropout': 0.0,
206
- 'attention_head_dim': 64,
207
- 'n_blocks': 4,
208
- 'num_mid_blocks': 12,
209
- 'num_heads': 8,
210
- 'act_fn': 'gelu',
211
- }
212
- },
213
- mel_feat_conf: Dict = {
214
- 'n_fft': 1024,
215
- 'num_mels': 80,
216
- 'sampling_rate': 22050,
217
- 'hop_size': 256,
218
- 'win_size': 1024,
219
- 'fmin': 0,
220
- 'fmax': 8000
221
- }
222
- ):
223
- super().__init__()
224
- self.input_size = input_size
225
- self.output_size = output_size
226
- self.decoder_conf = decoder_conf
227
- self.mel_feat_conf = mel_feat_conf
228
- self.vocab_size = vocab_size
229
- self.output_type = output_type
230
- self.input_frame_rate = input_frame_rate
231
- logging.info(f"input frame rate={self.input_frame_rate}")
232
- self.input_embedding = nn.Embedding(vocab_size, input_size)
233
- self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
234
- self.encoder = encoder
235
- self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
236
- self.decoder = decoder
237
- self.only_mask_loss = only_mask_loss
238
- self.token_mel_ratio = token_mel_ratio
239
- self.pre_lookahead_len = pre_lookahead_len
240
-
241
- # FIXME: this was missing - just putting it in as false
242
- self.fp16 = False
243
-
244
- @torch.inference_mode()
245
- def inference(self,
246
- token,
247
- token_len,
248
- prompt_token,
249
- prompt_token_len,
250
- prompt_feat,
251
- prompt_feat_len,
252
- embedding,
253
- finalize):
254
- if self.fp16 is True:
255
- prompt_feat = prompt_feat.half()
256
- embedding = embedding.half()
257
-
258
- assert token.shape[0] == 1
259
- # xvec projection
260
- embedding = F.normalize(embedding, dim=1)
261
- embedding = self.spk_embed_affine_layer(embedding)
262
-
263
- # concat text and prompt_text
264
- token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
265
- mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
266
- token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
267
-
268
- # text encode
269
- h, h_lengths = self.encoder(token, token_len)
270
- if finalize is False:
271
- h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
272
- mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
273
- h = self.encoder_proj(h)
274
-
275
- # get conditions
276
- conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
277
- conds[:, :mel_len1] = prompt_feat
278
- conds = conds.transpose(1, 2)
279
-
280
- mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
281
- feat, _ = self.decoder(
282
- mu=h.transpose(1, 2).contiguous(),
283
- mask=mask.unsqueeze(1),
284
- spks=embedding,
285
- cond=conds,
286
- n_timesteps=10
287
- )
288
- feat = feat[:, :, mel_len1:]
289
- assert feat.shape[2] == mel_len2
290
- return feat.float(), None # NOTE jrm: why are they returning None here?
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/flow_matching.py DELETED
@@ -1,218 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import threading
15
- import torch
16
- import torch.nn.functional as F
17
- from .matcha.flow_matching import BASECFM
18
- from .configs import CFM_PARAMS
19
-
20
-
21
- class ConditionalCFM(BASECFM):
22
- def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
23
- super().__init__(
24
- n_feats=in_channels,
25
- cfm_params=cfm_params,
26
- n_spks=n_spks,
27
- spk_emb_dim=spk_emb_dim,
28
- )
29
- self.t_scheduler = cfm_params.t_scheduler
30
- self.training_cfg_rate = cfm_params.training_cfg_rate
31
- self.inference_cfg_rate = cfm_params.inference_cfg_rate
32
- in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
33
- # Just change the architecture of the estimator here
34
- self.estimator = estimator
35
- self.lock = threading.Lock()
36
-
37
- @torch.inference_mode()
38
- def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
39
- """Forward diffusion
40
-
41
- Args:
42
- mu (torch.Tensor): output of encoder
43
- shape: (batch_size, n_feats, mel_timesteps)
44
- mask (torch.Tensor): output_mask
45
- shape: (batch_size, 1, mel_timesteps)
46
- n_timesteps (int): number of diffusion steps
47
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
48
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
49
- shape: (batch_size, spk_emb_dim)
50
- cond: Not used but kept for future purposes
51
-
52
- Returns:
53
- sample: generated mel-spectrogram
54
- shape: (batch_size, n_feats, mel_timesteps)
55
- """
56
-
57
- z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
58
- cache_size = flow_cache.shape[2]
59
- # fix prompt and overlap part mu and z
60
- if cache_size != 0:
61
- z[:, :, :cache_size] = flow_cache[:, :, :, 0]
62
- mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
63
- z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
64
- mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
65
- flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
66
-
67
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
68
- if self.t_scheduler == 'cosine':
69
- t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
70
- return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
71
-
72
- def solve_euler(self, x, t_span, mu, mask, spks, cond):
73
- """
74
- Fixed euler solver for ODEs.
75
- Args:
76
- x (torch.Tensor): random noise
77
- t_span (torch.Tensor): n_timesteps interpolated
78
- shape: (n_timesteps + 1,)
79
- mu (torch.Tensor): output of encoder
80
- shape: (batch_size, n_feats, mel_timesteps)
81
- mask (torch.Tensor): output_mask
82
- shape: (batch_size, 1, mel_timesteps)
83
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
84
- shape: (batch_size, spk_emb_dim)
85
- cond: Not used but kept for future purposes
86
- """
87
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
88
- t = t.unsqueeze(dim=0)
89
-
90
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
91
- # Or in future might add like a return_all_steps flag
92
- sol = []
93
-
94
- # Do not use concat, it may cause memory format changed and trt infer with wrong results!
95
- x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
96
- mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
97
- mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
98
- t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
99
- spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
100
- cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
101
- for step in range(1, len(t_span)):
102
- # Classifier-Free Guidance inference introduced in VoiceBox
103
- x_in[:] = x
104
- mask_in[:] = mask
105
- mu_in[0] = mu
106
- t_in[:] = t.unsqueeze(0)
107
- spks_in[0] = spks
108
- cond_in[0] = cond
109
- dphi_dt = self.forward_estimator(
110
- x_in, mask_in,
111
- mu_in, t_in,
112
- spks_in,
113
- cond_in
114
- )
115
- dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
116
- dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
117
- x = x + dt * dphi_dt
118
- t = t + dt
119
- sol.append(x)
120
- if step < len(t_span) - 1:
121
- dt = t_span[step + 1] - t
122
-
123
- return sol[-1].float()
124
-
125
- def forward_estimator(self, x, mask, mu, t, spks, cond):
126
- if isinstance(self.estimator, torch.nn.Module):
127
- return self.estimator.forward(x, mask, mu, t, spks, cond)
128
- else:
129
- with self.lock:
130
- self.estimator.set_input_shape('x', (2, 80, x.size(2)))
131
- self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
132
- self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
133
- self.estimator.set_input_shape('t', (2,))
134
- self.estimator.set_input_shape('spks', (2, 80))
135
- self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
136
- # run trt engine
137
- self.estimator.execute_v2([x.contiguous().data_ptr(),
138
- mask.contiguous().data_ptr(),
139
- mu.contiguous().data_ptr(),
140
- t.contiguous().data_ptr(),
141
- spks.contiguous().data_ptr(),
142
- cond.contiguous().data_ptr(),
143
- x.data_ptr()])
144
- return x
145
-
146
- def compute_loss(self, x1, mask, mu, spks=None, cond=None):
147
- """Computes diffusion loss
148
-
149
- Args:
150
- x1 (torch.Tensor): Target
151
- shape: (batch_size, n_feats, mel_timesteps)
152
- mask (torch.Tensor): target mask
153
- shape: (batch_size, 1, mel_timesteps)
154
- mu (torch.Tensor): output of encoder
155
- shape: (batch_size, n_feats, mel_timesteps)
156
- spks (torch.Tensor, optional): speaker embedding. Defaults to None.
157
- shape: (batch_size, spk_emb_dim)
158
-
159
- Returns:
160
- loss: conditional flow matching loss
161
- y: conditional flow
162
- shape: (batch_size, n_feats, mel_timesteps)
163
- """
164
- b, _, t = mu.shape
165
-
166
- # random timestep
167
- t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
168
- if self.t_scheduler == 'cosine':
169
- t = 1 - torch.cos(t * 0.5 * torch.pi)
170
- # sample noise p(x_0)
171
- z = torch.randn_like(x1)
172
-
173
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
174
- u = x1 - (1 - self.sigma_min) * z
175
-
176
- # during training, we randomly drop condition to trade off mode coverage and sample fidelity
177
- if self.training_cfg_rate > 0:
178
- cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
179
- mu = mu * cfg_mask.view(-1, 1, 1)
180
- spks = spks * cfg_mask.view(-1, 1)
181
- cond = cond * cfg_mask.view(-1, 1, 1)
182
-
183
- pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
184
- loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
185
- return loss, y
186
-
187
-
188
- class CausalConditionalCFM(ConditionalCFM):
189
- def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
190
- super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
191
- self.rand_noise = torch.randn([1, 80, 50 * 300])
192
-
193
- @torch.inference_mode()
194
- def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
195
- """Forward diffusion
196
-
197
- Args:
198
- mu (torch.Tensor): output of encoder
199
- shape: (batch_size, n_feats, mel_timesteps)
200
- mask (torch.Tensor): output_mask
201
- shape: (batch_size, 1, mel_timesteps)
202
- n_timesteps (int): number of diffusion steps
203
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
204
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
205
- shape: (batch_size, spk_emb_dim)
206
- cond: Not used but kept for future purposes
207
-
208
- Returns:
209
- sample: generated mel-spectrogram
210
- shape: (batch_size, n_feats, mel_timesteps)
211
- """
212
-
213
- z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
214
- # fix prompt and overlap part mu and z
215
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
216
- if self.t_scheduler == 'cosine':
217
- t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
218
- return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/hifigan.py DELETED
@@ -1,474 +0,0 @@
1
- # jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py
2
- # most modules should be reusable, but I found their SineGen changed a git.
3
-
4
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
-
18
- """HIFI-GAN"""
19
-
20
- from typing import Dict, Optional, List
21
- import numpy as np
22
- from scipy.signal import get_window
23
- import torch
24
- import torch.nn.functional as F
25
- from torch.nn import Conv1d
26
- from torch.nn import ConvTranspose1d
27
- from torch.nn.utils import remove_weight_norm
28
- from torch.nn.utils.parametrizations import weight_norm
29
- from torch.distributions.uniform import Uniform
30
- from torch import nn, sin, pow
31
- from torch.nn import Parameter
32
-
33
-
34
- class Snake(nn.Module):
35
- '''
36
- Implementation of a sine-based periodic activation function
37
- Shape:
38
- - Input: (B, C, T)
39
- - Output: (B, C, T), same shape as the input
40
- Parameters:
41
- - alpha - trainable parameter
42
- References:
43
- - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
- https://arxiv.org/abs/2006.08195
45
- Examples:
46
- >>> a1 = snake(256)
47
- >>> x = torch.randn(256)
48
- >>> x = a1(x)
49
- '''
50
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
- '''
52
- Initialization.
53
- INPUT:
54
- - in_features: shape of the input
55
- - alpha: trainable parameter
56
- alpha is initialized to 1 by default, higher values = higher-frequency.
57
- alpha will be trained along with the rest of your model.
58
- '''
59
- super(Snake, self).__init__()
60
- self.in_features = in_features
61
-
62
- # initialize alpha
63
- self.alpha_logscale = alpha_logscale
64
- if self.alpha_logscale: # log scale alphas initialized to zeros
65
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
- else: # linear scale alphas initialized to ones
67
- self.alpha = Parameter(torch.ones(in_features) * alpha)
68
-
69
- self.alpha.requires_grad = alpha_trainable
70
-
71
- self.no_div_by_zero = 0.000000001
72
-
73
- def forward(self, x):
74
- '''
75
- Forward pass of the function.
76
- Applies the function to the input elementwise.
77
- Snake ∶= x + 1/a * sin^2 (xa)
78
- '''
79
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
- if self.alpha_logscale:
81
- alpha = torch.exp(alpha)
82
- x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
-
84
- return x
85
-
86
-
87
-
88
- def get_padding(kernel_size, dilation=1):
89
- return int((kernel_size * dilation - dilation) / 2)
90
-
91
- def init_weights(m, mean=0.0, std=0.01):
92
- classname = m.__class__.__name__
93
- if classname.find("Conv") != -1:
94
- m.weight.data.normal_(mean, std)
95
-
96
-
97
- """hifigan based generator implementation.
98
-
99
- This code is modified from https://github.com/jik876/hifi-gan
100
- ,https://github.com/kan-bayashi/ParallelWaveGAN and
101
- https://github.com/NVIDIA/BigVGAN
102
-
103
- """
104
-
105
-
106
- class ResBlock(torch.nn.Module):
107
- """Residual block module in HiFiGAN/BigVGAN."""
108
- def __init__(
109
- self,
110
- channels: int = 512,
111
- kernel_size: int = 3,
112
- dilations: List[int] = [1, 3, 5],
113
- ):
114
- super(ResBlock, self).__init__()
115
- self.convs1 = nn.ModuleList()
116
- self.convs2 = nn.ModuleList()
117
-
118
- for dilation in dilations:
119
- self.convs1.append(
120
- weight_norm(
121
- Conv1d(
122
- channels,
123
- channels,
124
- kernel_size,
125
- 1,
126
- dilation=dilation,
127
- padding=get_padding(kernel_size, dilation)
128
- )
129
- )
130
- )
131
- self.convs2.append(
132
- weight_norm(
133
- Conv1d(
134
- channels,
135
- channels,
136
- kernel_size,
137
- 1,
138
- dilation=1,
139
- padding=get_padding(kernel_size, 1)
140
- )
141
- )
142
- )
143
- self.convs1.apply(init_weights)
144
- self.convs2.apply(init_weights)
145
- self.activations1 = nn.ModuleList([
146
- Snake(channels, alpha_logscale=False)
147
- for _ in range(len(self.convs1))
148
- ])
149
- self.activations2 = nn.ModuleList([
150
- Snake(channels, alpha_logscale=False)
151
- for _ in range(len(self.convs2))
152
- ])
153
-
154
- def forward(self, x: torch.Tensor) -> torch.Tensor:
155
- for idx in range(len(self.convs1)):
156
- xt = self.activations1[idx](x)
157
- xt = self.convs1[idx](xt)
158
- xt = self.activations2[idx](xt)
159
- xt = self.convs2[idx](xt)
160
- x = xt + x
161
- return x
162
-
163
- def remove_weight_norm(self):
164
- for idx in range(len(self.convs1)):
165
- remove_weight_norm(self.convs1[idx])
166
- remove_weight_norm(self.convs2[idx])
167
-
168
-
169
- class SineGen(torch.nn.Module):
170
- """ Definition of sine generator
171
- SineGen(samp_rate, harmonic_num = 0,
172
- sine_amp = 0.1, noise_std = 0.003,
173
- voiced_threshold = 0,
174
- flag_for_pulse=False)
175
- samp_rate: sampling rate in Hz
176
- harmonic_num: number of harmonic overtones (default 0)
177
- sine_amp: amplitude of sine-wavefrom (default 0.1)
178
- noise_std: std of Gaussian noise (default 0.003)
179
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
180
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
181
- Note: when flag_for_pulse is True, the first time step of a voiced
182
- segment is always sin(np.pi) or cos(0)
183
- """
184
-
185
- def __init__(self, samp_rate, harmonic_num=0,
186
- sine_amp=0.1, noise_std=0.003,
187
- voiced_threshold=0):
188
- super(SineGen, self).__init__()
189
- self.sine_amp = sine_amp
190
- self.noise_std = noise_std
191
- self.harmonic_num = harmonic_num
192
- self.sampling_rate = samp_rate
193
- self.voiced_threshold = voiced_threshold
194
-
195
- def _f02uv(self, f0):
196
- # generate uv signal
197
- uv = (f0 > self.voiced_threshold).type(torch.float32)
198
- return uv
199
-
200
- @torch.no_grad()
201
- def forward(self, f0):
202
- """
203
- :param f0: [B, 1, sample_len], Hz
204
- :return: [B, 1, sample_len]
205
- """
206
-
207
- F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
208
- for i in range(self.harmonic_num + 1):
209
- F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
210
-
211
- theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
212
- u_dist = Uniform(low=-np.pi, high=np.pi)
213
- phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
214
- phase_vec[:, 0, :] = 0
215
-
216
- # generate sine waveforms
217
- sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
218
-
219
- # generate uv signal
220
- uv = self._f02uv(f0)
221
-
222
- # noise: for unvoiced should be similar to sine_amp
223
- # std = self.sine_amp/3 -> max value ~ self.sine_amp
224
- # . for voiced regions is self.noise_std
225
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
226
- noise = noise_amp * torch.randn_like(sine_waves)
227
-
228
- # first: set the unvoiced part to 0 by uv
229
- # then: additive noise
230
- sine_waves = sine_waves * uv + noise
231
- return sine_waves, uv, noise
232
-
233
-
234
- class SourceModuleHnNSF(torch.nn.Module):
235
- """ SourceModule for hn-nsf
236
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
237
- add_noise_std=0.003, voiced_threshod=0)
238
- sampling_rate: sampling_rate in Hz
239
- harmonic_num: number of harmonic above F0 (default: 0)
240
- sine_amp: amplitude of sine source signal (default: 0.1)
241
- add_noise_std: std of additive Gaussian noise (default: 0.003)
242
- note that amplitude of noise in unvoiced is decided
243
- by sine_amp
244
- voiced_threshold: threhold to set U/V given F0 (default: 0)
245
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
246
- F0_sampled (batchsize, length, 1)
247
- Sine_source (batchsize, length, 1)
248
- noise_source (batchsize, length 1)
249
- uv (batchsize, length, 1)
250
- """
251
-
252
- def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
253
- add_noise_std=0.003, voiced_threshod=0):
254
- super(SourceModuleHnNSF, self).__init__()
255
-
256
- self.sine_amp = sine_amp
257
- self.noise_std = add_noise_std
258
-
259
- # to produce sine waveforms
260
- self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
261
- sine_amp, add_noise_std, voiced_threshod)
262
-
263
- # to merge source harmonics into a single excitation
264
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
265
- self.l_tanh = torch.nn.Tanh()
266
-
267
- def forward(self, x):
268
- """
269
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
270
- F0_sampled (batchsize, length, 1)
271
- Sine_source (batchsize, length, 1)
272
- noise_source (batchsize, length 1)
273
- """
274
- # source for harmonic branch
275
- with torch.no_grad():
276
- sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
277
- sine_wavs = sine_wavs.transpose(1, 2)
278
- uv = uv.transpose(1, 2)
279
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
280
-
281
- # source for noise branch, in the same shape as uv
282
- noise = torch.randn_like(uv) * self.sine_amp / 3
283
- return sine_merge, noise, uv
284
-
285
-
286
- class HiFTGenerator(nn.Module):
287
- """
288
- HiFTNet Generator: Neural Source Filter + ISTFTNet
289
- https://arxiv.org/abs/2309.09493
290
- """
291
- def __init__(
292
- self,
293
- in_channels: int = 80,
294
- base_channels: int = 512,
295
- nb_harmonics: int = 8,
296
- sampling_rate: int = 22050,
297
- nsf_alpha: float = 0.1,
298
- nsf_sigma: float = 0.003,
299
- nsf_voiced_threshold: float = 10,
300
- upsample_rates: List[int] = [8, 8],
301
- upsample_kernel_sizes: List[int] = [16, 16],
302
- istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
303
- resblock_kernel_sizes: List[int] = [3, 7, 11],
304
- resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
305
- source_resblock_kernel_sizes: List[int] = [7, 11],
306
- source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
307
- lrelu_slope: float = 0.1,
308
- audio_limit: float = 0.99,
309
- f0_predictor: torch.nn.Module = None,
310
- ):
311
- super(HiFTGenerator, self).__init__()
312
-
313
- self.out_channels = 1
314
- self.nb_harmonics = nb_harmonics
315
- self.sampling_rate = sampling_rate
316
- self.istft_params = istft_params
317
- self.lrelu_slope = lrelu_slope
318
- self.audio_limit = audio_limit
319
-
320
- self.num_kernels = len(resblock_kernel_sizes)
321
- self.num_upsamples = len(upsample_rates)
322
- self.m_source = SourceModuleHnNSF(
323
- sampling_rate=sampling_rate,
324
- upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
325
- harmonic_num=nb_harmonics,
326
- sine_amp=nsf_alpha,
327
- add_noise_std=nsf_sigma,
328
- voiced_threshod=nsf_voiced_threshold)
329
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
330
-
331
- self.conv_pre = weight_norm(
332
- Conv1d(in_channels, base_channels, 7, 1, padding=3)
333
- )
334
-
335
- # Up
336
- self.ups = nn.ModuleList()
337
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
338
- self.ups.append(
339
- weight_norm(
340
- ConvTranspose1d(
341
- base_channels // (2**i),
342
- base_channels // (2**(i + 1)),
343
- k,
344
- u,
345
- padding=(k - u) // 2,
346
- )
347
- )
348
- )
349
-
350
- # Down
351
- self.source_downs = nn.ModuleList()
352
- self.source_resblocks = nn.ModuleList()
353
- downsample_rates = [1] + upsample_rates[::-1][:-1]
354
- downsample_cum_rates = np.cumprod(downsample_rates)
355
- for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
356
- if u == 1:
357
- self.source_downs.append(
358
- Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
359
- )
360
- else:
361
- self.source_downs.append(
362
- Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
363
- )
364
-
365
- self.source_resblocks.append(
366
- ResBlock(base_channels // (2 ** (i + 1)), k, d)
367
- )
368
-
369
- self.resblocks = nn.ModuleList()
370
- for i in range(len(self.ups)):
371
- ch = base_channels // (2**(i + 1))
372
- for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
373
- self.resblocks.append(ResBlock(ch, k, d))
374
-
375
- self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
376
- self.ups.apply(init_weights)
377
- self.conv_post.apply(init_weights)
378
- self.reflection_pad = nn.ReflectionPad1d((1, 0))
379
- self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
380
- self.f0_predictor = f0_predictor
381
-
382
- def remove_weight_norm(self):
383
- print('Removing weight norm...')
384
- for l in self.ups:
385
- remove_weight_norm(l)
386
- for l in self.resblocks:
387
- l.remove_weight_norm()
388
- remove_weight_norm(self.conv_pre)
389
- remove_weight_norm(self.conv_post)
390
- self.m_source.remove_weight_norm()
391
- for l in self.source_downs:
392
- remove_weight_norm(l)
393
- for l in self.source_resblocks:
394
- l.remove_weight_norm()
395
-
396
- def _stft(self, x):
397
- spec = torch.stft(
398
- x,
399
- self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
400
- return_complex=True)
401
- spec = torch.view_as_real(spec) # [B, F, TT, 2]
402
- return spec[..., 0], spec[..., 1]
403
-
404
- def _istft(self, magnitude, phase):
405
- magnitude = torch.clip(magnitude, max=1e2)
406
- real = magnitude * torch.cos(phase)
407
- img = magnitude * torch.sin(phase)
408
- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
409
- self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
410
- return inverse_transform
411
-
412
- def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
413
- s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
414
- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
415
-
416
- x = self.conv_pre(x)
417
- for i in range(self.num_upsamples):
418
- x = F.leaky_relu(x, self.lrelu_slope)
419
- x = self.ups[i](x)
420
-
421
- if i == self.num_upsamples - 1:
422
- x = self.reflection_pad(x)
423
-
424
- # fusion
425
- si = self.source_downs[i](s_stft)
426
- si = self.source_resblocks[i](si)
427
- x = x + si
428
-
429
- xs = None
430
- for j in range(self.num_kernels):
431
- if xs is None:
432
- xs = self.resblocks[i * self.num_kernels + j](x)
433
- else:
434
- xs += self.resblocks[i * self.num_kernels + j](x)
435
- x = xs / self.num_kernels
436
-
437
- x = F.leaky_relu(x)
438
- x = self.conv_post(x)
439
- magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
440
- phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
441
-
442
- x = self._istft(magnitude, phase)
443
- x = torch.clamp(x, -self.audio_limit, self.audio_limit)
444
- return x
445
-
446
- def forward(
447
- self,
448
- batch: dict,
449
- device: torch.device,
450
- ) -> Dict[str, Optional[torch.Tensor]]:
451
- speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
452
- # mel->f0
453
- f0 = self.f0_predictor(speech_feat)
454
- # f0->source
455
- s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
456
- s, _, _ = self.m_source(s)
457
- s = s.transpose(1, 2)
458
- # mel+source->speech
459
- generated_speech = self.decode(x=speech_feat, s=s)
460
- return generated_speech, f0
461
-
462
- @torch.inference_mode()
463
- def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
464
- # mel->f0
465
- f0 = self.f0_predictor(speech_feat)
466
- # f0->source
467
- s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
468
- s, _, _ = self.m_source(s)
469
- s = s.transpose(1, 2)
470
- # use cache_source to avoid glitch
471
- if cache_source.shape[2] != 0:
472
- s[:, :, :cache_source.shape[2]] = cache_source
473
- generated_speech = self.decode(x=speech_feat, s=s)
474
- return generated_speech, s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/matcha/decoder.py DELETED
@@ -1,443 +0,0 @@
1
- import math
2
- from typing import Optional
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from conformer import ConformerBlock
8
- from diffusers.models.activations import get_activation
9
- from einops import pack, rearrange, repeat
10
-
11
- from .transformer import BasicTransformerBlock
12
-
13
-
14
- class SinusoidalPosEmb(torch.nn.Module):
15
- def __init__(self, dim):
16
- super().__init__()
17
- self.dim = dim
18
- assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
19
-
20
- def forward(self, x, scale=1000):
21
- if x.ndim < 1:
22
- x = x.unsqueeze(0)
23
- device = x.device
24
- half_dim = self.dim // 2
25
- emb = math.log(10000) / (half_dim - 1)
26
- emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
27
- emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
28
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
29
- return emb
30
-
31
-
32
- class Block1D(torch.nn.Module):
33
- def __init__(self, dim, dim_out, groups=8):
34
- super().__init__()
35
- self.block = torch.nn.Sequential(
36
- torch.nn.Conv1d(dim, dim_out, 3, padding=1),
37
- torch.nn.GroupNorm(groups, dim_out),
38
- nn.Mish(),
39
- )
40
-
41
- def forward(self, x, mask):
42
- output = self.block(x * mask)
43
- return output * mask
44
-
45
-
46
- class ResnetBlock1D(torch.nn.Module):
47
- def __init__(self, dim, dim_out, time_emb_dim, groups=8):
48
- super().__init__()
49
- self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
50
-
51
- self.block1 = Block1D(dim, dim_out, groups=groups)
52
- self.block2 = Block1D(dim_out, dim_out, groups=groups)
53
-
54
- self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
55
-
56
- def forward(self, x, mask, time_emb):
57
- h = self.block1(x, mask)
58
- h += self.mlp(time_emb).unsqueeze(-1)
59
- h = self.block2(h, mask)
60
- output = h + self.res_conv(x * mask)
61
- return output
62
-
63
-
64
- class Downsample1D(nn.Module):
65
- def __init__(self, dim):
66
- super().__init__()
67
- self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
68
-
69
- def forward(self, x):
70
- return self.conv(x)
71
-
72
-
73
- class TimestepEmbedding(nn.Module):
74
- def __init__(
75
- self,
76
- in_channels: int,
77
- time_embed_dim: int,
78
- act_fn: str = "silu",
79
- out_dim: int = None,
80
- post_act_fn: Optional[str] = None,
81
- cond_proj_dim=None,
82
- ):
83
- super().__init__()
84
-
85
- self.linear_1 = nn.Linear(in_channels, time_embed_dim)
86
-
87
- if cond_proj_dim is not None:
88
- self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
89
- else:
90
- self.cond_proj = None
91
-
92
- self.act = get_activation(act_fn)
93
-
94
- if out_dim is not None:
95
- time_embed_dim_out = out_dim
96
- else:
97
- time_embed_dim_out = time_embed_dim
98
- self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
99
-
100
- if post_act_fn is None:
101
- self.post_act = None
102
- else:
103
- self.post_act = get_activation(post_act_fn)
104
-
105
- def forward(self, sample, condition=None):
106
- if condition is not None:
107
- sample = sample + self.cond_proj(condition)
108
- sample = self.linear_1(sample)
109
-
110
- if self.act is not None:
111
- sample = self.act(sample)
112
-
113
- sample = self.linear_2(sample)
114
-
115
- if self.post_act is not None:
116
- sample = self.post_act(sample)
117
- return sample
118
-
119
-
120
- class Upsample1D(nn.Module):
121
- """A 1D upsampling layer with an optional convolution.
122
-
123
- Parameters:
124
- channels (`int`):
125
- number of channels in the inputs and outputs.
126
- use_conv (`bool`, default `False`):
127
- option to use a convolution.
128
- use_conv_transpose (`bool`, default `False`):
129
- option to use a convolution transpose.
130
- out_channels (`int`, optional):
131
- number of output channels. Defaults to `channels`.
132
- """
133
-
134
- def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
135
- super().__init__()
136
- self.channels = channels
137
- self.out_channels = out_channels or channels
138
- self.use_conv = use_conv
139
- self.use_conv_transpose = use_conv_transpose
140
- self.name = name
141
-
142
- self.conv = None
143
- if use_conv_transpose:
144
- self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
145
- elif use_conv:
146
- self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
147
-
148
- def forward(self, inputs):
149
- assert inputs.shape[1] == self.channels
150
- if self.use_conv_transpose:
151
- return self.conv(inputs)
152
-
153
- outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
154
-
155
- if self.use_conv:
156
- outputs = self.conv(outputs)
157
-
158
- return outputs
159
-
160
-
161
- class ConformerWrapper(ConformerBlock):
162
- def __init__( # pylint: disable=useless-super-delegation
163
- self,
164
- *,
165
- dim,
166
- dim_head=64,
167
- heads=8,
168
- ff_mult=4,
169
- conv_expansion_factor=2,
170
- conv_kernel_size=31,
171
- attn_dropout=0,
172
- ff_dropout=0,
173
- conv_dropout=0,
174
- conv_causal=False,
175
- ):
176
- super().__init__(
177
- dim=dim,
178
- dim_head=dim_head,
179
- heads=heads,
180
- ff_mult=ff_mult,
181
- conv_expansion_factor=conv_expansion_factor,
182
- conv_kernel_size=conv_kernel_size,
183
- attn_dropout=attn_dropout,
184
- ff_dropout=ff_dropout,
185
- conv_dropout=conv_dropout,
186
- conv_causal=conv_causal,
187
- )
188
-
189
- def forward(
190
- self,
191
- hidden_states,
192
- attention_mask,
193
- encoder_hidden_states=None,
194
- encoder_attention_mask=None,
195
- timestep=None,
196
- ):
197
- return super().forward(x=hidden_states, mask=attention_mask.bool())
198
-
199
-
200
- class Decoder(nn.Module):
201
- def __init__(
202
- self,
203
- in_channels,
204
- out_channels,
205
- channels=(256, 256),
206
- dropout=0.05,
207
- attention_head_dim=64,
208
- n_blocks=1,
209
- num_mid_blocks=2,
210
- num_heads=4,
211
- act_fn="snake",
212
- down_block_type="transformer",
213
- mid_block_type="transformer",
214
- up_block_type="transformer",
215
- ):
216
- super().__init__()
217
- channels = tuple(channels)
218
- self.in_channels = in_channels
219
- self.out_channels = out_channels
220
-
221
- self.time_embeddings = SinusoidalPosEmb(in_channels)
222
- time_embed_dim = channels[0] * 4
223
- self.time_mlp = TimestepEmbedding(
224
- in_channels=in_channels,
225
- time_embed_dim=time_embed_dim,
226
- act_fn="silu",
227
- )
228
-
229
- self.down_blocks = nn.ModuleList([])
230
- self.mid_blocks = nn.ModuleList([])
231
- self.up_blocks = nn.ModuleList([])
232
-
233
- output_channel = in_channels
234
- for i in range(len(channels)): # pylint: disable=consider-using-enumerate
235
- input_channel = output_channel
236
- output_channel = channels[i]
237
- is_last = i == len(channels) - 1
238
- resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
239
- transformer_blocks = nn.ModuleList(
240
- [
241
- self.get_block(
242
- down_block_type,
243
- output_channel,
244
- attention_head_dim,
245
- num_heads,
246
- dropout,
247
- act_fn,
248
- )
249
- for _ in range(n_blocks)
250
- ]
251
- )
252
- downsample = (
253
- Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
254
- )
255
-
256
- self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
257
-
258
- for i in range(num_mid_blocks):
259
- input_channel = channels[-1]
260
- out_channels = channels[-1]
261
-
262
- resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
263
-
264
- transformer_blocks = nn.ModuleList(
265
- [
266
- self.get_block(
267
- mid_block_type,
268
- output_channel,
269
- attention_head_dim,
270
- num_heads,
271
- dropout,
272
- act_fn,
273
- )
274
- for _ in range(n_blocks)
275
- ]
276
- )
277
-
278
- self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
279
-
280
- channels = channels[::-1] + (channels[0],)
281
- for i in range(len(channels) - 1):
282
- input_channel = channels[i]
283
- output_channel = channels[i + 1]
284
- is_last = i == len(channels) - 2
285
-
286
- resnet = ResnetBlock1D(
287
- dim=2 * input_channel,
288
- dim_out=output_channel,
289
- time_emb_dim=time_embed_dim,
290
- )
291
- transformer_blocks = nn.ModuleList(
292
- [
293
- self.get_block(
294
- up_block_type,
295
- output_channel,
296
- attention_head_dim,
297
- num_heads,
298
- dropout,
299
- act_fn,
300
- )
301
- for _ in range(n_blocks)
302
- ]
303
- )
304
- upsample = (
305
- Upsample1D(output_channel, use_conv_transpose=True)
306
- if not is_last
307
- else nn.Conv1d(output_channel, output_channel, 3, padding=1)
308
- )
309
-
310
- self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
311
-
312
- self.final_block = Block1D(channels[-1], channels[-1])
313
- self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
314
-
315
- self.initialize_weights()
316
- # nn.init.normal_(self.final_proj.weight)
317
-
318
- @staticmethod
319
- def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
320
- if block_type == "conformer":
321
- block = ConformerWrapper(
322
- dim=dim,
323
- dim_head=attention_head_dim,
324
- heads=num_heads,
325
- ff_mult=1,
326
- conv_expansion_factor=2,
327
- ff_dropout=dropout,
328
- attn_dropout=dropout,
329
- conv_dropout=dropout,
330
- conv_kernel_size=31,
331
- )
332
- elif block_type == "transformer":
333
- block = BasicTransformerBlock(
334
- dim=dim,
335
- num_attention_heads=num_heads,
336
- attention_head_dim=attention_head_dim,
337
- dropout=dropout,
338
- activation_fn=act_fn,
339
- )
340
- else:
341
- raise ValueError(f"Unknown block type {block_type}")
342
-
343
- return block
344
-
345
- def initialize_weights(self):
346
- for m in self.modules():
347
- if isinstance(m, nn.Conv1d):
348
- nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
349
-
350
- if m.bias is not None:
351
- nn.init.constant_(m.bias, 0)
352
-
353
- elif isinstance(m, nn.GroupNorm):
354
- nn.init.constant_(m.weight, 1)
355
- nn.init.constant_(m.bias, 0)
356
-
357
- elif isinstance(m, nn.Linear):
358
- nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
359
-
360
- if m.bias is not None:
361
- nn.init.constant_(m.bias, 0)
362
-
363
- def forward(self, x, mask, mu, t, spks=None, cond=None):
364
- """Forward pass of the UNet1DConditional model.
365
-
366
- Args:
367
- x (torch.Tensor): shape (batch_size, in_channels, time)
368
- mask (_type_): shape (batch_size, 1, time)
369
- t (_type_): shape (batch_size)
370
- spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
371
- cond (_type_, optional): placeholder for future use. Defaults to None.
372
-
373
- Raises:
374
- ValueError: _description_
375
- ValueError: _description_
376
-
377
- Returns:
378
- _type_: _description_
379
- """
380
-
381
- t = self.time_embeddings(t)
382
- t = self.time_mlp(t)
383
-
384
- x = pack([x, mu], "b * t")[0]
385
-
386
- if spks is not None:
387
- spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
388
- x = pack([x, spks], "b * t")[0]
389
-
390
- hiddens = []
391
- masks = [mask]
392
- for resnet, transformer_blocks, downsample in self.down_blocks:
393
- mask_down = masks[-1]
394
- x = resnet(x, mask_down, t)
395
- x = rearrange(x, "b c t -> b t c")
396
- mask_down = rearrange(mask_down, "b 1 t -> b t")
397
- for transformer_block in transformer_blocks:
398
- x = transformer_block(
399
- hidden_states=x,
400
- attention_mask=mask_down,
401
- timestep=t,
402
- )
403
- x = rearrange(x, "b t c -> b c t")
404
- mask_down = rearrange(mask_down, "b t -> b 1 t")
405
- hiddens.append(x) # Save hidden states for skip connections
406
- x = downsample(x * mask_down)
407
- masks.append(mask_down[:, :, ::2])
408
-
409
- masks = masks[:-1]
410
- mask_mid = masks[-1]
411
-
412
- for resnet, transformer_blocks in self.mid_blocks:
413
- x = resnet(x, mask_mid, t)
414
- x = rearrange(x, "b c t -> b t c")
415
- mask_mid = rearrange(mask_mid, "b 1 t -> b t")
416
- for transformer_block in transformer_blocks:
417
- x = transformer_block(
418
- hidden_states=x,
419
- attention_mask=mask_mid,
420
- timestep=t,
421
- )
422
- x = rearrange(x, "b t c -> b c t")
423
- mask_mid = rearrange(mask_mid, "b t -> b 1 t")
424
-
425
- for resnet, transformer_blocks, upsample in self.up_blocks:
426
- mask_up = masks.pop()
427
- x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
428
- x = rearrange(x, "b c t -> b t c")
429
- mask_up = rearrange(mask_up, "b 1 t -> b t")
430
- for transformer_block in transformer_blocks:
431
- x = transformer_block(
432
- hidden_states=x,
433
- attention_mask=mask_up,
434
- timestep=t,
435
- )
436
- x = rearrange(x, "b t c -> b c t")
437
- mask_up = rearrange(mask_up, "b t -> b 1 t")
438
- x = upsample(x * mask_up)
439
-
440
- x = self.final_block(x, mask_up)
441
- output = self.final_proj(x * mask_up)
442
-
443
- return output * mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/matcha/flow_matching.py DELETED
@@ -1,129 +0,0 @@
1
- from abc import ABC
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from .decoder import Decoder
7
-
8
-
9
- class BASECFM(torch.nn.Module, ABC):
10
- def __init__(
11
- self,
12
- n_feats,
13
- cfm_params,
14
- n_spks=1,
15
- spk_emb_dim=128,
16
- ):
17
- super().__init__()
18
- self.n_feats = n_feats
19
- self.n_spks = n_spks
20
- self.spk_emb_dim = spk_emb_dim
21
- self.solver = cfm_params.solver
22
- if hasattr(cfm_params, "sigma_min"):
23
- self.sigma_min = cfm_params.sigma_min
24
- else:
25
- self.sigma_min = 1e-4
26
-
27
- self.estimator = None
28
-
29
- @torch.inference_mode()
30
- def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
31
- """Forward diffusion
32
-
33
- Args:
34
- mu (torch.Tensor): output of encoder
35
- shape: (batch_size, n_feats, mel_timesteps)
36
- mask (torch.Tensor): output_mask
37
- shape: (batch_size, 1, mel_timesteps)
38
- n_timesteps (int): number of diffusion steps
39
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
40
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
41
- shape: (batch_size, spk_emb_dim)
42
- cond: Not used but kept for future purposes
43
-
44
- Returns:
45
- sample: generated mel-spectrogram
46
- shape: (batch_size, n_feats, mel_timesteps)
47
- """
48
- z = torch.randn_like(mu) * temperature
49
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
50
- return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
51
-
52
- def solve_euler(self, x, t_span, mu, mask, spks, cond):
53
- """
54
- Fixed euler solver for ODEs.
55
- Args:
56
- x (torch.Tensor): random noise
57
- t_span (torch.Tensor): n_timesteps interpolated
58
- shape: (n_timesteps + 1,)
59
- mu (torch.Tensor): output of encoder
60
- shape: (batch_size, n_feats, mel_timesteps)
61
- mask (torch.Tensor): output_mask
62
- shape: (batch_size, 1, mel_timesteps)
63
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
64
- shape: (batch_size, spk_emb_dim)
65
- cond: Not used but kept for future purposes
66
- """
67
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
68
-
69
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
70
- # Or in future might add like a return_all_steps flag
71
- sol = []
72
-
73
- for step in range(1, len(t_span)):
74
- dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
75
-
76
- x = x + dt * dphi_dt
77
- t = t + dt
78
- sol.append(x)
79
- if step < len(t_span) - 1:
80
- dt = t_span[step + 1] - t
81
-
82
- return sol[-1]
83
-
84
- def compute_loss(self, x1, mask, mu, spks=None, cond=None):
85
- """Computes diffusion loss
86
-
87
- Args:
88
- x1 (torch.Tensor): Target
89
- shape: (batch_size, n_feats, mel_timesteps)
90
- mask (torch.Tensor): target mask
91
- shape: (batch_size, 1, mel_timesteps)
92
- mu (torch.Tensor): output of encoder
93
- shape: (batch_size, n_feats, mel_timesteps)
94
- spks (torch.Tensor, optional): speaker embedding. Defaults to None.
95
- shape: (batch_size, spk_emb_dim)
96
-
97
- Returns:
98
- loss: conditional flow matching loss
99
- y: conditional flow
100
- shape: (batch_size, n_feats, mel_timesteps)
101
- """
102
- b, _, t = mu.shape
103
-
104
- # random timestep
105
- t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
106
- # sample noise p(x_0)
107
- z = torch.randn_like(x1)
108
-
109
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
110
- u = x1 - (1 - self.sigma_min) * z
111
-
112
- loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
113
- torch.sum(mask) * u.shape[1]
114
- )
115
- return loss, y
116
-
117
-
118
- class CFM(BASECFM):
119
- def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
120
- super().__init__(
121
- n_feats=in_channels,
122
- cfm_params=cfm_params,
123
- n_spks=n_spks,
124
- spk_emb_dim=spk_emb_dim,
125
- )
126
-
127
- in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
128
- # Just change the architecture of the estimator here
129
- self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/matcha/text_encoder.py DELETED
@@ -1,413 +0,0 @@
1
- """ from https://github.com/jaywalnut310/glow-tts """
2
-
3
- import math
4
-
5
- import torch
6
- import torch.nn as nn
7
- from einops import rearrange
8
-
9
-
10
- def sequence_mask(length, max_length=None):
11
- if max_length is None:
12
- max_length = length.max()
13
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
14
- return x.unsqueeze(0) < length.unsqueeze(1)
15
-
16
-
17
-
18
- class LayerNorm(nn.Module):
19
- def __init__(self, channels, eps=1e-4):
20
- super().__init__()
21
- self.channels = channels
22
- self.eps = eps
23
-
24
- self.gamma = torch.nn.Parameter(torch.ones(channels))
25
- self.beta = torch.nn.Parameter(torch.zeros(channels))
26
-
27
- def forward(self, x):
28
- n_dims = len(x.shape)
29
- mean = torch.mean(x, 1, keepdim=True)
30
- variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
31
-
32
- x = (x - mean) * torch.rsqrt(variance + self.eps)
33
-
34
- shape = [1, -1] + [1] * (n_dims - 2)
35
- x = x * self.gamma.view(*shape) + self.beta.view(*shape)
36
- return x
37
-
38
-
39
- class ConvReluNorm(nn.Module):
40
- def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
41
- super().__init__()
42
- self.in_channels = in_channels
43
- self.hidden_channels = hidden_channels
44
- self.out_channels = out_channels
45
- self.kernel_size = kernel_size
46
- self.n_layers = n_layers
47
- self.p_dropout = p_dropout
48
-
49
- self.conv_layers = torch.nn.ModuleList()
50
- self.norm_layers = torch.nn.ModuleList()
51
- self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
52
- self.norm_layers.append(LayerNorm(hidden_channels))
53
- self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
54
- for _ in range(n_layers - 1):
55
- self.conv_layers.append(
56
- torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
57
- )
58
- self.norm_layers.append(LayerNorm(hidden_channels))
59
- self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
60
- self.proj.weight.data.zero_()
61
- self.proj.bias.data.zero_()
62
-
63
- def forward(self, x, x_mask):
64
- x_org = x
65
- for i in range(self.n_layers):
66
- x = self.conv_layers[i](x * x_mask)
67
- x = self.norm_layers[i](x)
68
- x = self.relu_drop(x)
69
- x = x_org + self.proj(x)
70
- return x * x_mask
71
-
72
-
73
- class DurationPredictor(nn.Module):
74
- def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
75
- super().__init__()
76
- self.in_channels = in_channels
77
- self.filter_channels = filter_channels
78
- self.p_dropout = p_dropout
79
-
80
- self.drop = torch.nn.Dropout(p_dropout)
81
- self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
82
- self.norm_1 = LayerNorm(filter_channels)
83
- self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
84
- self.norm_2 = LayerNorm(filter_channels)
85
- self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
86
-
87
- def forward(self, x, x_mask):
88
- x = self.conv_1(x * x_mask)
89
- x = torch.relu(x)
90
- x = self.norm_1(x)
91
- x = self.drop(x)
92
- x = self.conv_2(x * x_mask)
93
- x = torch.relu(x)
94
- x = self.norm_2(x)
95
- x = self.drop(x)
96
- x = self.proj(x * x_mask)
97
- return x * x_mask
98
-
99
-
100
- class RotaryPositionalEmbeddings(nn.Module):
101
- """
102
- ## RoPE module
103
-
104
- Rotary encoding transforms pairs of features by rotating in the 2D plane.
105
- That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
106
- Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
107
- by an angle depending on the position of the token.
108
- """
109
-
110
- def __init__(self, d: int, base: int = 10_000):
111
- r"""
112
- * `d` is the number of features $d$
113
- * `base` is the constant used for calculating $\Theta$
114
- """
115
- super().__init__()
116
-
117
- self.base = base
118
- self.d = int(d)
119
- self.cos_cached = None
120
- self.sin_cached = None
121
-
122
- def _build_cache(self, x: torch.Tensor):
123
- r"""
124
- Cache $\cos$ and $\sin$ values
125
- """
126
- # Return if cache is already built
127
- if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
128
- return
129
-
130
- # Get sequence length
131
- seq_len = x.shape[0]
132
-
133
- # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
134
- theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
135
-
136
- # Create position indexes `[0, 1, ..., seq_len - 1]`
137
- seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
138
-
139
- # Calculate the product of position index and $\theta_i$
140
- idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
141
-
142
- # Concatenate so that for row $m$ we have
143
- # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
144
- idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
145
-
146
- # Cache them
147
- self.cos_cached = idx_theta2.cos()[:, None, None, :]
148
- self.sin_cached = idx_theta2.sin()[:, None, None, :]
149
-
150
- def _neg_half(self, x: torch.Tensor):
151
- # $\frac{d}{2}$
152
- d_2 = self.d // 2
153
-
154
- # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
155
- return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
156
-
157
- def forward(self, x: torch.Tensor):
158
- """
159
- * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
160
- """
161
- # Cache $\cos$ and $\sin$ values
162
- x = rearrange(x, "b h t d -> t b h d")
163
-
164
- self._build_cache(x)
165
-
166
- # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
167
- x_rope, x_pass = x[..., : self.d], x[..., self.d :]
168
-
169
- # Calculate
170
- # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
171
- neg_half_x = self._neg_half(x_rope)
172
-
173
- x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
174
-
175
- return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
176
-
177
-
178
- class MultiHeadAttention(nn.Module):
179
- def __init__(
180
- self,
181
- channels,
182
- out_channels,
183
- n_heads,
184
- heads_share=True,
185
- p_dropout=0.0,
186
- proximal_bias=False,
187
- proximal_init=False,
188
- ):
189
- super().__init__()
190
- assert channels % n_heads == 0
191
-
192
- self.channels = channels
193
- self.out_channels = out_channels
194
- self.n_heads = n_heads
195
- self.heads_share = heads_share
196
- self.proximal_bias = proximal_bias
197
- self.p_dropout = p_dropout
198
- self.attn = None
199
-
200
- self.k_channels = channels // n_heads
201
- self.conv_q = torch.nn.Conv1d(channels, channels, 1)
202
- self.conv_k = torch.nn.Conv1d(channels, channels, 1)
203
- self.conv_v = torch.nn.Conv1d(channels, channels, 1)
204
-
205
- # from https://nn.labml.ai/transformers/rope/index.html
206
- self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
207
- self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
208
-
209
- self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
210
- self.drop = torch.nn.Dropout(p_dropout)
211
-
212
- torch.nn.init.xavier_uniform_(self.conv_q.weight)
213
- torch.nn.init.xavier_uniform_(self.conv_k.weight)
214
- if proximal_init:
215
- self.conv_k.weight.data.copy_(self.conv_q.weight.data)
216
- self.conv_k.bias.data.copy_(self.conv_q.bias.data)
217
- torch.nn.init.xavier_uniform_(self.conv_v.weight)
218
-
219
- def forward(self, x, c, attn_mask=None):
220
- q = self.conv_q(x)
221
- k = self.conv_k(c)
222
- v = self.conv_v(c)
223
-
224
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
225
-
226
- x = self.conv_o(x)
227
- return x
228
-
229
- def attention(self, query, key, value, mask=None):
230
- b, d, t_s, t_t = (*key.size(), query.size(2))
231
- query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
232
- key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
233
- value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
234
-
235
- query = self.query_rotary_pe(query)
236
- key = self.key_rotary_pe(key)
237
-
238
- scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
239
-
240
- if self.proximal_bias:
241
- assert t_s == t_t, "Proximal bias is only available for self-attention."
242
- scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
243
- if mask is not None:
244
- scores = scores.masked_fill(mask == 0, -1e4)
245
- p_attn = torch.nn.functional.softmax(scores, dim=-1)
246
- p_attn = self.drop(p_attn)
247
- output = torch.matmul(p_attn, value)
248
- output = output.transpose(2, 3).contiguous().view(b, d, t_t)
249
- return output, p_attn
250
-
251
- @staticmethod
252
- def _attention_bias_proximal(length):
253
- r = torch.arange(length, dtype=torch.float32)
254
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
255
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
256
-
257
-
258
- class FFN(nn.Module):
259
- def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
260
- super().__init__()
261
- self.in_channels = in_channels
262
- self.out_channels = out_channels
263
- self.filter_channels = filter_channels
264
- self.kernel_size = kernel_size
265
- self.p_dropout = p_dropout
266
-
267
- self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
268
- self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
269
- self.drop = torch.nn.Dropout(p_dropout)
270
-
271
- def forward(self, x, x_mask):
272
- x = self.conv_1(x * x_mask)
273
- x = torch.relu(x)
274
- x = self.drop(x)
275
- x = self.conv_2(x * x_mask)
276
- return x * x_mask
277
-
278
-
279
- class Encoder(nn.Module):
280
- def __init__(
281
- self,
282
- hidden_channels,
283
- filter_channels,
284
- n_heads,
285
- n_layers,
286
- kernel_size=1,
287
- p_dropout=0.0,
288
- **kwargs,
289
- ):
290
- super().__init__()
291
- self.hidden_channels = hidden_channels
292
- self.filter_channels = filter_channels
293
- self.n_heads = n_heads
294
- self.n_layers = n_layers
295
- self.kernel_size = kernel_size
296
- self.p_dropout = p_dropout
297
-
298
- self.drop = torch.nn.Dropout(p_dropout)
299
- self.attn_layers = torch.nn.ModuleList()
300
- self.norm_layers_1 = torch.nn.ModuleList()
301
- self.ffn_layers = torch.nn.ModuleList()
302
- self.norm_layers_2 = torch.nn.ModuleList()
303
- for _ in range(self.n_layers):
304
- self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
305
- self.norm_layers_1.append(LayerNorm(hidden_channels))
306
- self.ffn_layers.append(
307
- FFN(
308
- hidden_channels,
309
- hidden_channels,
310
- filter_channels,
311
- kernel_size,
312
- p_dropout=p_dropout,
313
- )
314
- )
315
- self.norm_layers_2.append(LayerNorm(hidden_channels))
316
-
317
- def forward(self, x, x_mask):
318
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
319
- for i in range(self.n_layers):
320
- x = x * x_mask
321
- y = self.attn_layers[i](x, x, attn_mask)
322
- y = self.drop(y)
323
- x = self.norm_layers_1[i](x + y)
324
- y = self.ffn_layers[i](x, x_mask)
325
- y = self.drop(y)
326
- x = self.norm_layers_2[i](x + y)
327
- x = x * x_mask
328
- return x
329
-
330
-
331
- class TextEncoder(nn.Module):
332
- def __init__(
333
- self,
334
- encoder_type,
335
- encoder_params,
336
- duration_predictor_params,
337
- n_vocab,
338
- n_spks=1,
339
- spk_emb_dim=128,
340
- ):
341
- super().__init__()
342
- self.encoder_type = encoder_type
343
- self.n_vocab = n_vocab
344
- self.n_feats = encoder_params.n_feats
345
- self.n_channels = encoder_params.n_channels
346
- self.spk_emb_dim = spk_emb_dim
347
- self.n_spks = n_spks
348
-
349
- self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
350
- torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
351
-
352
- if encoder_params.prenet:
353
- self.prenet = ConvReluNorm(
354
- self.n_channels,
355
- self.n_channels,
356
- self.n_channels,
357
- kernel_size=5,
358
- n_layers=3,
359
- p_dropout=0.5,
360
- )
361
- else:
362
- self.prenet = lambda x, x_mask: x
363
-
364
- self.encoder = Encoder(
365
- encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
366
- encoder_params.filter_channels,
367
- encoder_params.n_heads,
368
- encoder_params.n_layers,
369
- encoder_params.kernel_size,
370
- encoder_params.p_dropout,
371
- )
372
-
373
- self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
374
- self.proj_w = DurationPredictor(
375
- self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
376
- duration_predictor_params.filter_channels_dp,
377
- duration_predictor_params.kernel_size,
378
- duration_predictor_params.p_dropout,
379
- )
380
-
381
- def forward(self, x, x_lengths, spks=None):
382
- """Run forward pass to the transformer based encoder and duration predictor
383
-
384
- Args:
385
- x (torch.Tensor): text input
386
- shape: (batch_size, max_text_length)
387
- x_lengths (torch.Tensor): text input lengths
388
- shape: (batch_size,)
389
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
390
- shape: (batch_size,)
391
-
392
- Returns:
393
- mu (torch.Tensor): average output of the encoder
394
- shape: (batch_size, n_feats, max_text_length)
395
- logw (torch.Tensor): log duration predicted by the duration predictor
396
- shape: (batch_size, 1, max_text_length)
397
- x_mask (torch.Tensor): mask for the text input
398
- shape: (batch_size, 1, max_text_length)
399
- """
400
- x = self.emb(x) * math.sqrt(self.n_channels)
401
- x = torch.transpose(x, 1, -1)
402
- x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
403
-
404
- x = self.prenet(x, x_mask)
405
- if self.n_spks > 1:
406
- x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
407
- x = self.encoder(x, x_mask)
408
- mu = self.proj_m(x) * x_mask
409
-
410
- x_dp = torch.detach(x)
411
- logw = self.proj_w(x_dp, x_mask)
412
-
413
- return mu, logw, x_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/matcha/transformer.py DELETED
@@ -1,316 +0,0 @@
1
- from typing import Any, Dict, Optional
2
-
3
- import torch
4
- import torch.nn as nn
5
- from diffusers.models.attention import (
6
- GEGLU,
7
- GELU,
8
- AdaLayerNorm,
9
- AdaLayerNormZero,
10
- ApproximateGELU,
11
- )
12
- from diffusers.models.attention_processor import Attention
13
- from diffusers.models.lora import LoRACompatibleLinear
14
- from diffusers.utils.torch_utils import maybe_allow_in_graph
15
-
16
-
17
- class SnakeBeta(nn.Module):
18
- """
19
- A modified Snake function which uses separate parameters for the magnitude of the periodic components
20
- Shape:
21
- - Input: (B, C, T)
22
- - Output: (B, C, T), same shape as the input
23
- Parameters:
24
- - alpha - trainable parameter that controls frequency
25
- - beta - trainable parameter that controls magnitude
26
- References:
27
- - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
28
- https://arxiv.org/abs/2006.08195
29
- Examples:
30
- >>> a1 = snakebeta(256)
31
- >>> x = torch.randn(256)
32
- >>> x = a1(x)
33
- """
34
-
35
- def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
36
- """
37
- Initialization.
38
- INPUT:
39
- - in_features: shape of the input
40
- - alpha - trainable parameter that controls frequency
41
- - beta - trainable parameter that controls magnitude
42
- alpha is initialized to 1 by default, higher values = higher-frequency.
43
- beta is initialized to 1 by default, higher values = higher-magnitude.
44
- alpha will be trained along with the rest of your model.
45
- """
46
- super().__init__()
47
- self.in_features = out_features if isinstance(out_features, list) else [out_features]
48
- self.proj = LoRACompatibleLinear(in_features, out_features)
49
-
50
- # initialize alpha
51
- self.alpha_logscale = alpha_logscale
52
- if self.alpha_logscale: # log scale alphas initialized to zeros
53
- self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
54
- self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
55
- else: # linear scale alphas initialized to ones
56
- self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
57
- self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
58
-
59
- self.alpha.requires_grad = alpha_trainable
60
- self.beta.requires_grad = alpha_trainable
61
-
62
- self.no_div_by_zero = 0.000000001
63
-
64
- def forward(self, x):
65
- """
66
- Forward pass of the function.
67
- Applies the function to the input elementwise.
68
- SnakeBeta ∶= x + 1/b * sin^2 (xa)
69
- """
70
- x = self.proj(x)
71
- if self.alpha_logscale:
72
- alpha = torch.exp(self.alpha)
73
- beta = torch.exp(self.beta)
74
- else:
75
- alpha = self.alpha
76
- beta = self.beta
77
-
78
- x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
79
-
80
- return x
81
-
82
-
83
- class FeedForward(nn.Module):
84
- r"""
85
- A feed-forward layer.
86
-
87
- Parameters:
88
- dim (`int`): The number of channels in the input.
89
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
90
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
91
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
92
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
93
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
94
- """
95
-
96
- def __init__(
97
- self,
98
- dim: int,
99
- dim_out: Optional[int] = None,
100
- mult: int = 4,
101
- dropout: float = 0.0,
102
- activation_fn: str = "geglu",
103
- final_dropout: bool = False,
104
- ):
105
- super().__init__()
106
- inner_dim = int(dim * mult)
107
- dim_out = dim_out if dim_out is not None else dim
108
-
109
- if activation_fn == "gelu":
110
- act_fn = GELU(dim, inner_dim)
111
- if activation_fn == "gelu-approximate":
112
- act_fn = GELU(dim, inner_dim, approximate="tanh")
113
- elif activation_fn == "geglu":
114
- act_fn = GEGLU(dim, inner_dim)
115
- elif activation_fn == "geglu-approximate":
116
- act_fn = ApproximateGELU(dim, inner_dim)
117
- elif activation_fn == "snakebeta":
118
- act_fn = SnakeBeta(dim, inner_dim)
119
-
120
- self.net = nn.ModuleList([])
121
- # project in
122
- self.net.append(act_fn)
123
- # project dropout
124
- self.net.append(nn.Dropout(dropout))
125
- # project out
126
- self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
127
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
128
- if final_dropout:
129
- self.net.append(nn.Dropout(dropout))
130
-
131
- def forward(self, hidden_states):
132
- for module in self.net:
133
- hidden_states = module(hidden_states)
134
- return hidden_states
135
-
136
-
137
- @maybe_allow_in_graph
138
- class BasicTransformerBlock(nn.Module):
139
- r"""
140
- A basic Transformer block.
141
-
142
- Parameters:
143
- dim (`int`): The number of channels in the input and output.
144
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
145
- attention_head_dim (`int`): The number of channels in each head.
146
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
147
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
148
- only_cross_attention (`bool`, *optional*):
149
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
150
- double_self_attention (`bool`, *optional*):
151
- Whether to use two self-attention layers. In this case no cross attention layers are used.
152
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
153
- num_embeds_ada_norm (:
154
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
155
- attention_bias (:
156
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
157
- """
158
-
159
- def __init__(
160
- self,
161
- dim: int,
162
- num_attention_heads: int,
163
- attention_head_dim: int,
164
- dropout=0.0,
165
- cross_attention_dim: Optional[int] = None,
166
- activation_fn: str = "geglu",
167
- num_embeds_ada_norm: Optional[int] = None,
168
- attention_bias: bool = False,
169
- only_cross_attention: bool = False,
170
- double_self_attention: bool = False,
171
- upcast_attention: bool = False,
172
- norm_elementwise_affine: bool = True,
173
- norm_type: str = "layer_norm",
174
- final_dropout: bool = False,
175
- ):
176
- super().__init__()
177
- self.only_cross_attention = only_cross_attention
178
-
179
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
180
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
181
-
182
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
183
- raise ValueError(
184
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
185
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
186
- )
187
-
188
- # Define 3 blocks. Each block has its own normalization layer.
189
- # 1. Self-Attn
190
- if self.use_ada_layer_norm:
191
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
192
- elif self.use_ada_layer_norm_zero:
193
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
194
- else:
195
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
196
- self.attn1 = Attention(
197
- query_dim=dim,
198
- heads=num_attention_heads,
199
- dim_head=attention_head_dim,
200
- dropout=dropout,
201
- bias=attention_bias,
202
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
203
- upcast_attention=upcast_attention,
204
- )
205
-
206
- # 2. Cross-Attn
207
- if cross_attention_dim is not None or double_self_attention:
208
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
209
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
210
- # the second cross attention block.
211
- self.norm2 = (
212
- AdaLayerNorm(dim, num_embeds_ada_norm)
213
- if self.use_ada_layer_norm
214
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
215
- )
216
- self.attn2 = Attention(
217
- query_dim=dim,
218
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
219
- heads=num_attention_heads,
220
- dim_head=attention_head_dim,
221
- dropout=dropout,
222
- bias=attention_bias,
223
- upcast_attention=upcast_attention,
224
- # scale_qk=False, # uncomment this to not to use flash attention
225
- ) # is self-attn if encoder_hidden_states is none
226
- else:
227
- self.norm2 = None
228
- self.attn2 = None
229
-
230
- # 3. Feed-forward
231
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
232
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
233
-
234
- # let chunk size default to None
235
- self._chunk_size = None
236
- self._chunk_dim = 0
237
-
238
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
239
- # Sets chunk feed-forward
240
- self._chunk_size = chunk_size
241
- self._chunk_dim = dim
242
-
243
- def forward(
244
- self,
245
- hidden_states: torch.FloatTensor,
246
- attention_mask: Optional[torch.FloatTensor] = None,
247
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
248
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
249
- timestep: Optional[torch.LongTensor] = None,
250
- cross_attention_kwargs: Dict[str, Any] = None,
251
- class_labels: Optional[torch.LongTensor] = None,
252
- ):
253
- # Notice that normalization is always applied before the real computation in the following blocks.
254
- # 1. Self-Attention
255
- if self.use_ada_layer_norm:
256
- norm_hidden_states = self.norm1(hidden_states, timestep)
257
- elif self.use_ada_layer_norm_zero:
258
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
259
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
260
- )
261
- else:
262
- norm_hidden_states = self.norm1(hidden_states)
263
-
264
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
265
-
266
- attn_output = self.attn1(
267
- norm_hidden_states,
268
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
269
- attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
270
- **cross_attention_kwargs,
271
- )
272
- if self.use_ada_layer_norm_zero:
273
- attn_output = gate_msa.unsqueeze(1) * attn_output
274
- hidden_states = attn_output + hidden_states
275
-
276
- # 2. Cross-Attention
277
- if self.attn2 is not None:
278
- norm_hidden_states = (
279
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
280
- )
281
-
282
- attn_output = self.attn2(
283
- norm_hidden_states,
284
- encoder_hidden_states=encoder_hidden_states,
285
- attention_mask=encoder_attention_mask,
286
- **cross_attention_kwargs,
287
- )
288
- hidden_states = attn_output + hidden_states
289
-
290
- # 3. Feed-forward
291
- norm_hidden_states = self.norm3(hidden_states)
292
-
293
- if self.use_ada_layer_norm_zero:
294
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
295
-
296
- if self._chunk_size is not None:
297
- # "feed_forward_chunk_size" can be used to save memory
298
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
299
- raise ValueError(
300
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
301
- )
302
-
303
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
304
- ff_output = torch.cat(
305
- [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
306
- dim=self._chunk_dim,
307
- )
308
- else:
309
- ff_output = self.ff(norm_hidden_states)
310
-
311
- if self.use_ada_layer_norm_zero:
312
- ff_output = gate_mlp.unsqueeze(1) * ff_output
313
-
314
- hidden_states = ff_output + hidden_states
315
-
316
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/s3gen.py DELETED
@@ -1,298 +0,0 @@
1
- # Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import logging
16
-
17
- import numpy as np
18
- import torch
19
- import torchaudio as ta
20
- from functools import lru_cache
21
- from typing import Optional
22
-
23
- from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
24
- from .const import S3GEN_SR
25
- from .flow import CausalMaskedDiffWithXvec
26
- from .xvector import CAMPPlus
27
- from .utils.mel import mel_spectrogram
28
- from .f0_predictor import ConvRNNF0Predictor
29
- from .hifigan import HiFTGenerator
30
- from .transformer.upsample_encoder import UpsampleConformerEncoder
31
- from .flow_matching import CausalConditionalCFM
32
- from .decoder import ConditionalDecoder
33
- from .configs import CFM_PARAMS
34
-
35
-
36
- def drop_invalid_tokens(x):
37
- assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now"
38
- return x[x < SPEECH_VOCAB_SIZE]
39
-
40
-
41
- # TODO: global resampler cache
42
- @lru_cache(100)
43
- def get_resampler(src_sr, dst_sr, device):
44
- return ta.transforms.Resample(src_sr, dst_sr).to(device)
45
-
46
-
47
- class S3Token2Mel(torch.nn.Module):
48
- """
49
- CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
50
-
51
- TODO: make these modules configurable?
52
- """
53
- def __init__(self):
54
- super().__init__()
55
- self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
56
- self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
57
- self.speaker_encoder = CAMPPlus() # use default args
58
-
59
- encoder = UpsampleConformerEncoder(
60
- output_size=512,
61
- attention_heads=8,
62
- linear_units=2048,
63
- num_blocks=6,
64
- dropout_rate=0.1,
65
- positional_dropout_rate=0.1,
66
- attention_dropout_rate=0.1,
67
- normalize_before=True,
68
- input_layer='linear',
69
- pos_enc_layer_type='rel_pos_espnet',
70
- selfattention_layer_type='rel_selfattn',
71
- input_size=512,
72
- use_cnn_module=False,
73
- macaron_style=False,
74
- )
75
-
76
- estimator = ConditionalDecoder(
77
- in_channels=320,
78
- out_channels=80,
79
- causal=True,
80
- channels=[256],
81
- dropout=0.0,
82
- attention_head_dim=64,
83
- n_blocks=4,
84
- num_mid_blocks=12,
85
- num_heads=8,
86
- act_fn='gelu',
87
- )
88
- cfm_params = CFM_PARAMS
89
- decoder = CausalConditionalCFM(
90
- spk_emb_dim=80,
91
- cfm_params=cfm_params,
92
- estimator=estimator,
93
- )
94
-
95
- self.flow = CausalMaskedDiffWithXvec(
96
- encoder=encoder,
97
- decoder=decoder
98
- )
99
-
100
- self.resamplers = {}
101
-
102
- @property
103
- def device(self):
104
- params = self.tokenizer.parameters()
105
- return next(params).device
106
-
107
- def embed_ref(
108
- self,
109
- ref_wav: torch.Tensor,
110
- ref_sr: int,
111
- device="auto",
112
- ref_fade_out=True,
113
- ):
114
- device = self.device if device == "auto" else device
115
- if isinstance(ref_wav, np.ndarray):
116
- ref_wav = torch.from_numpy(ref_wav).float()
117
-
118
- if ref_wav.device != device:
119
- ref_wav = ref_wav.to(device)
120
-
121
- if len(ref_wav.shape) == 1:
122
- ref_wav = ref_wav.unsqueeze(0) # (B, L)
123
-
124
- if ref_wav.size(1) > 10 * ref_sr:
125
- print("WARNING: cosydec received ref longer than 10s")
126
-
127
- ref_wav_24 = ref_wav
128
- if ref_sr != S3GEN_SR:
129
- ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
130
-
131
- ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
132
- ref_mels_24_len = None
133
-
134
- # Resample to 16kHz
135
- ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
136
-
137
- # Speaker embedding
138
- ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
139
-
140
- # Tokenize 16khz reference
141
- ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
142
-
143
- # Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
144
- if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
145
- logging.warning(
146
- "Reference mel length is not equal to 2 * reference token length.\n"
147
- )
148
- ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2]
149
- ref_speech_token_lens[0] = ref_speech_tokens.shape[1]
150
-
151
- return dict(
152
- prompt_token=ref_speech_tokens.to(device),
153
- prompt_token_len=ref_speech_token_lens,
154
- prompt_feat=ref_mels_24,
155
- prompt_feat_len=ref_mels_24_len,
156
- embedding=ref_x_vector,
157
- )
158
-
159
- def forward(
160
- self,
161
- speech_tokens: torch.LongTensor,
162
- # locally-computed ref embedding (mutex with ref_dict)
163
- ref_wav: Optional[torch.Tensor],
164
- ref_sr: Optional[int],
165
- # pre-computed ref embedding (prod API)
166
- ref_dict: Optional[dict] = None,
167
- finalize: bool = False,
168
- ):
169
- """
170
- Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
171
-
172
- NOTE:
173
- - The speaker encoder accepts 16 kHz waveform.
174
- - S3TokenizerV2 accepts 16 kHz waveform.
175
- - The mel-spectrogram for the reference assumes 24 kHz input signal.
176
- - This function is designed for batch_size=1 only.
177
-
178
- Args
179
- ----
180
- - `speech_tokens`: S3 speech tokens [B=1, T]
181
- - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T])
182
- - `ref_sr`: reference sample rate
183
- - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored.
184
- """
185
- assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})"
186
-
187
- if ref_dict is None:
188
- ref_dict = self.embed_ref(ref_wav, ref_sr)
189
- else:
190
- # type/device casting (all values will be numpy if it's from a prod API call)
191
- for rk in list(ref_dict):
192
- if isinstance(ref_dict[rk], np.ndarray):
193
- ref_dict[rk] = torch.from_numpy(ref_dict[rk])
194
- if torch.is_tensor(ref_dict[rk]):
195
- ref_dict[rk] = ref_dict[rk].to(self.device)
196
-
197
- if len(speech_tokens.shape) == 1:
198
- speech_tokens = speech_tokens.unsqueeze(0)
199
-
200
- # assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
201
- speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
202
-
203
- output_mels, _ = self.flow.inference(
204
- token=speech_tokens,
205
- token_len=speech_token_lens,
206
- finalize=finalize,
207
- **ref_dict,
208
- )
209
- return output_mels
210
-
211
-
212
- class S3Token2Wav(S3Token2Mel):
213
- """
214
- The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
215
-
216
- TODO: make these modules configurable?
217
- """
218
-
219
- def __init__(self):
220
- super().__init__()
221
-
222
- f0_predictor = ConvRNNF0Predictor()
223
- self.mel2wav = HiFTGenerator(
224
- sampling_rate=S3GEN_SR,
225
- upsample_rates=[8, 5, 3],
226
- upsample_kernel_sizes=[16, 11, 7],
227
- source_resblock_kernel_sizes=[7, 7, 11],
228
- source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
229
- f0_predictor=f0_predictor,
230
- )
231
-
232
- # silence out a few ms and fade audio in to reduce artifacts
233
- n_trim = S3GEN_SR // 50 # 20ms = half of a frame
234
- trim_fade = torch.zeros(2 * n_trim)
235
- trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
236
- self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
237
-
238
- def forward(
239
- self,
240
- speech_tokens,
241
- # locally-computed ref embedding (mutex with ref_dict)
242
- ref_wav: Optional[torch.Tensor],
243
- ref_sr: Optional[int],
244
- # pre-computed ref embedding (prod API)
245
- ref_dict: Optional[dict] = None,
246
- finalize: bool = False
247
- ):
248
- output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
249
-
250
- # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
251
- hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
252
-
253
- output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source)
254
-
255
- if not self.training:
256
- # NOTE: ad-hoc method to reduce "spillover" from the reference clip.
257
- output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
258
-
259
- return output_wavs
260
-
261
- @torch.inference_mode()
262
- def flow_inference(
263
- self,
264
- speech_tokens,
265
- # locally-computed ref embedding (mutex with ref_dict)
266
- ref_wav: Optional[torch.Tensor] = None,
267
- ref_sr: Optional[int] = None,
268
- # pre-computed ref embedding (prod API)
269
- ref_dict: Optional[dict] = None,
270
- finalize: bool = False,
271
- ):
272
- return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
273
-
274
- @torch.inference_mode()
275
- def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
276
- if cache_source is None:
277
- cache_source = torch.zeros(1, 1, 0).to(self.device)
278
- return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
279
-
280
- @torch.inference_mode()
281
- def inference(
282
- self,
283
- speech_tokens,
284
- # locally-computed ref embedding (mutex with ref_dict)
285
- ref_wav: Optional[torch.Tensor] = None,
286
- ref_sr: Optional[int] = None,
287
- # pre-computed ref embedding (prod API)
288
- ref_dict: Optional[dict] = None,
289
- cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
290
- finalize: bool = True,
291
- ):
292
- output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
293
- output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
294
-
295
- # NOTE: ad-hoc method to reduce "spillover" from the reference clip.
296
- output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
297
-
298
- return output_wavs, output_sources
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/transformer/__init__.py DELETED
File without changes
src/chatterbox/models/s3gen/transformer/activation.py DELETED
@@ -1,84 +0,0 @@
1
- # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
- # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
- # 2020 Mobvoi Inc (Binbin Zhang)
4
- # 2024 Alibaba Inc (Xiang Lyu)
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
- """Swish() activation function for Conformer."""
18
-
19
- import torch
20
- from torch import nn, sin, pow
21
- from torch.nn import Parameter
22
-
23
-
24
- class Swish(torch.nn.Module):
25
- """Construct an Swish object."""
26
-
27
- def forward(self, x: torch.Tensor) -> torch.Tensor:
28
- """Return Swish activation function."""
29
- return x * torch.sigmoid(x)
30
-
31
-
32
- # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
- # LICENSE is in incl_licenses directory.
34
- class Snake(nn.Module):
35
- '''
36
- Implementation of a sine-based periodic activation function
37
- Shape:
38
- - Input: (B, C, T)
39
- - Output: (B, C, T), same shape as the input
40
- Parameters:
41
- - alpha - trainable parameter
42
- References:
43
- - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
- https://arxiv.org/abs/2006.08195
45
- Examples:
46
- >>> a1 = snake(256)
47
- >>> x = torch.randn(256)
48
- >>> x = a1(x)
49
- '''
50
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
- '''
52
- Initialization.
53
- INPUT:
54
- - in_features: shape of the input
55
- - alpha: trainable parameter
56
- alpha is initialized to 1 by default, higher values = higher-frequency.
57
- alpha will be trained along with the rest of your model.
58
- '''
59
- super(Snake, self).__init__()
60
- self.in_features = in_features
61
-
62
- # initialize alpha
63
- self.alpha_logscale = alpha_logscale
64
- if self.alpha_logscale: # log scale alphas initialized to zeros
65
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
- else: # linear scale alphas initialized to ones
67
- self.alpha = Parameter(torch.ones(in_features) * alpha)
68
-
69
- self.alpha.requires_grad = alpha_trainable
70
-
71
- self.no_div_by_zero = 0.000000001
72
-
73
- def forward(self, x):
74
- '''
75
- Forward pass of the function.
76
- Applies the function to the input elementwise.
77
- Snake ∶= x + 1/a * sin^2 (xa)
78
- '''
79
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
- if self.alpha_logscale:
81
- alpha = torch.exp(alpha)
82
- x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
-
84
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/transformer/attention.py DELETED
@@ -1,330 +0,0 @@
1
- # Copyright (c) 2019 Shigeki Karita
2
- # 2020 Mobvoi Inc (Binbin Zhang)
3
- # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
4
- # 2024 Alibaba Inc (Xiang Lyu)
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
- """Multi-Head Attention layer definition."""
18
-
19
- import math
20
- from typing import Tuple
21
-
22
- import torch
23
- from torch import nn
24
-
25
-
26
- class MultiHeadedAttention(nn.Module):
27
- """Multi-Head Attention layer.
28
-
29
- Args:
30
- n_head (int): The number of heads.
31
- n_feat (int): The number of features.
32
- dropout_rate (float): Dropout rate.
33
-
34
- """
35
-
36
- def __init__(self,
37
- n_head: int,
38
- n_feat: int,
39
- dropout_rate: float,
40
- key_bias: bool = True):
41
- """Construct an MultiHeadedAttention object."""
42
- super().__init__()
43
- assert n_feat % n_head == 0
44
- # We assume d_v always equals d_k
45
- self.d_k = n_feat // n_head
46
- self.h = n_head
47
- self.linear_q = nn.Linear(n_feat, n_feat)
48
- self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
- self.linear_v = nn.Linear(n_feat, n_feat)
50
- self.linear_out = nn.Linear(n_feat, n_feat)
51
- self.dropout = nn.Dropout(p=dropout_rate)
52
-
53
- def forward_qkv(
54
- self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
- """Transform query, key and value.
57
-
58
- Args:
59
- query (torch.Tensor): Query tensor (#batch, time1, size).
60
- key (torch.Tensor): Key tensor (#batch, time2, size).
61
- value (torch.Tensor): Value tensor (#batch, time2, size).
62
-
63
- Returns:
64
- torch.Tensor: Transformed query tensor, size
65
- (#batch, n_head, time1, d_k).
66
- torch.Tensor: Transformed key tensor, size
67
- (#batch, n_head, time2, d_k).
68
- torch.Tensor: Transformed value tensor, size
69
- (#batch, n_head, time2, d_k).
70
-
71
- """
72
- n_batch = query.size(0)
73
- q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
- q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
- k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
- v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
-
80
- return q, k, v
81
-
82
- def forward_attention(
83
- self,
84
- value: torch.Tensor,
85
- scores: torch.Tensor,
86
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
- ) -> torch.Tensor:
88
- """Compute attention context vector.
89
-
90
- Args:
91
- value (torch.Tensor): Transformed value, size
92
- (#batch, n_head, time2, d_k).
93
- scores (torch.Tensor): Attention score, size
94
- (#batch, n_head, time1, time2).
95
- mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
- (#batch, time1, time2), (0, 0, 0) means fake mask.
97
-
98
- Returns:
99
- torch.Tensor: Transformed value (#batch, time1, d_model)
100
- weighted by the attention score (#batch, time1, time2).
101
-
102
- """
103
- n_batch = value.size(0)
104
- # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
- # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
- # 1st chunk to ease the onnx export.]
107
- # 2. pytorch training
108
- if mask.size(2) > 0: # time2 > 0
109
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
- # For last chunk, time2 might be larger than scores.size(-1)
111
- mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
- scores = scores.masked_fill(mask, -float('inf'))
113
- attn = torch.softmax(scores, dim=-1).masked_fill(
114
- mask, 0.0) # (batch, head, time1, time2)
115
- # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
- # 1. onnx(16/-1, -1/-1, 16/0)
117
- # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
- else:
119
- attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
-
121
- p_attn = self.dropout(attn)
122
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
- x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
- self.h * self.d_k)
125
- ) # (batch, time1, d_model)
126
-
127
- return self.linear_out(x) # (batch, time1, d_model)
128
-
129
- def forward(
130
- self,
131
- query: torch.Tensor,
132
- key: torch.Tensor,
133
- value: torch.Tensor,
134
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
- pos_emb: torch.Tensor = torch.empty(0),
136
- cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
- ) -> Tuple[torch.Tensor, torch.Tensor]:
138
- """Compute scaled dot product attention.
139
-
140
- Args:
141
- query (torch.Tensor): Query tensor (#batch, time1, size).
142
- key (torch.Tensor): Key tensor (#batch, time2, size).
143
- value (torch.Tensor): Value tensor (#batch, time2, size).
144
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
- (#batch, time1, time2).
146
- 1.When applying cross attention between decoder and encoder,
147
- the batch padding mask for input is in (#batch, 1, T) shape.
148
- 2.When applying self attention of encoder,
149
- the mask is in (#batch, T, T) shape.
150
- 3.When applying self attention of decoder,
151
- the mask is in (#batch, L, L) shape.
152
- 4.If the different position in decoder see different block
153
- of the encoder, such as Mocha, the passed in mask could be
154
- in (#batch, L, T) shape. But there is no such case in current
155
- CosyVoice.
156
- cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
- where `cache_t == chunk_size * num_decoding_left_chunks`
158
- and `head * d_k == size`
159
-
160
-
161
- Returns:
162
- torch.Tensor: Output tensor (#batch, time1, d_model).
163
- torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
- where `cache_t == chunk_size * num_decoding_left_chunks`
165
- and `head * d_k == size`
166
-
167
- """
168
- q, k, v = self.forward_qkv(query, key, value)
169
-
170
- # NOTE(xcsong):
171
- # when export onnx model, for 1st chunk, we feed
172
- # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
- # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
- # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
- # and we will always do splitting and
176
- # concatnation(this will simplify onnx export). Note that
177
- # it's OK to concat & split zero-shaped tensors(see code below).
178
- # when export jit model, for 1st chunk, we always feed
179
- # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
- # >>> a = torch.ones((1, 2, 0, 4))
181
- # >>> b = torch.ones((1, 2, 3, 4))
182
- # >>> c = torch.cat((a, b), dim=2)
183
- # >>> torch.equal(b, c) # True
184
- # >>> d = torch.split(a, 2, dim=-1)
185
- # >>> torch.equal(d[0], d[1]) # True
186
- if cache.size(0) > 0:
187
- key_cache, value_cache = torch.split(cache,
188
- cache.size(-1) // 2,
189
- dim=-1)
190
- k = torch.cat([key_cache, k], dim=2)
191
- v = torch.cat([value_cache, v], dim=2)
192
- # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
- # non-trivial to calculate `next_cache_start` here.
194
- new_cache = torch.cat((k, v), dim=-1)
195
-
196
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
- return self.forward_attention(v, scores, mask), new_cache
198
-
199
-
200
- class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
- """Multi-Head Attention layer with relative position encoding.
202
- Paper: https://arxiv.org/abs/1901.02860
203
- Args:
204
- n_head (int): The number of heads.
205
- n_feat (int): The number of features.
206
- dropout_rate (float): Dropout rate.
207
- """
208
-
209
- def __init__(self,
210
- n_head: int,
211
- n_feat: int,
212
- dropout_rate: float,
213
- key_bias: bool = True):
214
- """Construct an RelPositionMultiHeadedAttention object."""
215
- super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
- # linear transformation for positional encoding
217
- self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
- # these two learnable bias are used in matrix c and matrix d
219
- # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
- self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
- self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
- torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
- torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
-
225
- def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
- """Compute relative positional encoding.
227
-
228
- Args:
229
- x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
- time1 means the length of query vector.
231
-
232
- Returns:
233
- torch.Tensor: Output tensor.
234
-
235
- """
236
- zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
- device=x.device,
238
- dtype=x.dtype)
239
- x_padded = torch.cat([zero_pad, x], dim=-1)
240
-
241
- x_padded = x_padded.view(x.size()[0],
242
- x.size()[1],
243
- x.size(3) + 1, x.size(2))
244
- x = x_padded[:, :, 1:].view_as(x)[
245
- :, :, :, : x.size(-1) // 2 + 1
246
- ] # only keep the positions from 0 to time2
247
- return x
248
-
249
- def forward(
250
- self,
251
- query: torch.Tensor,
252
- key: torch.Tensor,
253
- value: torch.Tensor,
254
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
255
- pos_emb: torch.Tensor = torch.empty(0),
256
- cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
257
- ) -> Tuple[torch.Tensor, torch.Tensor]:
258
- """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
259
- Args:
260
- query (torch.Tensor): Query tensor (#batch, time1, size).
261
- key (torch.Tensor): Key tensor (#batch, time2, size).
262
- value (torch.Tensor): Value tensor (#batch, time2, size).
263
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
264
- (#batch, time1, time2), (0, 0, 0) means fake mask.
265
- pos_emb (torch.Tensor): Positional embedding tensor
266
- (#batch, time2, size).
267
- cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
268
- where `cache_t == chunk_size * num_decoding_left_chunks`
269
- and `head * d_k == size`
270
- Returns:
271
- torch.Tensor: Output tensor (#batch, time1, d_model).
272
- torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
273
- where `cache_t == chunk_size * num_decoding_left_chunks`
274
- and `head * d_k == size`
275
- """
276
- q, k, v = self.forward_qkv(query, key, value)
277
- q = q.transpose(1, 2) # (batch, time1, head, d_k)
278
-
279
- # NOTE(xcsong):
280
- # when export onnx model, for 1st chunk, we feed
281
- # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
282
- # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
283
- # In all modes, `if cache.size(0) > 0` will alwayse be `True`
284
- # and we will always do splitting and
285
- # concatnation(this will simplify onnx export). Note that
286
- # it's OK to concat & split zero-shaped tensors(see code below).
287
- # when export jit model, for 1st chunk, we always feed
288
- # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
289
- # >>> a = torch.ones((1, 2, 0, 4))
290
- # >>> b = torch.ones((1, 2, 3, 4))
291
- # >>> c = torch.cat((a, b), dim=2)
292
- # >>> torch.equal(b, c) # True
293
- # >>> d = torch.split(a, 2, dim=-1)
294
- # >>> torch.equal(d[0], d[1]) # True
295
- if cache.size(0) > 0:
296
- key_cache, value_cache = torch.split(cache,
297
- cache.size(-1) // 2,
298
- dim=-1)
299
- k = torch.cat([key_cache, k], dim=2)
300
- v = torch.cat([value_cache, v], dim=2)
301
- # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
302
- # non-trivial to calculate `next_cache_start` here.
303
- new_cache = torch.cat((k, v), dim=-1)
304
-
305
- n_batch_pos = pos_emb.size(0)
306
- p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
307
- p = p.transpose(1, 2) # (batch, head, time1, d_k)
308
-
309
- # (batch, head, time1, d_k)
310
- q_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2)
311
- # (batch, head, time1, d_k)
312
- q_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2)
313
-
314
- # compute attention score
315
- # first compute matrix a and matrix c
316
- # as described in https://arxiv.org/abs/1901.02860 Section 3.3
317
- # (batch, head, time1, time2)
318
- matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
319
-
320
- # compute matrix b and matrix d
321
- # (batch, head, time1, time2)
322
- matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
323
- # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
324
- if matrix_ac.shape != matrix_bd.shape:
325
- matrix_bd = self.rel_shift(matrix_bd)
326
-
327
- scores = (matrix_ac + matrix_bd) / math.sqrt(
328
- self.d_k) # (batch, head, time1, time2)
329
-
330
- return self.forward_attention(v, scores, mask), new_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/transformer/convolution.py DELETED
@@ -1,145 +0,0 @@
1
- # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
- # 2024 Alibaba Inc (Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Modified from ESPnet(https://github.com/espnet/espnet)
16
- """ConvolutionModule definition."""
17
-
18
- from typing import Tuple
19
-
20
- import torch
21
- from torch import nn
22
-
23
-
24
- class ConvolutionModule(nn.Module):
25
- """ConvolutionModule in Conformer model."""
26
-
27
- def __init__(self,
28
- channels: int,
29
- kernel_size: int = 15,
30
- activation: nn.Module = nn.ReLU(),
31
- norm: str = "batch_norm",
32
- causal: bool = False,
33
- bias: bool = True):
34
- """Construct an ConvolutionModule object.
35
- Args:
36
- channels (int): The number of channels of conv layers.
37
- kernel_size (int): Kernel size of conv layers.
38
- causal (int): Whether use causal convolution or not
39
- """
40
- super().__init__()
41
-
42
- self.pointwise_conv1 = nn.Conv1d(
43
- channels,
44
- 2 * channels,
45
- kernel_size=1,
46
- stride=1,
47
- padding=0,
48
- bias=bias,
49
- )
50
- # self.lorder is used to distinguish if it's a causal convolution,
51
- # if self.lorder > 0: it's a causal convolution, the input will be
52
- # padded with self.lorder frames on the left in forward.
53
- # else: it's a symmetrical convolution
54
- if causal:
55
- padding = 0
56
- self.lorder = kernel_size - 1
57
- else:
58
- # kernel_size should be an odd number for none causal convolution
59
- assert (kernel_size - 1) % 2 == 0
60
- padding = (kernel_size - 1) // 2
61
- self.lorder = 0
62
- self.depthwise_conv = nn.Conv1d(
63
- channels,
64
- channels,
65
- kernel_size,
66
- stride=1,
67
- padding=padding,
68
- groups=channels,
69
- bias=bias,
70
- )
71
-
72
- assert norm in ['batch_norm', 'layer_norm']
73
- if norm == "batch_norm":
74
- self.use_layer_norm = False
75
- self.norm = nn.BatchNorm1d(channels)
76
- else:
77
- self.use_layer_norm = True
78
- self.norm = nn.LayerNorm(channels)
79
-
80
- self.pointwise_conv2 = nn.Conv1d(
81
- channels,
82
- channels,
83
- kernel_size=1,
84
- stride=1,
85
- padding=0,
86
- bias=bias,
87
- )
88
- self.activation = activation
89
-
90
- def forward(
91
- self,
92
- x: torch.Tensor,
93
- mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94
- cache: torch.Tensor = torch.zeros((0, 0, 0)),
95
- ) -> Tuple[torch.Tensor, torch.Tensor]:
96
- """Compute convolution module.
97
- Args:
98
- x (torch.Tensor): Input tensor (#batch, time, channels).
99
- mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100
- (0, 0, 0) means fake mask.
101
- cache (torch.Tensor): left context cache, it is only
102
- used in causal convolution (#batch, channels, cache_t),
103
- (0, 0, 0) meas fake cache.
104
- Returns:
105
- torch.Tensor: Output tensor (#batch, time, channels).
106
- """
107
- # exchange the temporal dimension and the feature dimension
108
- x = x.transpose(1, 2) # (#batch, channels, time)
109
-
110
- # mask batch padding
111
- if mask_pad.size(2) > 0: # time > 0
112
- x.masked_fill_(~mask_pad, 0.0)
113
-
114
- if self.lorder > 0:
115
- if cache.size(2) == 0: # cache_t == 0
116
- x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117
- else:
118
- assert cache.size(0) == x.size(0) # equal batch
119
- assert cache.size(1) == x.size(1) # equal channel
120
- x = torch.cat((cache, x), dim=2)
121
- assert (x.size(2) > self.lorder)
122
- new_cache = x[:, :, -self.lorder:]
123
- else:
124
- # It's better we just return None if no cache is required,
125
- # However, for JIT export, here we just fake one tensor instead of
126
- # None.
127
- new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128
-
129
- # GLU mechanism
130
- x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131
- x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132
-
133
- # 1D Depthwise Conv
134
- x = self.depthwise_conv(x)
135
- if self.use_layer_norm:
136
- x = x.transpose(1, 2)
137
- x = self.activation(self.norm(x))
138
- if self.use_layer_norm:
139
- x = x.transpose(1, 2)
140
- x = self.pointwise_conv2(x)
141
- # mask batch padding
142
- if mask_pad.size(2) > 0: # time > 0
143
- x.masked_fill_(~mask_pad, 0.0)
144
-
145
- return x.transpose(1, 2), new_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/transformer/embedding.py DELETED
@@ -1,294 +0,0 @@
1
- # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
- # 2024 Alibaba Inc (Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Modified from ESPnet(https://github.com/espnet/espnet)
16
- """Positonal Encoding Module."""
17
-
18
- import math
19
- from typing import Tuple, Union
20
-
21
- import torch
22
- import torch.nn.functional as F
23
- import numpy as np
24
-
25
-
26
- class PositionalEncoding(torch.nn.Module):
27
- """Positional encoding.
28
-
29
- :param int d_model: embedding dim
30
- :param float dropout_rate: dropout rate
31
- :param int max_len: maximum input length
32
-
33
- PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
34
- PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
35
- """
36
-
37
- def __init__(self,
38
- d_model: int,
39
- dropout_rate: float,
40
- max_len: int = 5000,
41
- reverse: bool = False):
42
- """Construct an PositionalEncoding object."""
43
- super().__init__()
44
- self.d_model = d_model
45
- self.xscale = math.sqrt(self.d_model)
46
- self.dropout = torch.nn.Dropout(p=dropout_rate)
47
- self.max_len = max_len
48
-
49
- self.pe = torch.zeros(self.max_len, self.d_model)
50
- position = torch.arange(0, self.max_len,
51
- dtype=torch.float32).unsqueeze(1)
52
- div_term = torch.exp(
53
- torch.arange(0, self.d_model, 2, dtype=torch.float32) *
54
- -(math.log(10000.0) / self.d_model))
55
- self.pe[:, 0::2] = torch.sin(position * div_term)
56
- self.pe[:, 1::2] = torch.cos(position * div_term)
57
- self.pe = self.pe.unsqueeze(0)
58
-
59
- def forward(self,
60
- x: torch.Tensor,
61
- offset: Union[int, torch.Tensor] = 0) \
62
- -> Tuple[torch.Tensor, torch.Tensor]:
63
- """Add positional encoding.
64
-
65
- Args:
66
- x (torch.Tensor): Input. Its shape is (batch, time, ...)
67
- offset (int, torch.tensor): position offset
68
-
69
- Returns:
70
- torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
71
- torch.Tensor: for compatibility to RelPositionalEncoding
72
- """
73
-
74
- self.pe = self.pe.to(x.device)
75
- pos_emb = self.position_encoding(offset, x.size(1), False)
76
- x = x * self.xscale + pos_emb
77
- return self.dropout(x), self.dropout(pos_emb)
78
-
79
- def position_encoding(self,
80
- offset: Union[int, torch.Tensor],
81
- size: int,
82
- apply_dropout: bool = True) -> torch.Tensor:
83
- """ For getting encoding in a streaming fashion
84
-
85
- Attention!!!!!
86
- we apply dropout only once at the whole utterance level in a none
87
- streaming way, but will call this function several times with
88
- increasing input size in a streaming scenario, so the dropout will
89
- be applied several times.
90
-
91
- Args:
92
- offset (int or torch.tensor): start offset
93
- size (int): required size of position encoding
94
-
95
- Returns:
96
- torch.Tensor: Corresponding encoding
97
- """
98
- # How to subscript a Union type:
99
- # https://github.com/pytorch/pytorch/issues/69434
100
- if isinstance(offset, int):
101
- assert offset + size <= self.max_len
102
- pos_emb = self.pe[:, offset:offset + size]
103
- elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
104
- assert offset + size <= self.max_len
105
- pos_emb = self.pe[:, offset:offset + size]
106
- else: # for batched streaming decoding on GPU
107
- assert torch.max(offset) + size <= self.max_len
108
- index = offset.unsqueeze(1) + \
109
- torch.arange(0, size).to(offset.device) # B X T
110
- flag = index > 0
111
- # remove negative offset
112
- index = index * flag
113
- pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
114
-
115
- if apply_dropout:
116
- pos_emb = self.dropout(pos_emb)
117
- return pos_emb
118
-
119
-
120
- class RelPositionalEncoding(PositionalEncoding):
121
- """Relative positional encoding module.
122
- See : Appendix B in https://arxiv.org/abs/1901.02860
123
- Args:
124
- d_model (int): Embedding dimension.
125
- dropout_rate (float): Dropout rate.
126
- max_len (int): Maximum input length.
127
- """
128
-
129
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
130
- """Initialize class."""
131
- super().__init__(d_model, dropout_rate, max_len, reverse=True)
132
-
133
- def forward(self,
134
- x: torch.Tensor,
135
- offset: Union[int, torch.Tensor] = 0) \
136
- -> Tuple[torch.Tensor, torch.Tensor]:
137
- """Compute positional encoding.
138
- Args:
139
- x (torch.Tensor): Input tensor (batch, time, `*`).
140
- Returns:
141
- torch.Tensor: Encoded tensor (batch, time, `*`).
142
- torch.Tensor: Positional embedding tensor (1, time, `*`).
143
- """
144
- self.pe = self.pe.to(x.device)
145
- x = x * self.xscale
146
- pos_emb = self.position_encoding(offset, x.size(1), False)
147
- return self.dropout(x), self.dropout(pos_emb)
148
-
149
-
150
- class WhisperPositionalEncoding(PositionalEncoding):
151
- """ Sinusoids position encoding used in openai-whisper.encoder
152
- """
153
-
154
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
155
- super().__init__(d_model, dropout_rate, max_len)
156
- self.xscale = 1.0
157
- log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
158
- inv_timescales = torch.exp(-log_timescale_increment *
159
- torch.arange(d_model // 2))
160
- scaled_time = torch.arange(max_len)[:, np.newaxis] * \
161
- inv_timescales[np.newaxis, :]
162
- pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
163
- delattr(self, "pe")
164
- self.register_buffer("pe", pe.unsqueeze(0))
165
-
166
-
167
- class LearnablePositionalEncoding(PositionalEncoding):
168
- """ Learnable position encoding used in openai-whisper.decoder
169
- """
170
-
171
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
172
- super().__init__(d_model, dropout_rate, max_len)
173
- # NOTE(xcsong): overwrite self.pe & self.xscale
174
- self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
175
- self.xscale = 1.0
176
-
177
-
178
- class NoPositionalEncoding(torch.nn.Module):
179
- """ No position encoding
180
- """
181
-
182
- def __init__(self, d_model: int, dropout_rate: float):
183
- super().__init__()
184
- self.d_model = d_model
185
- self.dropout = torch.nn.Dropout(p=dropout_rate)
186
-
187
- def forward(self,
188
- x: torch.Tensor,
189
- offset: Union[int, torch.Tensor] = 0) \
190
- -> Tuple[torch.Tensor, torch.Tensor]:
191
- """ Just return zero vector for interface compatibility
192
- """
193
- pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
194
- return self.dropout(x), pos_emb
195
-
196
- def position_encoding(self, offset: Union[int, torch.Tensor],
197
- size: int) -> torch.Tensor:
198
- return torch.zeros(1, size, self.d_model)
199
-
200
-
201
- class EspnetRelPositionalEncoding(torch.nn.Module):
202
- """Relative positional encoding module (new implementation).
203
-
204
- Details can be found in https://github.com/espnet/espnet/pull/2816.
205
-
206
- See : Appendix B in https://arxiv.org/abs/1901.02860
207
-
208
- Args:
209
- d_model (int): Embedding dimension.
210
- dropout_rate (float): Dropout rate.
211
- max_len (int): Maximum input length.
212
-
213
- """
214
-
215
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
216
- """Construct an PositionalEncoding object."""
217
- super(EspnetRelPositionalEncoding, self).__init__()
218
- self.d_model = d_model
219
- self.xscale = math.sqrt(self.d_model)
220
- self.dropout = torch.nn.Dropout(p=dropout_rate)
221
- self.pe = None
222
- self.extend_pe(torch.tensor(0.0).expand(1, max_len))
223
-
224
- def extend_pe(self, x: torch.Tensor):
225
- """Reset the positional encodings."""
226
- if self.pe is not None:
227
- # self.pe contains both positive and negative parts
228
- # the length of self.pe is 2 * input_len - 1
229
- if self.pe.size(1) >= x.size(1) * 2 - 1:
230
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
231
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
232
- return
233
- # Suppose `i` means to the position of query vecotr and `j` means the
234
- # position of key vector. We use position relative positions when keys
235
- # are to the left (i>j) and negative relative positions otherwise (i<j).
236
- pe_positive = torch.zeros(x.size(1), self.d_model)
237
- pe_negative = torch.zeros(x.size(1), self.d_model)
238
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
239
- div_term = torch.exp(
240
- torch.arange(0, self.d_model, 2, dtype=torch.float32)
241
- * -(math.log(10000.0) / self.d_model)
242
- )
243
- pe_positive[:, 0::2] = torch.sin(position * div_term)
244
- pe_positive[:, 1::2] = torch.cos(position * div_term)
245
- pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
246
- pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
247
-
248
- # Reserve the order of positive indices and concat both positive and
249
- # negative indices. This is used to support the shifting trick
250
- # as in https://arxiv.org/abs/1901.02860
251
- pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
252
- pe_negative = pe_negative[1:].unsqueeze(0)
253
- pe = torch.cat([pe_positive, pe_negative], dim=1)
254
- self.pe = pe.to(device=x.device, dtype=x.dtype)
255
-
256
- def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
257
- -> Tuple[torch.Tensor, torch.Tensor]:
258
- """Add positional encoding.
259
-
260
- Args:
261
- x (torch.Tensor): Input tensor (batch, time, `*`).
262
-
263
- Returns:
264
- torch.Tensor: Encoded tensor (batch, time, `*`).
265
-
266
- """
267
- self.extend_pe(x)
268
- x = x * self.xscale
269
- pos_emb = self.position_encoding(size=x.size(1), offset=offset)
270
- return self.dropout(x), self.dropout(pos_emb)
271
-
272
- def position_encoding(self,
273
- offset: Union[int, torch.Tensor],
274
- size: int) -> torch.Tensor:
275
- """ For getting encoding in a streaming fashion
276
-
277
- Attention!!!!!
278
- we apply dropout only once at the whole utterance level in a none
279
- streaming way, but will call this function several times with
280
- increasing input size in a streaming scenario, so the dropout will
281
- be applied several times.
282
-
283
- Args:
284
- offset (int or torch.tensor): start offset
285
- size (int): required size of position encoding
286
-
287
- Returns:
288
- torch.Tensor: Corresponding encoding
289
- """
290
- pos_emb = self.pe[
291
- :,
292
- self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
293
- ]
294
- return pos_emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/transformer/encoder_layer.py DELETED
@@ -1,236 +0,0 @@
1
- # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
- # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Modified from ESPnet(https://github.com/espnet/espnet)
16
- """Encoder self-attention layer definition."""
17
-
18
- from typing import Optional, Tuple
19
-
20
- import torch
21
- from torch import nn
22
-
23
-
24
- class TransformerEncoderLayer(nn.Module):
25
- """Encoder layer module.
26
-
27
- Args:
28
- size (int): Input dimension.
29
- self_attn (torch.nn.Module): Self-attention module instance.
30
- `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
31
- instance can be used as the argument.
32
- feed_forward (torch.nn.Module): Feed-forward module instance.
33
- `PositionwiseFeedForward`, instance can be used as the argument.
34
- dropout_rate (float): Dropout rate.
35
- normalize_before (bool):
36
- True: use layer_norm before each sub-block.
37
- False: to use layer_norm after each sub-block.
38
- """
39
-
40
- def __init__(
41
- self,
42
- size: int,
43
- self_attn: torch.nn.Module,
44
- feed_forward: torch.nn.Module,
45
- dropout_rate: float,
46
- normalize_before: bool = True,
47
- ):
48
- """Construct an EncoderLayer object."""
49
- super().__init__()
50
- self.self_attn = self_attn
51
- self.feed_forward = feed_forward
52
- self.norm1 = nn.LayerNorm(size, eps=1e-12)
53
- self.norm2 = nn.LayerNorm(size, eps=1e-12)
54
- self.dropout = nn.Dropout(dropout_rate)
55
- self.size = size
56
- self.normalize_before = normalize_before
57
-
58
- def forward(
59
- self,
60
- x: torch.Tensor,
61
- mask: torch.Tensor,
62
- pos_emb: torch.Tensor,
63
- mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
64
- att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
65
- cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
66
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
67
- """Compute encoded features.
68
-
69
- Args:
70
- x (torch.Tensor): (#batch, time, size)
71
- mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
72
- (0, 0, 0) means fake mask.
73
- pos_emb (torch.Tensor): just for interface compatibility
74
- to ConformerEncoderLayer
75
- mask_pad (torch.Tensor): does not used in transformer layer,
76
- just for unified api with conformer.
77
- att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
78
- (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
79
- cnn_cache (torch.Tensor): Convolution cache in conformer layer
80
- (#batch=1, size, cache_t2), not used here, it's for interface
81
- compatibility to ConformerEncoderLayer.
82
- Returns:
83
- torch.Tensor: Output tensor (#batch, time, size).
84
- torch.Tensor: Mask tensor (#batch, time, time).
85
- torch.Tensor: att_cache tensor,
86
- (#batch=1, head, cache_t1 + time, d_k * 2).
87
- torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
88
-
89
- """
90
- residual = x
91
- if self.normalize_before:
92
- x = self.norm1(x)
93
- x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
94
- x = residual + self.dropout(x_att)
95
- if not self.normalize_before:
96
- x = self.norm1(x)
97
-
98
- residual = x
99
- if self.normalize_before:
100
- x = self.norm2(x)
101
- x = residual + self.dropout(self.feed_forward(x))
102
- if not self.normalize_before:
103
- x = self.norm2(x)
104
-
105
- fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
106
- return x, mask, new_att_cache, fake_cnn_cache
107
-
108
-
109
- class ConformerEncoderLayer(nn.Module):
110
- """Encoder layer module.
111
- Args:
112
- size (int): Input dimension.
113
- self_attn (torch.nn.Module): Self-attention module instance.
114
- `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
115
- instance can be used as the argument.
116
- feed_forward (torch.nn.Module): Feed-forward module instance.
117
- `PositionwiseFeedForward` instance can be used as the argument.
118
- feed_forward_macaron (torch.nn.Module): Additional feed-forward module
119
- instance.
120
- `PositionwiseFeedForward` instance can be used as the argument.
121
- conv_module (torch.nn.Module): Convolution module instance.
122
- `ConvlutionModule` instance can be used as the argument.
123
- dropout_rate (float): Dropout rate.
124
- normalize_before (bool):
125
- True: use layer_norm before each sub-block.
126
- False: use layer_norm after each sub-block.
127
- """
128
-
129
- def __init__(
130
- self,
131
- size: int,
132
- self_attn: torch.nn.Module,
133
- feed_forward: Optional[nn.Module] = None,
134
- feed_forward_macaron: Optional[nn.Module] = None,
135
- conv_module: Optional[nn.Module] = None,
136
- dropout_rate: float = 0.1,
137
- normalize_before: bool = True,
138
- ):
139
- """Construct an EncoderLayer object."""
140
- super().__init__()
141
- self.self_attn = self_attn
142
- self.feed_forward = feed_forward
143
- self.feed_forward_macaron = feed_forward_macaron
144
- self.conv_module = conv_module
145
- self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
146
- self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
147
- if feed_forward_macaron is not None:
148
- self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
149
- self.ff_scale = 0.5
150
- else:
151
- self.ff_scale = 1.0
152
- if self.conv_module is not None:
153
- self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
154
- self.norm_final = nn.LayerNorm(
155
- size, eps=1e-12) # for the final output of the block
156
- self.dropout = nn.Dropout(dropout_rate)
157
- self.size = size
158
- self.normalize_before = normalize_before
159
-
160
- def forward(
161
- self,
162
- x: torch.Tensor,
163
- mask: torch.Tensor,
164
- pos_emb: torch.Tensor,
165
- mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
166
- att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
167
- cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
168
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
169
- """Compute encoded features.
170
-
171
- Args:
172
- x (torch.Tensor): (#batch, time, size)
173
- mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
174
- (0, 0, 0) means fake mask.
175
- pos_emb (torch.Tensor): positional encoding, must not be None
176
- for ConformerEncoderLayer.
177
- mask_pad (torch.Tensor): batch padding mask used for conv module.
178
- (#batch, 1,time), (0, 0, 0) means fake mask.
179
- att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
180
- (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
181
- cnn_cache (torch.Tensor): Convolution cache in conformer layer
182
- (#batch=1, size, cache_t2)
183
- Returns:
184
- torch.Tensor: Output tensor (#batch, time, size).
185
- torch.Tensor: Mask tensor (#batch, time, time).
186
- torch.Tensor: att_cache tensor,
187
- (#batch=1, head, cache_t1 + time, d_k * 2).
188
- torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
189
- """
190
-
191
- # whether to use macaron style
192
- if self.feed_forward_macaron is not None:
193
- residual = x
194
- if self.normalize_before:
195
- x = self.norm_ff_macaron(x)
196
- x = residual + self.ff_scale * self.dropout(
197
- self.feed_forward_macaron(x))
198
- if not self.normalize_before:
199
- x = self.norm_ff_macaron(x)
200
-
201
- # multi-headed self-attention module
202
- residual = x
203
- if self.normalize_before:
204
- x = self.norm_mha(x)
205
- x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
206
- att_cache)
207
- x = residual + self.dropout(x_att)
208
- if not self.normalize_before:
209
- x = self.norm_mha(x)
210
-
211
- # convolution module
212
- # Fake new cnn cache here, and then change it in conv_module
213
- new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
214
- if self.conv_module is not None:
215
- residual = x
216
- if self.normalize_before:
217
- x = self.norm_conv(x)
218
- x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
219
- x = residual + self.dropout(x)
220
-
221
- if not self.normalize_before:
222
- x = self.norm_conv(x)
223
-
224
- # feed forward module
225
- residual = x
226
- if self.normalize_before:
227
- x = self.norm_ff(x)
228
-
229
- x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
230
- if not self.normalize_before:
231
- x = self.norm_ff(x)
232
-
233
- if self.conv_module is not None:
234
- x = self.norm_final(x)
235
-
236
- return x, mask, new_att_cache, new_cnn_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py DELETED
@@ -1,115 +0,0 @@
1
- # Copyright (c) 2019 Shigeki Karita
2
- # 2020 Mobvoi Inc (Binbin Zhang)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Positionwise feed forward layer definition."""
16
-
17
- import torch
18
-
19
-
20
- class PositionwiseFeedForward(torch.nn.Module):
21
- """Positionwise feed forward layer.
22
-
23
- FeedForward are appied on each position of the sequence.
24
- The output dim is same with the input dim.
25
-
26
- Args:
27
- idim (int): Input dimenstion.
28
- hidden_units (int): The number of hidden units.
29
- dropout_rate (float): Dropout rate.
30
- activation (torch.nn.Module): Activation function
31
- """
32
-
33
- def __init__(
34
- self,
35
- idim: int,
36
- hidden_units: int,
37
- dropout_rate: float,
38
- activation: torch.nn.Module = torch.nn.ReLU(),
39
- ):
40
- """Construct a PositionwiseFeedForward object."""
41
- super(PositionwiseFeedForward, self).__init__()
42
- self.w_1 = torch.nn.Linear(idim, hidden_units)
43
- self.activation = activation
44
- self.dropout = torch.nn.Dropout(dropout_rate)
45
- self.w_2 = torch.nn.Linear(hidden_units, idim)
46
-
47
- def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
- """Forward function.
49
-
50
- Args:
51
- xs: input tensor (B, L, D)
52
- Returns:
53
- output tensor, (B, L, D)
54
- """
55
- return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
-
57
-
58
- class MoEFFNLayer(torch.nn.Module):
59
- """
60
- Mixture of expert with Positionwise feed forward layer
61
- See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62
- The output dim is same with the input dim.
63
-
64
- Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65
- https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66
- Args:
67
- n_expert: number of expert.
68
- n_expert_per_token: The actual number of experts used for each frame
69
- idim (int): Input dimenstion.
70
- hidden_units (int): The number of hidden units.
71
- dropout_rate (float): Dropout rate.
72
- activation (torch.nn.Module): Activation function
73
- """
74
-
75
- def __init__(
76
- self,
77
- n_expert: int,
78
- n_expert_per_token: int,
79
- idim: int,
80
- hidden_units: int,
81
- dropout_rate: float,
82
- activation: torch.nn.Module = torch.nn.ReLU(),
83
- ):
84
- super(MoEFFNLayer, self).__init__()
85
- self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86
- self.experts = torch.nn.ModuleList(
87
- PositionwiseFeedForward(idim, hidden_units, dropout_rate,
88
- activation) for _ in range(n_expert))
89
- self.n_expert_per_token = n_expert_per_token
90
-
91
- def forward(self, xs: torch.Tensor) -> torch.Tensor:
92
- """Foward function.
93
- Args:
94
- xs: input tensor (B, L, D)
95
- Returns:
96
- output tensor, (B, L, D)
97
-
98
- """
99
- B, L, D = xs.size(
100
- ) # batch size, sequence length, embedding dimension (idim)
101
- xs = xs.view(-1, D) # (B*L, D)
102
- router = self.gate(xs) # (B*L, n_expert)
103
- logits, indices = torch.topk(
104
- router, self.n_expert_per_token
105
- ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106
- weights = torch.nn.functional.softmax(
107
- logits, dim=1,
108
- dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
109
- output = torch.zeros_like(xs) # (B*L, D)
110
- for i, expert in enumerate(self.experts):
111
- mask = indices == i
112
- batch_idx, ith_expert = torch.where(mask)
113
- output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114
- xs[batch_idx])
115
- return output.view(B, L, D)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/transformer/subsampling.py DELETED
@@ -1,383 +0,0 @@
1
- # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
- # 2024 Alibaba Inc (Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Modified from ESPnet(https://github.com/espnet/espnet)
16
- """Subsampling layer definition."""
17
-
18
- from typing import Tuple, Union
19
-
20
- import torch
21
-
22
-
23
- class BaseSubsampling(torch.nn.Module):
24
-
25
- def __init__(self):
26
- super().__init__()
27
- self.right_context = 0
28
- self.subsampling_rate = 1
29
-
30
- def position_encoding(self, offset: Union[int, torch.Tensor],
31
- size: int) -> torch.Tensor:
32
- return self.pos_enc.position_encoding(offset, size)
33
-
34
-
35
- class EmbedinigNoSubsampling(BaseSubsampling):
36
- """Embedding input without subsampling
37
- """
38
-
39
- def __init__(self, idim: int, odim: int, dropout_rate: float,
40
- pos_enc_class: torch.nn.Module):
41
- super().__init__()
42
- self.embed = torch.nn.Embedding(idim, odim)
43
- self.pos_enc = pos_enc_class
44
-
45
- def forward(
46
- self,
47
- x: torch.Tensor,
48
- x_mask: torch.Tensor,
49
- offset: Union[int, torch.Tensor] = 0
50
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
- """Input x.
52
-
53
- Args:
54
- x (torch.Tensor): Input tensor (#batch, time, idim).
55
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
56
-
57
- Returns:
58
- torch.Tensor: linear input tensor (#batch, time', odim),
59
- where time' = time .
60
- torch.Tensor: linear input mask (#batch, 1, time'),
61
- where time' = time .
62
-
63
- """
64
- x = self.embed(x)
65
- x, pos_emb = self.pos_enc(x, offset)
66
- return x, pos_emb, x_mask
67
-
68
-
69
- class LinearNoSubsampling(BaseSubsampling):
70
- """Linear transform the input without subsampling
71
-
72
- Args:
73
- idim (int): Input dimension.
74
- odim (int): Output dimension.
75
- dropout_rate (float): Dropout rate.
76
-
77
- """
78
-
79
- def __init__(self, idim: int, odim: int, dropout_rate: float,
80
- pos_enc_class: torch.nn.Module):
81
- """Construct an linear object."""
82
- super().__init__()
83
- self.out = torch.nn.Sequential(
84
- torch.nn.Linear(idim, odim),
85
- torch.nn.LayerNorm(odim, eps=1e-5),
86
- torch.nn.Dropout(dropout_rate),
87
- )
88
- self.pos_enc = pos_enc_class
89
- self.right_context = 0
90
- self.subsampling_rate = 1
91
-
92
- def forward(
93
- self,
94
- x: torch.Tensor,
95
- x_mask: torch.Tensor,
96
- offset: Union[int, torch.Tensor] = 0
97
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
- """Input x.
99
-
100
- Args:
101
- x (torch.Tensor): Input tensor (#batch, time, idim).
102
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
103
-
104
- Returns:
105
- torch.Tensor: linear input tensor (#batch, time', odim),
106
- where time' = time .
107
- torch.Tensor: linear input mask (#batch, 1, time'),
108
- where time' = time .
109
-
110
- """
111
- x = self.out(x)
112
- x, pos_emb = self.pos_enc(x, offset)
113
- return x, pos_emb, x_mask
114
-
115
-
116
- class Conv1dSubsampling2(BaseSubsampling):
117
- """Convolutional 1D subsampling (to 1/2 length).
118
- It is designed for Whisper, ref:
119
- https://github.com/openai/whisper/blob/main/whisper/model.py
120
-
121
- Args:
122
- idim (int): Input dimension.
123
- odim (int): Output dimension.
124
- dropout_rate (float): Dropout rate.
125
-
126
- """
127
-
128
- def __init__(self, idim: int, odim: int, dropout_rate: float,
129
- pos_enc_class: torch.nn.Module):
130
- """Construct an Conv1dSubsampling2 object."""
131
- super().__init__()
132
- self.conv = torch.nn.Sequential(
133
- torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
134
- torch.nn.GELU(),
135
- torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
136
- torch.nn.GELU(),
137
- )
138
- self.pos_enc = pos_enc_class
139
- # The right context for every conv layer is computed by:
140
- # (kernel_size - 1) * frame_rate_of_this_layer
141
- self.subsampling_rate = 2
142
- # 4 = (3 - 1) * 1 + (3 - 1) * 1
143
- self.right_context = 4
144
-
145
- def forward(
146
- self,
147
- x: torch.Tensor,
148
- x_mask: torch.Tensor,
149
- offset: Union[int, torch.Tensor] = 0
150
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
151
- """Subsample x.
152
-
153
- Args:
154
- x (torch.Tensor): Input tensor (#batch, time, idim).
155
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
156
-
157
- Returns:
158
- torch.Tensor: Subsampled tensor (#batch, time', odim),
159
- where time' = time // 2.
160
- torch.Tensor: Subsampled mask (#batch, 1, time'),
161
- where time' = time // 2.
162
- torch.Tensor: positional encoding
163
-
164
- """
165
- time = x.size(1)
166
- x = x.transpose(1, 2) # (b, f, t)
167
- x = self.conv(x)
168
- x = x.transpose(1, 2) # (b, t, f)
169
- x, pos_emb = self.pos_enc(x, offset)
170
- return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
171
-
172
-
173
- class Conv2dSubsampling4(BaseSubsampling):
174
- """Convolutional 2D subsampling (to 1/4 length).
175
-
176
- Args:
177
- idim (int): Input dimension.
178
- odim (int): Output dimension.
179
- dropout_rate (float): Dropout rate.
180
-
181
- """
182
-
183
- def __init__(self, idim: int, odim: int, dropout_rate: float,
184
- pos_enc_class: torch.nn.Module):
185
- """Construct an Conv2dSubsampling4 object."""
186
- super().__init__()
187
- self.conv = torch.nn.Sequential(
188
- torch.nn.Conv2d(1, odim, 3, 2),
189
- torch.nn.ReLU(),
190
- torch.nn.Conv2d(odim, odim, 3, 2),
191
- torch.nn.ReLU(),
192
- )
193
- self.out = torch.nn.Sequential(
194
- torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
195
- self.pos_enc = pos_enc_class
196
- # The right context for every conv layer is computed by:
197
- # (kernel_size - 1) * frame_rate_of_this_layer
198
- self.subsampling_rate = 4
199
- # 6 = (3 - 1) * 1 + (3 - 1) * 2
200
- self.right_context = 6
201
-
202
- def forward(
203
- self,
204
- x: torch.Tensor,
205
- x_mask: torch.Tensor,
206
- offset: Union[int, torch.Tensor] = 0
207
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208
- """Subsample x.
209
-
210
- Args:
211
- x (torch.Tensor): Input tensor (#batch, time, idim).
212
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
213
-
214
- Returns:
215
- torch.Tensor: Subsampled tensor (#batch, time', odim),
216
- where time' = time // 4.
217
- torch.Tensor: Subsampled mask (#batch, 1, time'),
218
- where time' = time // 4.
219
- torch.Tensor: positional encoding
220
-
221
- """
222
- x = x.unsqueeze(1) # (b, c=1, t, f)
223
- x = self.conv(x)
224
- b, c, t, f = x.size()
225
- x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
226
- x, pos_emb = self.pos_enc(x, offset)
227
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
228
-
229
-
230
- class Conv2dSubsampling6(BaseSubsampling):
231
- """Convolutional 2D subsampling (to 1/6 length).
232
- Args:
233
- idim (int): Input dimension.
234
- odim (int): Output dimension.
235
- dropout_rate (float): Dropout rate.
236
- pos_enc (torch.nn.Module): Custom position encoding layer.
237
- """
238
-
239
- def __init__(self, idim: int, odim: int, dropout_rate: float,
240
- pos_enc_class: torch.nn.Module):
241
- """Construct an Conv2dSubsampling6 object."""
242
- super().__init__()
243
- self.conv = torch.nn.Sequential(
244
- torch.nn.Conv2d(1, odim, 3, 2),
245
- torch.nn.ReLU(),
246
- torch.nn.Conv2d(odim, odim, 5, 3),
247
- torch.nn.ReLU(),
248
- )
249
- self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
250
- odim)
251
- self.pos_enc = pos_enc_class
252
- # 10 = (3 - 1) * 1 + (5 - 1) * 2
253
- self.subsampling_rate = 6
254
- self.right_context = 10
255
-
256
- def forward(
257
- self,
258
- x: torch.Tensor,
259
- x_mask: torch.Tensor,
260
- offset: Union[int, torch.Tensor] = 0
261
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
262
- """Subsample x.
263
- Args:
264
- x (torch.Tensor): Input tensor (#batch, time, idim).
265
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
266
-
267
- Returns:
268
- torch.Tensor: Subsampled tensor (#batch, time', odim),
269
- where time' = time // 6.
270
- torch.Tensor: Subsampled mask (#batch, 1, time'),
271
- where time' = time // 6.
272
- torch.Tensor: positional encoding
273
- """
274
- x = x.unsqueeze(1) # (b, c, t, f)
275
- x = self.conv(x)
276
- b, c, t, f = x.size()
277
- x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
278
- x, pos_emb = self.pos_enc(x, offset)
279
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
280
-
281
-
282
- class Conv2dSubsampling8(BaseSubsampling):
283
- """Convolutional 2D subsampling (to 1/8 length).
284
-
285
- Args:
286
- idim (int): Input dimension.
287
- odim (int): Output dimension.
288
- dropout_rate (float): Dropout rate.
289
-
290
- """
291
-
292
- def __init__(self, idim: int, odim: int, dropout_rate: float,
293
- pos_enc_class: torch.nn.Module):
294
- """Construct an Conv2dSubsampling8 object."""
295
- super().__init__()
296
- self.conv = torch.nn.Sequential(
297
- torch.nn.Conv2d(1, odim, 3, 2),
298
- torch.nn.ReLU(),
299
- torch.nn.Conv2d(odim, odim, 3, 2),
300
- torch.nn.ReLU(),
301
- torch.nn.Conv2d(odim, odim, 3, 2),
302
- torch.nn.ReLU(),
303
- )
304
- self.linear = torch.nn.Linear(
305
- odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
306
- self.pos_enc = pos_enc_class
307
- self.subsampling_rate = 8
308
- # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
309
- self.right_context = 14
310
-
311
- def forward(
312
- self,
313
- x: torch.Tensor,
314
- x_mask: torch.Tensor,
315
- offset: Union[int, torch.Tensor] = 0
316
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
317
- """Subsample x.
318
-
319
- Args:
320
- x (torch.Tensor): Input tensor (#batch, time, idim).
321
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
322
-
323
- Returns:
324
- torch.Tensor: Subsampled tensor (#batch, time', odim),
325
- where time' = time // 8.
326
- torch.Tensor: Subsampled mask (#batch, 1, time'),
327
- where time' = time // 8.
328
- torch.Tensor: positional encoding
329
- """
330
- x = x.unsqueeze(1) # (b, c, t, f)
331
- x = self.conv(x)
332
- b, c, t, f = x.size()
333
- x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
334
- x, pos_emb = self.pos_enc(x, offset)
335
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
336
-
337
-
338
- class LegacyLinearNoSubsampling(BaseSubsampling):
339
- """Linear transform the input without subsampling
340
-
341
- Args:
342
- idim (int): Input dimension.
343
- odim (int): Output dimension.
344
- dropout_rate (float): Dropout rate.
345
-
346
- """
347
-
348
- def __init__(self, idim: int, odim: int, dropout_rate: float,
349
- pos_enc_class: torch.nn.Module):
350
- """Construct an linear object."""
351
- super().__init__()
352
- self.out = torch.nn.Sequential(
353
- torch.nn.Linear(idim, odim),
354
- torch.nn.LayerNorm(odim, eps=1e-5),
355
- torch.nn.Dropout(dropout_rate),
356
- torch.nn.ReLU(),
357
- )
358
- self.pos_enc = pos_enc_class
359
- self.right_context = 0
360
- self.subsampling_rate = 1
361
-
362
- def forward(
363
- self,
364
- x: torch.Tensor,
365
- x_mask: torch.Tensor,
366
- offset: Union[int, torch.Tensor] = 0
367
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
368
- """Input x.
369
-
370
- Args:
371
- x (torch.Tensor): Input tensor (#batch, time, idim).
372
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
373
-
374
- Returns:
375
- torch.Tensor: linear input tensor (#batch, time', odim),
376
- where time' = time .
377
- torch.Tensor: linear input mask (#batch, 1, time'),
378
- where time' = time .
379
-
380
- """
381
- x = self.out(x)
382
- x, pos_emb = self.pos_enc(x, offset)
383
- return x, pos_emb, x_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/transformer/upsample_encoder.py DELETED
@@ -1,318 +0,0 @@
1
- # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
- # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
3
- # 2024 Alibaba Inc (Xiang Lyu)
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- # Modified from ESPnet(https://github.com/espnet/espnet)
17
- """Encoder definition."""
18
- from typing import Tuple
19
-
20
- import torch
21
- from torch import nn
22
- from torch.nn import functional as F
23
-
24
- from .convolution import ConvolutionModule
25
- from .encoder_layer import ConformerEncoderLayer
26
- from .positionwise_feed_forward import PositionwiseFeedForward
27
- from ..utils.class_utils import (
28
- COSYVOICE_EMB_CLASSES,
29
- COSYVOICE_SUBSAMPLE_CLASSES,
30
- COSYVOICE_ATTENTION_CLASSES,
31
- COSYVOICE_ACTIVATION_CLASSES,
32
- )
33
- from ..utils.mask import make_pad_mask
34
- from ..utils.mask import add_optional_chunk_mask
35
-
36
-
37
- class Upsample1D(nn.Module):
38
- """A 1D upsampling layer with an optional convolution.
39
-
40
- Parameters:
41
- channels (`int`):
42
- number of channels in the inputs and outputs.
43
- use_conv (`bool`, default `False`):
44
- option to use a convolution.
45
- use_conv_transpose (`bool`, default `False`):
46
- option to use a convolution transpose.
47
- out_channels (`int`, optional):
48
- number of output channels. Defaults to `channels`.
49
- """
50
-
51
- def __init__(self, channels: int, out_channels: int, stride: int = 2):
52
- super().__init__()
53
- self.channels = channels
54
- self.out_channels = out_channels
55
- self.stride = stride
56
- # In this mode, first repeat interpolate, than conv with stride=1
57
- self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
58
-
59
- def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
60
- outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
61
- outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
62
- outputs = self.conv(outputs)
63
- return outputs, input_lengths * self.stride
64
-
65
-
66
- class PreLookaheadLayer(nn.Module):
67
- def __init__(self, channels: int, pre_lookahead_len: int = 1):
68
- super().__init__()
69
- self.channels = channels
70
- self.pre_lookahead_len = pre_lookahead_len
71
- self.conv1 = nn.Conv1d(
72
- channels, channels,
73
- kernel_size=pre_lookahead_len + 1,
74
- stride=1, padding=0,
75
- )
76
- self.conv2 = nn.Conv1d(
77
- channels, channels,
78
- kernel_size=3, stride=1, padding=0,
79
- )
80
-
81
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
82
- """
83
- inputs: (batch_size, seq_len, channels)
84
- """
85
- outputs = inputs.transpose(1, 2).contiguous()
86
- # look ahead
87
- outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
88
- outputs = F.leaky_relu(self.conv1(outputs))
89
- # outputs
90
- outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
91
- outputs = self.conv2(outputs)
92
- outputs = outputs.transpose(1, 2).contiguous()
93
-
94
- # residual connection
95
- outputs = outputs + inputs
96
- return outputs
97
-
98
-
99
- class UpsampleConformerEncoder(torch.nn.Module):
100
-
101
- def __init__(
102
- self,
103
- input_size: int = 512,
104
- output_size: int = 512,
105
- attention_heads: int = 8,
106
- linear_units: int = 2048,
107
- num_blocks: int = 6,
108
- dropout_rate: float = 0.1,
109
- positional_dropout_rate: float = 0.1,
110
- attention_dropout_rate: float = 0.1,
111
- input_layer: str = "linear",
112
- pos_enc_layer_type: str = "rel_pos_espnet",
113
- normalize_before: bool = True,
114
- static_chunk_size: int = 0,
115
- use_dynamic_chunk: bool = False,
116
- global_cmvn: torch.nn.Module = None,
117
- use_dynamic_left_chunk: bool = False,
118
- positionwise_conv_kernel_size: int = 1,
119
- macaron_style: bool = False,
120
- selfattention_layer_type: str = "rel_selfattn",
121
- activation_type: str = "swish",
122
- use_cnn_module: bool = False,
123
- cnn_module_kernel: int = 15,
124
- causal: bool = False,
125
- cnn_module_norm: str = "batch_norm",
126
- key_bias: bool = True,
127
- gradient_checkpointing: bool = False,
128
- ):
129
- """
130
- Args:
131
- input_size (int): input dim
132
- output_size (int): dimension of attention
133
- attention_heads (int): the number of heads of multi head attention
134
- linear_units (int): the hidden units number of position-wise feed
135
- forward
136
- num_blocks (int): the number of decoder blocks
137
- dropout_rate (float): dropout rate
138
- attention_dropout_rate (float): dropout rate in attention
139
- positional_dropout_rate (float): dropout rate after adding
140
- positional encoding
141
- input_layer (str): input layer type.
142
- optional [linear, conv2d, conv2d6, conv2d8]
143
- pos_enc_layer_type (str): Encoder positional encoding layer type.
144
- opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
145
- normalize_before (bool):
146
- True: use layer_norm before each sub-block of a layer.
147
- False: use layer_norm after each sub-block of a layer.
148
- static_chunk_size (int): chunk size for static chunk training and
149
- decoding
150
- use_dynamic_chunk (bool): whether use dynamic chunk size for
151
- training or not, You can only use fixed chunk(chunk_size > 0)
152
- or dyanmic chunk size(use_dynamic_chunk = True)
153
- global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
154
- use_dynamic_left_chunk (bool): whether use dynamic left chunk in
155
- dynamic chunk training
156
- key_bias: whether use bias in attention.linear_k, False for whisper models.
157
- gradient_checkpointing: rerunning a forward-pass segment for each
158
- checkpointed segment during backward.
159
- """
160
- super().__init__()
161
- self._output_size = output_size
162
-
163
- self.global_cmvn = global_cmvn
164
- self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
165
- input_size,
166
- output_size,
167
- dropout_rate,
168
- COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
169
- positional_dropout_rate),
170
- )
171
-
172
- self.normalize_before = normalize_before
173
- self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
174
- self.static_chunk_size = static_chunk_size
175
- self.use_dynamic_chunk = use_dynamic_chunk
176
- self.use_dynamic_left_chunk = use_dynamic_left_chunk
177
- self.gradient_checkpointing = gradient_checkpointing
178
- activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
179
- # self-attention module definition
180
- encoder_selfattn_layer_args = (
181
- attention_heads,
182
- output_size,
183
- attention_dropout_rate,
184
- key_bias,
185
- )
186
- # feed-forward module definition
187
- positionwise_layer_args = (
188
- output_size,
189
- linear_units,
190
- dropout_rate,
191
- activation,
192
- )
193
- # convolution module definition
194
- convolution_layer_args = (output_size, cnn_module_kernel, activation,
195
- cnn_module_norm, causal)
196
- self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
197
- self.encoders = torch.nn.ModuleList([
198
- ConformerEncoderLayer(
199
- output_size,
200
- COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
201
- *encoder_selfattn_layer_args),
202
- PositionwiseFeedForward(*positionwise_layer_args),
203
- PositionwiseFeedForward(
204
- *positionwise_layer_args) if macaron_style else None,
205
- ConvolutionModule(
206
- *convolution_layer_args) if use_cnn_module else None,
207
- dropout_rate,
208
- normalize_before,
209
- ) for _ in range(num_blocks)
210
- ])
211
- self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
212
- self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
213
- input_size,
214
- output_size,
215
- dropout_rate,
216
- COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
217
- positional_dropout_rate),
218
- )
219
- self.up_encoders = torch.nn.ModuleList([
220
- ConformerEncoderLayer(
221
- output_size,
222
- COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
223
- *encoder_selfattn_layer_args),
224
- PositionwiseFeedForward(*positionwise_layer_args),
225
- PositionwiseFeedForward(
226
- *positionwise_layer_args) if macaron_style else None,
227
- ConvolutionModule(
228
- *convolution_layer_args) if use_cnn_module else None,
229
- dropout_rate,
230
- normalize_before,
231
- ) for _ in range(4)
232
- ])
233
-
234
- def output_size(self) -> int:
235
- return self._output_size
236
-
237
- def forward(
238
- self,
239
- xs: torch.Tensor,
240
- xs_lens: torch.Tensor,
241
- decoding_chunk_size: int = 0,
242
- num_decoding_left_chunks: int = -1,
243
- ) -> Tuple[torch.Tensor, torch.Tensor]:
244
- """Embed positions in tensor.
245
-
246
- Args:
247
- xs: padded input tensor (B, T, D)
248
- xs_lens: input length (B)
249
- decoding_chunk_size: decoding chunk size for dynamic chunk
250
- 0: default for training, use random dynamic chunk.
251
- <0: for decoding, use full chunk.
252
- >0: for decoding, use fixed chunk size as set.
253
- num_decoding_left_chunks: number of left chunks, this is for decoding,
254
- the chunk size is decoding_chunk_size.
255
- >=0: use num_decoding_left_chunks
256
- <0: use all left chunks
257
- Returns:
258
- encoder output tensor xs, and subsampled masks
259
- xs: padded output tensor (B, T' ~= T/subsample_rate, D)
260
- masks: torch.Tensor batch padding mask after subsample
261
- (B, 1, T' ~= T/subsample_rate)
262
- NOTE(xcsong):
263
- We pass the `__call__` method of the modules instead of `forward` to the
264
- checkpointing API because `__call__` attaches all the hooks of the module.
265
- https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
266
- """
267
- T = xs.size(1)
268
- masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
269
- if self.global_cmvn is not None:
270
- xs = self.global_cmvn(xs)
271
- xs, pos_emb, masks = self.embed(xs, masks)
272
- mask_pad = masks # (B, 1, T/subsample_rate)
273
- chunk_masks = add_optional_chunk_mask(xs, masks,
274
- self.use_dynamic_chunk,
275
- self.use_dynamic_left_chunk,
276
- decoding_chunk_size,
277
- self.static_chunk_size,
278
- num_decoding_left_chunks)
279
- # lookahead + conformer encoder
280
- xs = self.pre_lookahead_layer(xs)
281
- xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
282
-
283
- # upsample + conformer encoder
284
- xs = xs.transpose(1, 2).contiguous()
285
- xs, xs_lens = self.up_layer(xs, xs_lens)
286
- xs = xs.transpose(1, 2).contiguous()
287
- T = xs.size(1)
288
- masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
289
- xs, pos_emb, masks = self.up_embed(xs, masks)
290
- mask_pad = masks # (B, 1, T/subsample_rate)
291
- chunk_masks = add_optional_chunk_mask(xs, masks,
292
- self.use_dynamic_chunk,
293
- self.use_dynamic_left_chunk,
294
- decoding_chunk_size,
295
- self.static_chunk_size * self.up_layer.stride,
296
- num_decoding_left_chunks)
297
- xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
298
-
299
- if self.normalize_before:
300
- xs = self.after_norm(xs)
301
- # Here we assume the mask is not changed in encoder layers, so just
302
- # return the masks before encoder layers, and the masks will be used
303
- # for cross attention with decoder later
304
- return xs, masks
305
-
306
- def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
307
- pos_emb: torch.Tensor,
308
- mask_pad: torch.Tensor) -> torch.Tensor:
309
- for layer in self.encoders:
310
- xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
311
- return xs
312
-
313
- def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
314
- pos_emb: torch.Tensor,
315
- mask_pad: torch.Tensor) -> torch.Tensor:
316
- for layer in self.up_encoders:
317
- xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
318
- return xs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/utils/class_utils.py DELETED
@@ -1,71 +0,0 @@
1
- # Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
2
- # 2024 Alibaba Inc (authors: Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import torch
16
-
17
- from ..transformer.activation import Swish
18
- from ..transformer.subsampling import (
19
- LinearNoSubsampling,
20
- EmbedinigNoSubsampling,
21
- Conv1dSubsampling2,
22
- Conv2dSubsampling4,
23
- Conv2dSubsampling6,
24
- Conv2dSubsampling8,
25
- )
26
- from ..transformer.embedding import (
27
- PositionalEncoding,
28
- RelPositionalEncoding,
29
- WhisperPositionalEncoding,
30
- LearnablePositionalEncoding,
31
- NoPositionalEncoding)
32
- from ..transformer.attention import (MultiHeadedAttention,
33
- RelPositionMultiHeadedAttention)
34
- from ..transformer.embedding import EspnetRelPositionalEncoding
35
- from ..transformer.subsampling import LegacyLinearNoSubsampling
36
-
37
-
38
- COSYVOICE_ACTIVATION_CLASSES = {
39
- "hardtanh": torch.nn.Hardtanh,
40
- "tanh": torch.nn.Tanh,
41
- "relu": torch.nn.ReLU,
42
- "selu": torch.nn.SELU,
43
- "swish": getattr(torch.nn, "SiLU", Swish),
44
- "gelu": torch.nn.GELU,
45
- }
46
-
47
- COSYVOICE_SUBSAMPLE_CLASSES = {
48
- "linear": LinearNoSubsampling,
49
- "linear_legacy": LegacyLinearNoSubsampling,
50
- "embed": EmbedinigNoSubsampling,
51
- "conv1d2": Conv1dSubsampling2,
52
- "conv2d": Conv2dSubsampling4,
53
- "conv2d6": Conv2dSubsampling6,
54
- "conv2d8": Conv2dSubsampling8,
55
- 'paraformer_dummy': torch.nn.Identity
56
- }
57
-
58
- COSYVOICE_EMB_CLASSES = {
59
- "embed": PositionalEncoding,
60
- "abs_pos": PositionalEncoding,
61
- "rel_pos": RelPositionalEncoding,
62
- "rel_pos_espnet": EspnetRelPositionalEncoding,
63
- "no_pos": NoPositionalEncoding,
64
- "abs_pos_whisper": WhisperPositionalEncoding,
65
- "embed_learnable_pe": LearnablePositionalEncoding,
66
- }
67
-
68
- COSYVOICE_ATTENTION_CLASSES = {
69
- "selfattn": MultiHeadedAttention,
70
- "rel_selfattn": RelPositionMultiHeadedAttention,
71
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/utils/mask.py DELETED
@@ -1,193 +0,0 @@
1
- # Copyright (c) 2019 Shigeki Karita
2
- # 2020 Mobvoi Inc (Binbin Zhang)
3
- # 2024 Alibaba Inc (authors: Xiang Lyu)
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import torch
18
-
19
- '''
20
- def subsequent_mask(
21
- size: int,
22
- device: torch.device = torch.device("cpu"),
23
- ) -> torch.Tensor:
24
- """Create mask for subsequent steps (size, size).
25
-
26
- This mask is used only in decoder which works in an auto-regressive mode.
27
- This means the current step could only do attention with its left steps.
28
-
29
- In encoder, fully attention is used when streaming is not necessary and
30
- the sequence is not long. In this case, no attention mask is needed.
31
-
32
- When streaming is need, chunk-based attention is used in encoder. See
33
- subsequent_chunk_mask for the chunk-based attention mask.
34
-
35
- Args:
36
- size (int): size of mask
37
- str device (str): "cpu" or "cuda" or torch.Tensor.device
38
- dtype (torch.device): result dtype
39
-
40
- Returns:
41
- torch.Tensor: mask
42
-
43
- Examples:
44
- >>> subsequent_mask(3)
45
- [[1, 0, 0],
46
- [1, 1, 0],
47
- [1, 1, 1]]
48
- """
49
- ret = torch.ones(size, size, device=device, dtype=torch.bool)
50
- return torch.tril(ret)
51
- '''
52
-
53
-
54
- def subsequent_chunk_mask(
55
- size: int,
56
- chunk_size: int,
57
- num_left_chunks: int = -1,
58
- device: torch.device = torch.device("cpu"),
59
- ) -> torch.Tensor:
60
- """Create mask for subsequent steps (size, size) with chunk size,
61
- this is for streaming encoder
62
-
63
- Args:
64
- size (int): size of mask
65
- chunk_size (int): size of chunk
66
- num_left_chunks (int): number of left chunks
67
- <0: use full chunk
68
- >=0: use num_left_chunks
69
- device (torch.device): "cpu" or "cuda" or torch.Tensor.device
70
-
71
- Returns:
72
- torch.Tensor: mask
73
-
74
- Examples:
75
- >>> subsequent_chunk_mask(4, 2)
76
- [[1, 1, 0, 0],
77
- [1, 1, 0, 0],
78
- [1, 1, 1, 1],
79
- [1, 1, 1, 1]]
80
- """
81
- # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
82
- # actually this is not needed after we have inference cache implemented, will remove it later
83
- pos_idx = torch.arange(size, device=device)
84
- block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
85
- ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
86
- return ret
87
-
88
-
89
- def add_optional_chunk_mask(xs: torch.Tensor,
90
- masks: torch.Tensor,
91
- use_dynamic_chunk: bool,
92
- use_dynamic_left_chunk: bool,
93
- decoding_chunk_size: int,
94
- static_chunk_size: int,
95
- num_decoding_left_chunks: int,
96
- enable_full_context: bool = True):
97
- """ Apply optional mask for encoder.
98
-
99
- Args:
100
- xs (torch.Tensor): padded input, (B, L, D), L for max length
101
- mask (torch.Tensor): mask for xs, (B, 1, L)
102
- use_dynamic_chunk (bool): whether to use dynamic chunk or not
103
- use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
104
- training.
105
- decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
106
- 0: default for training, use random dynamic chunk.
107
- <0: for decoding, use full chunk.
108
- >0: for decoding, use fixed chunk size as set.
109
- static_chunk_size (int): chunk size for static chunk training/decoding
110
- if it's greater than 0, if use_dynamic_chunk is true,
111
- this parameter will be ignored
112
- num_decoding_left_chunks: number of left chunks, this is for decoding,
113
- the chunk size is decoding_chunk_size.
114
- >=0: use num_decoding_left_chunks
115
- <0: use all left chunks
116
- enable_full_context (bool):
117
- True: chunk size is either [1, 25] or full context(max_len)
118
- False: chunk size ~ U[1, 25]
119
-
120
- Returns:
121
- torch.Tensor: chunk mask of the input xs.
122
- """
123
- # Whether to use chunk mask or not
124
- if use_dynamic_chunk:
125
- max_len = xs.size(1)
126
- if decoding_chunk_size < 0:
127
- chunk_size = max_len
128
- num_left_chunks = -1
129
- elif decoding_chunk_size > 0:
130
- chunk_size = decoding_chunk_size
131
- num_left_chunks = num_decoding_left_chunks
132
- else:
133
- # chunk size is either [1, 25] or full context(max_len).
134
- # Since we use 4 times subsampling and allow up to 1s(100 frames)
135
- # delay, the maximum frame is 100 / 4 = 25.
136
- chunk_size = torch.randint(1, max_len, (1, )).item()
137
- num_left_chunks = -1
138
- if chunk_size > max_len // 2 and enable_full_context:
139
- chunk_size = max_len
140
- else:
141
- chunk_size = chunk_size % 25 + 1
142
- if use_dynamic_left_chunk:
143
- max_left_chunks = (max_len - 1) // chunk_size
144
- num_left_chunks = torch.randint(0, max_left_chunks,
145
- (1, )).item()
146
- chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
147
- num_left_chunks,
148
- xs.device) # (L, L)
149
- chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
150
- chunk_masks = masks & chunk_masks # (B, L, L)
151
- elif static_chunk_size > 0:
152
- num_left_chunks = num_decoding_left_chunks
153
- chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
154
- num_left_chunks,
155
- xs.device) # (L, L)
156
- chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
157
- chunk_masks = masks & chunk_masks # (B, L, L)
158
- else:
159
- chunk_masks = masks
160
- assert chunk_masks.dtype == torch.bool
161
- if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
162
- logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
163
- chunk_masks[chunk_masks.sum(dim=-1)==0] = True
164
- return chunk_masks
165
-
166
-
167
- def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
168
- """Make mask tensor containing indices of padded part.
169
-
170
- See description of make_non_pad_mask.
171
-
172
- Args:
173
- lengths (torch.Tensor): Batch of lengths (B,).
174
- Returns:
175
- torch.Tensor: Mask tensor containing indices of padded part.
176
-
177
- Examples:
178
- >>> lengths = [5, 3, 2]
179
- >>> make_pad_mask(lengths)
180
- masks = [[0, 0, 0, 0 ,0],
181
- [0, 0, 0, 1, 1],
182
- [0, 0, 1, 1, 1]]
183
- """
184
- batch_size = lengths.size(0)
185
- max_len = max_len if max_len > 0 else lengths.max().item()
186
- seq_range = torch.arange(0,
187
- max_len,
188
- dtype=torch.int64,
189
- device=lengths.device)
190
- seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
191
- seq_length_expand = lengths.unsqueeze(-1)
192
- mask = seq_range_expand >= seq_length_expand
193
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/utils/mel.py DELETED
@@ -1,85 +0,0 @@
1
- """mel-spectrogram extraction in Matcha-TTS"""
2
- import logging
3
- from librosa.filters import mel as librosa_mel_fn
4
- import torch
5
- import numpy as np
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
-
10
- # NOTE: they decalred these global vars
11
- mel_basis = {}
12
- hann_window = {}
13
-
14
-
15
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
16
- return torch.log(torch.clamp(x, min=clip_val) * C)
17
-
18
-
19
- def spectral_normalize_torch(magnitudes):
20
- output = dynamic_range_compression_torch(magnitudes)
21
- return output
22
-
23
- """
24
- feat_extractor: !name:matcha.utils.audio.mel_spectrogram
25
- n_fft: 1920
26
- num_mels: 80
27
- sampling_rate: 24000
28
- hop_size: 480
29
- win_size: 1920
30
- fmin: 0
31
- fmax: 8000
32
- center: False
33
-
34
- """
35
-
36
- def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920,
37
- fmin=0, fmax=8000, center=False):
38
- """Copied from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py
39
- Set default values according to Cosyvoice's config.
40
- """
41
-
42
- if isinstance(y, np.ndarray):
43
- y = torch.tensor(y).float()
44
-
45
- if len(y.shape) == 1:
46
- y = y[None, ]
47
-
48
- # Debug: Check for audio clipping (values outside [-1.0, 1.0] range)
49
- min_val = torch.min(y)
50
- max_val = torch.max(y)
51
- if min_val < -1.0 or max_val > 1.0:
52
- logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}")
53
-
54
- global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
55
- if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
56
- mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
57
- mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
58
- hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
59
-
60
- y = torch.nn.functional.pad(
61
- y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
62
- )
63
- y = y.squeeze(1)
64
-
65
- spec = torch.view_as_real(
66
- torch.stft(
67
- y,
68
- n_fft,
69
- hop_length=hop_size,
70
- win_length=win_size,
71
- window=hann_window[str(y.device)],
72
- center=center,
73
- pad_mode="reflect",
74
- normalized=False,
75
- onesided=True,
76
- return_complex=True,
77
- )
78
- )
79
-
80
- spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
81
-
82
- spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
83
- spec = spectral_normalize_torch(spec)
84
-
85
- return spec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3gen/xvector.py DELETED
@@ -1,428 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- encoding: utf-8 -*-
3
- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
4
- # MIT License (https://opensource.org/licenses/MIT)
5
- # Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
6
-
7
-
8
- from collections import OrderedDict
9
- import torch
10
- import torch.nn.functional as F
11
- import torch.utils.checkpoint as cp
12
- import torchaudio.compliance.kaldi as Kaldi
13
-
14
-
15
- def pad_list(xs, pad_value):
16
- """Perform padding for the list of tensors.
17
-
18
- Args:
19
- xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
20
- pad_value (float): Value for padding.
21
-
22
- Returns:
23
- Tensor: Padded tensor (B, Tmax, `*`).
24
-
25
- Examples:
26
- >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
27
- >>> x
28
- [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
29
- >>> pad_list(x, 0)
30
- tensor([[1., 1., 1., 1.],
31
- [1., 1., 0., 0.],
32
- [1., 0., 0., 0.]])
33
-
34
- """
35
- n_batch = len(xs)
36
- max_len = max(x.size(0) for x in xs)
37
- pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
38
-
39
- for i in range(n_batch):
40
- pad[i, : xs[i].size(0)] = xs[i]
41
-
42
- return pad
43
-
44
-
45
- def extract_feature(audio):
46
- features = []
47
- feature_times = []
48
- feature_lengths = []
49
- for au in audio:
50
- feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80)
51
- feature = feature - feature.mean(dim=0, keepdim=True)
52
- features.append(feature)
53
- feature_times.append(au.shape[0])
54
- feature_lengths.append(feature.shape[0])
55
- # padding for batch inference
56
- features_padded = pad_list(features, pad_value=0)
57
- # features = torch.cat(features)
58
- return features_padded, feature_lengths, feature_times
59
-
60
-
61
- class BasicResBlock(torch.nn.Module):
62
- expansion = 1
63
-
64
- def __init__(self, in_planes, planes, stride=1):
65
- super(BasicResBlock, self).__init__()
66
- self.conv1 = torch.nn.Conv2d(
67
- in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False
68
- )
69
- self.bn1 = torch.nn.BatchNorm2d(planes)
70
- self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
71
- self.bn2 = torch.nn.BatchNorm2d(planes)
72
-
73
- self.shortcut = torch.nn.Sequential()
74
- if stride != 1 or in_planes != self.expansion * planes:
75
- self.shortcut = torch.nn.Sequential(
76
- torch.nn.Conv2d(
77
- in_planes,
78
- self.expansion * planes,
79
- kernel_size=1,
80
- stride=(stride, 1),
81
- bias=False,
82
- ),
83
- torch.nn.BatchNorm2d(self.expansion * planes),
84
- )
85
-
86
- def forward(self, x):
87
- out = F.relu(self.bn1(self.conv1(x)))
88
- out = self.bn2(self.conv2(out))
89
- out += self.shortcut(x)
90
- out = F.relu(out)
91
- return out
92
-
93
-
94
- class FCM(torch.nn.Module):
95
- def __init__(self, block=BasicResBlock, num_blocks=[2, 2], m_channels=32, feat_dim=80):
96
- super(FCM, self).__init__()
97
- self.in_planes = m_channels
98
- self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
99
- self.bn1 = torch.nn.BatchNorm2d(m_channels)
100
-
101
- self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
102
- self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
103
-
104
- self.conv2 = torch.nn.Conv2d(
105
- m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False
106
- )
107
- self.bn2 = torch.nn.BatchNorm2d(m_channels)
108
- self.out_channels = m_channels * (feat_dim // 8)
109
-
110
- def _make_layer(self, block, planes, num_blocks, stride):
111
- strides = [stride] + [1] * (num_blocks - 1)
112
- layers = []
113
- for stride in strides:
114
- layers.append(block(self.in_planes, planes, stride))
115
- self.in_planes = planes * block.expansion
116
- return torch.nn.Sequential(*layers)
117
-
118
- def forward(self, x):
119
- x = x.unsqueeze(1)
120
- out = F.relu(self.bn1(self.conv1(x)))
121
- out = self.layer1(out)
122
- out = self.layer2(out)
123
- out = F.relu(self.bn2(self.conv2(out)))
124
-
125
- shape = out.shape
126
- out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
127
- return out
128
-
129
-
130
- def get_nonlinear(config_str, channels):
131
- nonlinear = torch.nn.Sequential()
132
- for name in config_str.split("-"):
133
- if name == "relu":
134
- nonlinear.add_module("relu", torch.nn.ReLU(inplace=True))
135
- elif name == "prelu":
136
- nonlinear.add_module("prelu", torch.nn.PReLU(channels))
137
- elif name == "batchnorm":
138
- nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels))
139
- elif name == "batchnorm_":
140
- nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels, affine=False))
141
- else:
142
- raise ValueError("Unexpected module ({}).".format(name))
143
- return nonlinear
144
-
145
-
146
- def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
147
- mean = x.mean(dim=dim)
148
- std = x.std(dim=dim, unbiased=unbiased)
149
- stats = torch.cat([mean, std], dim=-1)
150
- if keepdim:
151
- stats = stats.unsqueeze(dim=dim)
152
- return stats
153
-
154
-
155
- class StatsPool(torch.nn.Module):
156
- def forward(self, x):
157
- return statistics_pooling(x)
158
-
159
-
160
- class TDNNLayer(torch.nn.Module):
161
- def __init__(
162
- self,
163
- in_channels,
164
- out_channels,
165
- kernel_size,
166
- stride=1,
167
- padding=0,
168
- dilation=1,
169
- bias=False,
170
- config_str="batchnorm-relu",
171
- ):
172
- super(TDNNLayer, self).__init__()
173
- if padding < 0:
174
- assert (
175
- kernel_size % 2 == 1
176
- ), "Expect equal paddings, but got even kernel size ({})".format(kernel_size)
177
- padding = (kernel_size - 1) // 2 * dilation
178
- self.linear = torch.nn.Conv1d(
179
- in_channels,
180
- out_channels,
181
- kernel_size,
182
- stride=stride,
183
- padding=padding,
184
- dilation=dilation,
185
- bias=bias,
186
- )
187
- self.nonlinear = get_nonlinear(config_str, out_channels)
188
-
189
- def forward(self, x):
190
- x = self.linear(x)
191
- x = self.nonlinear(x)
192
- return x
193
-
194
-
195
- class CAMLayer(torch.nn.Module):
196
- def __init__(
197
- self, bn_channels, out_channels, kernel_size, stride, padding, dilation, bias, reduction=2
198
- ):
199
- super(CAMLayer, self).__init__()
200
- self.linear_local = torch.nn.Conv1d(
201
- bn_channels,
202
- out_channels,
203
- kernel_size,
204
- stride=stride,
205
- padding=padding,
206
- dilation=dilation,
207
- bias=bias,
208
- )
209
- self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1)
210
- self.relu = torch.nn.ReLU(inplace=True)
211
- self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1)
212
- self.sigmoid = torch.nn.Sigmoid()
213
-
214
- def forward(self, x):
215
- y = self.linear_local(x)
216
- context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
217
- context = self.relu(self.linear1(context))
218
- m = self.sigmoid(self.linear2(context))
219
- return y * m
220
-
221
- def seg_pooling(self, x, seg_len=100, stype="avg"):
222
- if stype == "avg":
223
- seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
224
- elif stype == "max":
225
- seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
226
- else:
227
- raise ValueError("Wrong segment pooling type.")
228
- shape = seg.shape
229
- seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
230
- seg = seg[..., : x.shape[-1]]
231
- return seg
232
-
233
-
234
- class CAMDenseTDNNLayer(torch.nn.Module):
235
- def __init__(
236
- self,
237
- in_channels,
238
- out_channels,
239
- bn_channels,
240
- kernel_size,
241
- stride=1,
242
- dilation=1,
243
- bias=False,
244
- config_str="batchnorm-relu",
245
- memory_efficient=False,
246
- ):
247
- super(CAMDenseTDNNLayer, self).__init__()
248
- assert kernel_size % 2 == 1, "Expect equal paddings, but got even kernel size ({})".format(
249
- kernel_size
250
- )
251
- padding = (kernel_size - 1) // 2 * dilation
252
- self.memory_efficient = memory_efficient
253
- self.nonlinear1 = get_nonlinear(config_str, in_channels)
254
- self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False)
255
- self.nonlinear2 = get_nonlinear(config_str, bn_channels)
256
- self.cam_layer = CAMLayer(
257
- bn_channels,
258
- out_channels,
259
- kernel_size,
260
- stride=stride,
261
- padding=padding,
262
- dilation=dilation,
263
- bias=bias,
264
- )
265
-
266
- def bn_function(self, x):
267
- return self.linear1(self.nonlinear1(x))
268
-
269
- def forward(self, x):
270
- if self.training and self.memory_efficient:
271
- x = cp.checkpoint(self.bn_function, x)
272
- else:
273
- x = self.bn_function(x)
274
- x = self.cam_layer(self.nonlinear2(x))
275
- return x
276
-
277
-
278
- class CAMDenseTDNNBlock(torch.nn.ModuleList):
279
- def __init__(
280
- self,
281
- num_layers,
282
- in_channels,
283
- out_channels,
284
- bn_channels,
285
- kernel_size,
286
- stride=1,
287
- dilation=1,
288
- bias=False,
289
- config_str="batchnorm-relu",
290
- memory_efficient=False,
291
- ):
292
- super(CAMDenseTDNNBlock, self).__init__()
293
- for i in range(num_layers):
294
- layer = CAMDenseTDNNLayer(
295
- in_channels=in_channels + i * out_channels,
296
- out_channels=out_channels,
297
- bn_channels=bn_channels,
298
- kernel_size=kernel_size,
299
- stride=stride,
300
- dilation=dilation,
301
- bias=bias,
302
- config_str=config_str,
303
- memory_efficient=memory_efficient,
304
- )
305
- self.add_module("tdnnd%d" % (i + 1), layer)
306
-
307
- def forward(self, x):
308
- for layer in self:
309
- x = torch.cat([x, layer(x)], dim=1)
310
- return x
311
-
312
-
313
- class TransitLayer(torch.nn.Module):
314
- def __init__(self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"):
315
- super(TransitLayer, self).__init__()
316
- self.nonlinear = get_nonlinear(config_str, in_channels)
317
- self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
318
-
319
- def forward(self, x):
320
- x = self.nonlinear(x)
321
- x = self.linear(x)
322
- return x
323
-
324
-
325
- class DenseLayer(torch.nn.Module):
326
- def __init__(self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"):
327
- super(DenseLayer, self).__init__()
328
- self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
329
- self.nonlinear = get_nonlinear(config_str, out_channels)
330
-
331
- def forward(self, x):
332
- if len(x.shape) == 2:
333
- x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
334
- else:
335
- x = self.linear(x)
336
- x = self.nonlinear(x)
337
- return x
338
-
339
- # @tables.register("model_classes", "CAMPPlus")
340
- class CAMPPlus(torch.nn.Module):
341
- def __init__(
342
- self,
343
- feat_dim=80,
344
- embedding_size=192,
345
- growth_rate=32,
346
- bn_size=4,
347
- init_channels=128,
348
- config_str="batchnorm-relu",
349
- memory_efficient=True,
350
- output_level="segment",
351
- **kwargs,
352
- ):
353
- super().__init__()
354
-
355
- self.head = FCM(feat_dim=feat_dim)
356
- channels = self.head.out_channels
357
- self.output_level = output_level
358
-
359
- self.xvector = torch.nn.Sequential(
360
- OrderedDict(
361
- [
362
- (
363
- "tdnn",
364
- TDNNLayer(
365
- channels,
366
- init_channels,
367
- 5,
368
- stride=2,
369
- dilation=1,
370
- padding=-1,
371
- config_str=config_str,
372
- ),
373
- ),
374
- ]
375
- )
376
- )
377
- channels = init_channels
378
- for i, (num_layers, kernel_size, dilation) in enumerate(
379
- zip((12, 24, 16), (3, 3, 3), (1, 2, 2))
380
- ):
381
- block = CAMDenseTDNNBlock(
382
- num_layers=num_layers,
383
- in_channels=channels,
384
- out_channels=growth_rate,
385
- bn_channels=bn_size * growth_rate,
386
- kernel_size=kernel_size,
387
- dilation=dilation,
388
- config_str=config_str,
389
- memory_efficient=memory_efficient,
390
- )
391
- self.xvector.add_module("block%d" % (i + 1), block)
392
- channels = channels + num_layers * growth_rate
393
- self.xvector.add_module(
394
- "transit%d" % (i + 1),
395
- TransitLayer(channels, channels // 2, bias=False, config_str=config_str),
396
- )
397
- channels //= 2
398
-
399
- self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
400
-
401
- if self.output_level == "segment":
402
- self.xvector.add_module("stats", StatsPool())
403
- self.xvector.add_module(
404
- "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_")
405
- )
406
- else:
407
- assert (
408
- self.output_level == "frame"
409
- ), "`output_level` should be set to 'segment' or 'frame'. "
410
-
411
- for m in self.modules():
412
- if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
413
- torch.nn.init.kaiming_normal_(m.weight.data)
414
- if m.bias is not None:
415
- torch.nn.init.zeros_(m.bias)
416
-
417
- def forward(self, x):
418
- x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
419
- x = self.head(x)
420
- x = self.xvector(x)
421
- if self.output_level == "frame":
422
- x = x.transpose(1, 2)
423
- return x
424
-
425
- def inference(self, audio_list):
426
- speech, speech_lengths, speech_times = extract_feature(audio_list)
427
- results = self.forward(speech.to(torch.float32))
428
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3tokenizer/__init__.py DELETED
@@ -1,30 +0,0 @@
1
- from .s3tokenizer import (
2
- S3_SR,
3
- S3_HOP,
4
- S3_TOKEN_HOP,
5
- S3_TOKEN_RATE,
6
- SPEECH_VOCAB_SIZE,
7
- S3Tokenizer,
8
- )
9
-
10
-
11
- SOS = SPEECH_VOCAB_SIZE
12
- EOS = SPEECH_VOCAB_SIZE + 1
13
-
14
-
15
-
16
- def drop_invalid_tokens(x):
17
- """Drop SoS and EoS"""
18
- assert len(x.shape) == 1 or (len(x.shape) == 2 and x.shape[0] == 1), "only batch size of one allowed for now"
19
- if SOS in x:
20
- s = (x == SOS).nonzero(as_tuple=True)[0].squeeze(0) + 1
21
- else:
22
- s = 0
23
-
24
- if EOS in x:
25
- e = (x == EOS).nonzero(as_tuple=True)[0].squeeze(0)
26
- else:
27
- e = None
28
-
29
- x = x[s: e]
30
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/s3tokenizer/s3tokenizer.py DELETED
@@ -1,168 +0,0 @@
1
- from typing import List, Tuple
2
-
3
- import numpy as np
4
- import librosa
5
- import torch
6
- import torch.nn.functional as F
7
- from s3tokenizer.utils import padding
8
- from s3tokenizer.model_v2 import (
9
- S3TokenizerV2,
10
- ModelConfig,
11
- )
12
-
13
-
14
- # Sampling rate of the inputs to S3TokenizerV2
15
- S3_SR = 16_000
16
- S3_HOP = 160 # 100 frames/sec
17
- S3_TOKEN_HOP = 640 # 25 tokens/sec
18
- S3_TOKEN_RATE = 25
19
- SPEECH_VOCAB_SIZE = 6561
20
-
21
-
22
- class S3Tokenizer(S3TokenizerV2):
23
- """
24
- s3tokenizer.S3TokenizerV2 with the following changes:
25
- - a more integrated `forward`
26
- - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers`
27
- """
28
-
29
- ignore_state_dict_missing = ("_mel_filters", "window")
30
-
31
- def __init__(
32
- self,
33
- name: str="speech_tokenizer_v2_25hz",
34
- config: ModelConfig = ModelConfig()
35
- ):
36
- super().__init__(name)
37
-
38
- self.n_fft = 400
39
- _mel_filters = librosa.filters.mel(
40
- sr=S3_SR,
41
- n_fft=self.n_fft,
42
- n_mels=config.n_mels
43
- )
44
- self.register_buffer(
45
- "_mel_filters",
46
- torch.FloatTensor(_mel_filters),
47
- )
48
-
49
- self.register_buffer(
50
- "window",
51
- torch.hann_window(self.n_fft),
52
- )
53
-
54
- def pad(self, wavs, sr) -> List[torch.Tensor]:
55
- """
56
- Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec).
57
- """
58
- processed_wavs = []
59
- for wav in wavs:
60
- if isinstance(wav, np.ndarray):
61
- wav = torch.from_numpy(wav)
62
- if wav.dim() == 1:
63
- wav = wav.unsqueeze(0)
64
-
65
- n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE
66
- n_tokens = np.ceil(n_tokens)
67
- intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE)
68
- intended_wav_len = int(intended_wav_len)
69
- wav = torch.nn.functional.pad(
70
- wav,
71
- (0, intended_wav_len - wav.shape[-1]),
72
- mode="constant",
73
- value=0
74
- )
75
- processed_wavs.append(wav)
76
- return processed_wavs
77
-
78
- def _prepare_audio(self, wavs):
79
- """Prepare a list of audios for s3tokenizer processing."""
80
- processed_wavs = []
81
- for wav in wavs:
82
- if isinstance(wav, np.ndarray):
83
- wav = torch.from_numpy(wav)
84
- if wav.dim() == 1:
85
- wav = wav.unsqueeze(0)
86
-
87
- processed_wavs.append(wav)
88
- return processed_wavs
89
-
90
- @torch.no_grad()
91
- def forward(
92
- self,
93
- wavs: torch.Tensor,
94
- accelerator: 'Accelerator'=None,
95
- max_len: int=None,
96
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
97
- """
98
- NOTE: mel-spec has a hop size of 160 points (100 frame/sec).
99
- FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected.
100
-
101
- Args
102
- ----
103
- - `wavs`: 16 kHz speech audio
104
- - `max_len` max length to truncate the output sequence to (25 token/sec).
105
- NOTE: please pad the waveform if longer sequence is needed.
106
- """
107
- processed_wavs = self._prepare_audio(wavs)
108
- mels, mel_lens = [], []
109
- for wav in processed_wavs:
110
- wav = wav.to(self.device)
111
- mel = self.log_mel_spectrogram(wav) # [B=1, F, T]
112
- if max_len is not None:
113
- mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens
114
- mels.append(mel.squeeze(0))
115
-
116
- mels, mel_lens = padding(mels)
117
- if accelerator is None:
118
- tokenizer = self
119
- else:
120
- tokenizer = accelerator.unwrap_model(self)
121
-
122
- speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device))
123
- return (
124
- speech_tokens.long().detach(),
125
- speech_token_lens.long().detach(),
126
- )
127
-
128
- def log_mel_spectrogram(
129
- self,
130
- audio: torch.Tensor,
131
- padding: int = 0,
132
- ):
133
- """
134
- Compute the log-Mel spectrogram of
135
-
136
- Parameters
137
- ----------
138
- audio: torch.Tensor, shape = (*)
139
- The path to audio or either a NumPy array or Tensor containing the
140
- audio waveform in 16 kHz
141
-
142
- padding: int
143
- Number of zero samples to pad to the right
144
-
145
- Returns
146
- -------
147
- torch.Tensor, shape = (128, n_frames)
148
- A Tensor that contains the Mel spectrogram
149
- """
150
- if not torch.is_tensor(audio):
151
- audio = torch.from_numpy(audio)
152
-
153
- audio = audio.to(self.device)
154
- if padding > 0:
155
- audio = F.pad(audio, (0, padding))
156
- stft = torch.stft(
157
- audio, self.n_fft, S3_HOP,
158
- window=self.window.to(self.device),
159
- return_complex=True
160
- )
161
- magnitudes = stft[..., :-1].abs()**2
162
-
163
- mel_spec = self._mel_filters.to(self.device) @ magnitudes
164
-
165
- log_spec = torch.clamp(mel_spec, min=1e-10).log10()
166
- log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
167
- log_spec = (log_spec + 4.0) / 4.0
168
- return log_spec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/t3/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .t3 import T3
 
 
src/chatterbox/models/t3/inference/alignment_stream_analyzer.py DELETED
@@ -1,178 +0,0 @@
1
- # Copyright (c) 2025 Resemble AI
2
- # Author: John Meade, Jeremy Hsu
3
- # MIT License
4
- import logging
5
- import torch
6
- from dataclasses import dataclass
7
- from types import MethodType
8
-
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)]
14
-
15
-
16
- @dataclass
17
- class AlignmentAnalysisResult:
18
- # was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
19
- false_start: bool
20
- # was this frame detected as being part of a long tail with potential hallucinations?
21
- long_tail: bool
22
- # was this frame detected as repeating existing text content?
23
- repetition: bool
24
- # was the alignment position of this frame too far from the previous frame?
25
- discontinuity: bool
26
- # has inference reached the end of the text tokens? eg, this remains false if inference stops early
27
- complete: bool
28
- # approximate position in the text token sequence. Can be used for generating online timestamps.
29
- position: int
30
-
31
-
32
- class AlignmentStreamAnalyzer:
33
- def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0):
34
- """
35
- Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention
36
- activation maps. This module exploits this to perform online integrity checks which streaming.
37
- A hook is injected into the specified attention layer, and heuristics are used to determine alignment
38
- position, repetition, etc.
39
-
40
- NOTE: currently requires no queues.
41
- """
42
- # self.queue = queue
43
- self.text_tokens_slice = (i, j) = text_tokens_slice
44
- self.eos_idx = eos_idx
45
- self.alignment = torch.zeros(0, j-i)
46
- # self.alignment_bin = torch.zeros(0, j-i)
47
- self.curr_frame_pos = 0
48
- self.text_position = 0
49
-
50
- self.started = False
51
- self.started_at = None
52
-
53
- self.complete = False
54
- self.completed_at = None
55
-
56
- # Track generated tokens for repetition detection
57
- self.generated_tokens = []
58
-
59
- # Using `output_attentions=True` is incompatible with optimized attention kernels, so
60
- # using it for all layers slows things down too much. We can apply it to just one layer
61
- # by intercepting the kwargs and adding a forward hook (credit: jrm)
62
- self.last_aligned_attns = []
63
- for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS):
64
- self.last_aligned_attns += [None]
65
- self._add_attention_spy(tfmr, i, layer_idx, head_idx)
66
-
67
- def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
68
- """
69
- Adds a forward hook to a specific attention layer to collect outputs.
70
- """
71
- def attention_forward_hook(module, input, output):
72
- """
73
- See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
74
- NOTE:
75
- - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
76
- - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
77
- """
78
- if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
79
- step_attention = output[1].cpu() # (B, n_heads, T0, Ti)
80
- self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti)
81
-
82
- target_layer = tfmr.layers[layer_idx].self_attn
83
- # Register hook and store the handle
84
- target_layer.register_forward_hook(attention_forward_hook)
85
- if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
86
- self.original_output_attentions = tfmr.config.output_attentions
87
- tfmr.config.output_attentions = True
88
-
89
- def step(self, logits, next_token=None):
90
- """
91
- Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
92
- """
93
- # extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
94
- aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N)
95
- i, j = self.text_tokens_slice
96
- if self.curr_frame_pos == 0:
97
- # first chunk has conditioning info, text tokens, and BOS token
98
- A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S)
99
- else:
100
- # subsequent chunks have 1 frame due to KV-caching
101
- A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S)
102
-
103
- # TODO: monotonic masking; could have issue b/c spaces are often skipped.
104
- A_chunk[:, self.curr_frame_pos + 1:] = 0
105
-
106
-
107
- self.alignment = torch.cat((self.alignment, A_chunk), dim=0)
108
-
109
- A = self.alignment
110
- T, S = A.shape
111
-
112
- # update position
113
- cur_text_posn = A_chunk[-1].argmax()
114
- discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient!
115
- if not discontinuity:
116
- self.text_position = cur_text_posn
117
-
118
- # Hallucinations at the start of speech show up as activations at the bottom of the attention maps!
119
- # To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens,
120
- # and there are some strong activations in the first few tokens.
121
- false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5)
122
- self.started = not false_start
123
- if self.started and self.started_at is None:
124
- self.started_at = T
125
-
126
- # Is generation likely complete?
127
- self.complete = self.complete or self.text_position >= S - 3
128
- if self.complete and self.completed_at is None:
129
- self.completed_at = T
130
-
131
- # NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens.
132
- # NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens.
133
- last_text_token_duration = A[15:, -3:].sum()
134
-
135
- # Activations for the final token that last too long are likely hallucinations.
136
- long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms
137
-
138
- # If there are activations in previous tokens after generation has completed, assume this is a repetition error.
139
- alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
140
-
141
- # Track generated tokens for repetition detection
142
- if next_token is not None:
143
- # Convert tensor to scalar if needed
144
- if isinstance(next_token, torch.Tensor):
145
- token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item()
146
- else:
147
- token_id = next_token
148
- self.generated_tokens.append(token_id)
149
-
150
- # Keep only last 8 tokens to prevent memory issues
151
- if len(self.generated_tokens) > 8:
152
- self.generated_tokens = self.generated_tokens[-8:]
153
-
154
- # Check for excessive token repetition (3x same token in a row)
155
- token_repetition = (
156
- # self.complete and
157
- len(self.generated_tokens) >= 3 and
158
- len(set(self.generated_tokens[-2:])) == 1
159
- )
160
-
161
- if token_repetition:
162
- repeated_token = self.generated_tokens[-1]
163
- logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}")
164
-
165
- # Suppress EoS to prevent early termination
166
- if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
167
- logits[..., self.eos_idx] = -2**15
168
-
169
- # If a bad ending is detected, force emit EOS by modifying logits
170
- # NOTE: this means logits may be inconsistent with latents!
171
- if long_tail or alignment_repetition or token_repetition:
172
- logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}")
173
- # (±2**15 is safe for all dtypes >= 16bit)
174
- logits = -(2**15) * torch.ones_like(logits)
175
- logits[..., self.eos_idx] = 2**15
176
-
177
- self.curr_frame_pos += 1
178
- return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/t3/inference/t3_hf_backend.py DELETED
@@ -1,116 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- from torch import nn as nn
5
- from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin
6
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
7
-
8
-
9
- class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
10
- """
11
- Override some HuggingFace interface methods so we can use the standard `generate` method with our
12
- custom embedding / logit layers.
13
-
14
- NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights!
15
- """
16
-
17
- def __init__(
18
- self,
19
- config: LlamaConfig,
20
- llama: LlamaModel,
21
- *,
22
- speech_enc,
23
- speech_head,
24
- latents_queue=None,
25
- logits_queue=None,
26
- alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
27
- ):
28
- super().__init__(config)
29
- self.model = llama
30
- self.speech_enc = speech_enc
31
- self.speech_head = speech_head
32
- self._added_cond = False
33
- self.alignment_stream_analyzer = alignment_stream_analyzer
34
-
35
- @torch.inference_mode()
36
- def prepare_inputs_for_generation(
37
- self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None,
38
- # This argument was introduced in some recent version of transformers (>=4.29.1)
39
- cache_position=None
40
- ):
41
- """
42
- This is a method used by huggingface's generate() method.
43
- Overridden here to apply our custom speech token embedding layer.
44
-
45
- :param input_ids: (B, S) int64 tensors of input tokens.
46
- :param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to <input_embeds>)
47
- """
48
-
49
- # Make use of the kv cache: only the last input ID is new, we trim away all the ones before
50
- if not use_cache:
51
- past_key_values = None
52
- if past_key_values is not None:
53
- input_ids = input_ids[:, -1:]
54
-
55
- # custom speech token embedding layer
56
- inputs_embeds = self.speech_enc(input_ids)
57
-
58
- # prefix decoder conditioning if applicable
59
- if not self._added_cond:
60
- assert past_key_values is not None # should be first step
61
- if decoder_cond.size(0) != inputs_embeds.size(0):
62
- decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1)
63
- inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1)
64
- self._added_cond = True
65
-
66
- return {
67
- "inputs_embeds": inputs_embeds,
68
- "past_key_values": past_key_values,
69
- "use_cache": use_cache,
70
- }
71
-
72
- @torch.inference_mode()
73
- def forward(
74
- self,
75
- inputs_embeds: torch.Tensor,
76
- past_key_values: Optional[torch.Tensor]=None,
77
- use_cache=True,
78
- output_attentions=False,
79
- output_hidden_states=True,
80
- return_dict=True,
81
- ):
82
- """
83
- This is a method used by huggingface's generate() method.
84
- Overridden here to apply our custom layer norm and speech logit projection layers.
85
-
86
- :param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given,
87
- S should be 1.
88
- """
89
- is_large_input = inputs_embeds.size(1) != 1
90
- has_cache = past_key_values is not None and len(past_key_values) > 0
91
- assert not (is_large_input and has_cache)
92
- assert return_dict
93
- assert output_hidden_states
94
-
95
- tfmr_out = self.model(
96
- inputs_embeds=inputs_embeds,
97
- past_key_values=past_key_values,
98
- use_cache=use_cache,
99
- output_attentions=output_attentions,
100
- output_hidden_states=output_hidden_states,
101
- return_dict=True,
102
- )
103
- hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim)
104
-
105
- logits = self.speech_head(hidden_states)
106
- # assert inputs_embeds.size(0) == 1 # (disabled for CFG)
107
-
108
- # NOTE: hallucination handler may modify logits to force emit an EOS token
109
- # logits = self.alignment_stream_analyzer.step(logits)
110
-
111
- return CausalLMOutputWithCrossAttentions(
112
- logits=logits,
113
- past_key_values=tfmr_out.past_key_values,
114
- hidden_states=tfmr_out.hidden_states,
115
- attentions=tfmr_out.attentions,
116
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/t3/llama_configs.py DELETED
@@ -1,37 +0,0 @@
1
- LLAMA_520M_CONFIG_DICT = dict(
2
- # Arbitrary small number that won't cause problems when loading.
3
- # These param are unused due to custom input layers.
4
- vocab_size=8,
5
- # default params needed for loading most pretrained 1B weights
6
- max_position_embeddings=131072,
7
- hidden_size=1024,
8
- intermediate_size=4096,
9
- num_hidden_layers=30,
10
- num_attention_heads=16,
11
- attn_implementation="sdpa",
12
- head_dim=64,
13
- tie_word_embeddings=False,
14
- hidden_act="silu",
15
- attention_bias=False,
16
- attention_dropout=0.0,
17
- initializer_range=0.02,
18
- mlp_bias=False,
19
- model_type="llama",
20
- num_key_value_heads=16,
21
- pretraining_tp=1,
22
- rms_norm_eps=1e-05,
23
- rope_scaling=dict(
24
- factor=8.0,
25
- high_freq_factor=4.0,
26
- low_freq_factor=1.0,
27
- original_max_position_embeddings=8192,
28
- rope_type="llama3"
29
- ),
30
- rope_theta=500000.0,
31
- torch_dtype="bfloat16",
32
- use_cache=True,
33
- )
34
-
35
- LLAMA_CONFIGS = {
36
- "Llama_520M": LLAMA_520M_CONFIG_DICT,
37
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/t3/modules/cond_enc.py DELETED
@@ -1,97 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional
3
-
4
- import torch
5
- from torch import nn, Tensor
6
-
7
- from .perceiver import Perceiver
8
- from .t3_config import T3Config
9
-
10
-
11
- @dataclass
12
- class T3Cond:
13
- """
14
- Dataclass container for most / all conditioning info.
15
- TODO: serialization methods aren't used, keeping them around for convenience
16
- """
17
-
18
- speaker_emb: Tensor
19
- clap_emb: Optional[Tensor] = None
20
- cond_prompt_speech_tokens: Optional[Tensor] = None
21
- cond_prompt_speech_emb: Optional[Tensor] = None
22
- emotion_adv: Optional[Tensor] = 0.5
23
-
24
- def to(self, *, device=None, dtype=None):
25
- "Cast to a device and dtype. Dtype casting is ignored for long/int tensors."
26
- for k, v in self.__dict__.items():
27
- if torch.is_tensor(v):
28
- is_fp = type(v.view(-1)[0].item()) is not int
29
- setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None))
30
- return self
31
-
32
- def save(self, fpath):
33
- torch.save(self.__dict__, fpath)
34
-
35
- @staticmethod
36
- def load(fpath, map_location="cpu"):
37
- kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
38
- return T3Cond(**kwargs)
39
-
40
-
41
- class T3CondEnc(nn.Module):
42
- """
43
- Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc.
44
- """
45
-
46
- def __init__(self, hp: T3Config):
47
- super().__init__()
48
- self.hp = hp
49
- if hp.encoder_type == "voice_encoder":
50
- self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels)
51
- else:
52
- raise NotImplementedError(str(hp.encoder_type))
53
-
54
- # emotion adv
55
- self.emotion_adv_fc = None
56
- if hp.emotion_adv:
57
- self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False)
58
-
59
- # perceiver resampler
60
- self.perceiver = None
61
- if hp.use_perceiver_resampler:
62
- self.perceiver = Perceiver()
63
-
64
- def forward(self, cond: T3Cond):
65
- # Validate
66
- assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \
67
- "no embeddings for cond_prompt_speech_tokens"
68
-
69
- # Speaker embedding projection
70
- cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim)
71
- empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim)
72
-
73
- # TODO CLAP
74
- assert cond.clap_emb is None, "clap_embed not implemented"
75
- cond_clap = empty # (B, 0, dim)
76
-
77
- # Cond prompt
78
- cond_prompt_speech_emb = cond.cond_prompt_speech_emb
79
- if cond_prompt_speech_emb is None:
80
- cond_prompt_speech_emb = empty # (B, 0, dim)
81
- elif self.hp.use_perceiver_resampler:
82
- cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb)
83
-
84
- # Emotion Adv: must provide a value if this model uses emotion conditioning
85
- cond_emotion_adv = empty # (B, 0, dim)
86
- if self.hp.emotion_adv:
87
- assert cond.emotion_adv is not None
88
- cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1))
89
-
90
- # Concat and return
91
- cond_embeds = torch.cat((
92
- cond_spkr,
93
- cond_clap,
94
- cond_prompt_speech_emb,
95
- cond_emotion_adv,
96
- ), dim=1)
97
- return cond_embeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/t3/modules/learned_pos_emb.py DELETED
@@ -1,32 +0,0 @@
1
- from typing import Union
2
-
3
- import torch
4
- from torch import nn, Tensor
5
-
6
-
7
- class LearnedPositionEmbeddings(nn.Module):
8
- def __init__(self, seq_len, model_dim, init=.02):
9
- super().__init__()
10
- self.emb = nn.Embedding(seq_len, model_dim)
11
- # Initializing this way is standard for GPT-2
12
- self.emb.weight.data.normal_(mean=0.0, std=init)
13
-
14
- def forward(self, x):
15
- """
16
- Returns positional embeddings for index 0 up to the length of x
17
- """
18
- sl = x.shape[1]
19
- return self.emb(torch.arange(0, sl, device=x.device))
20
-
21
- def get_fixed_embedding(self, idx: 'Union[int, Tensor]'):
22
- """
23
- Args:
24
- idx: scalar int or an integer tensor of shape (T,) or (B, T)
25
- Returns:
26
- positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input
27
- """
28
- device = self.emb.weight.device
29
- idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device)
30
- idx = torch.atleast_2d(idx)
31
- assert idx.ndim == 2
32
- return self.emb(idx) # (B, T, dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/t3/modules/perceiver.py DELETED
@@ -1,212 +0,0 @@
1
- # Copyright (c) 2025 Resemble AI
2
- # Author: Manmay Nakhashi
3
- # MIT License
4
- import math
5
-
6
- import torch
7
- from torch import nn
8
- import torch.nn.functional as F
9
- from einops import rearrange
10
-
11
-
12
- class RelativePositionBias(nn.Module):
13
- def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
14
- super().__init__()
15
- self.scale = scale
16
- self.causal = causal
17
- self.num_buckets = num_buckets
18
- self.max_distance = max_distance
19
- self.relative_attention_bias = nn.Embedding(num_buckets, heads)
20
-
21
- @staticmethod
22
- def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
23
- ret = 0
24
- n = -relative_position
25
- if not causal:
26
- num_buckets //= 2
27
- ret += (n < 0).long() * num_buckets
28
- n = torch.abs(n)
29
- else:
30
- n = torch.max(n, torch.zeros_like(n))
31
-
32
- max_exact = num_buckets // 2
33
- is_small = n < max_exact
34
-
35
- val_if_large = max_exact + (
36
- torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
37
- ).long()
38
- val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
39
-
40
- ret += torch.where(is_small, n, val_if_large)
41
- return ret
42
-
43
- def forward(self, qk_dots):
44
- i, j, device = *qk_dots.shape[-2:], qk_dots.device
45
- q_pos = torch.arange(i, dtype=torch.long, device=device)
46
- k_pos = torch.arange(j, dtype=torch.long, device=device)
47
- rel_pos = k_pos[None, :] - q_pos[:, None]
48
- rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
49
- max_distance=self.max_distance)
50
- values = self.relative_attention_bias(rp_bucket)
51
- bias = rearrange(values, 'i j h -> () h i j')
52
- return qk_dots + (bias * self.scale)
53
-
54
-
55
- class AttentionQKV(nn.Module):
56
- def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False):
57
- super().__init__()
58
- self.n_heads = n_heads
59
- self.head_dim = head_dim
60
- self.scale = scale if scale is not None else head_dim ** -0.5
61
- self.flash = flash
62
- self.dropout_rate = dropout_rate
63
- self.dropout = nn.Dropout(dropout_rate)
64
- self.flash_config = self.setup_flash_config() if flash else None
65
-
66
- def setup_flash_config(self):
67
- # Setup flash attention configuration
68
- flash_config = {
69
- 'enable_flash': True,
70
- 'enable_math': True,
71
- 'enable_mem_efficient': True
72
- }
73
- return flash_config
74
-
75
- def forward(self, q, k, v, mask=None):
76
- q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]]
77
- if self.flash:
78
- out = self.flash_attention(q, k, v, mask=mask)
79
- else:
80
- out = self.scaled_dot_product_attention(q, k, v, mask=mask)
81
-
82
- return self.combine_heads(out)
83
-
84
- def scaled_dot_product_attention(self, q, k, v, mask=None):
85
- sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale
86
- if mask is not None:
87
- sim = sim.masked_fill(mask == 0, float('-inf'))
88
- attn = torch.softmax(sim, dim=-1)
89
- attn = self.dropout(attn)
90
- return torch.einsum("bhts,bhls->bhlt", attn, v)
91
-
92
- def flash_attention(self, q, k, v, mask=None):
93
- config = self.flash_config if self.flash_config else {}
94
- with torch.backends.cuda.sdp_kernel(**config):
95
- out = F.scaled_dot_product_attention(
96
- q, k, v,
97
- attn_mask=mask,
98
- dropout_p=self.dropout_rate if self.training else 0.
99
- )
100
- return out
101
-
102
- def split_heads(self, x):
103
- bs, length, _ = x.shape
104
- x = x.view(bs, length, self.n_heads, self.head_dim)
105
- return x.permute(0, 2, 1, 3)
106
-
107
- def combine_heads(self, x):
108
- bs, _, length, _ = x.shape
109
- x = x.permute(0, 2, 1, 3).contiguous()
110
- return x.view(bs, length, -1)
111
-
112
-
113
- class AttentionBlock2(nn.Module):
114
- """
115
- An attention block that allows spatial positions to attend to each other,
116
- using AttentionQKV and separate linear transformations for Q, K, and V.
117
- """
118
-
119
- def __init__(
120
- self,
121
- channels,
122
- num_heads=1,
123
- num_head_channels=-1,
124
- relative_pos_embeddings=False,
125
- flash_attention=True,
126
- dropout_rate=0.2,
127
- scale=None
128
- ):
129
- super().__init__()
130
- self.channels = channels
131
-
132
- if num_head_channels == -1:
133
- self.num_heads = num_heads
134
- else:
135
- assert (
136
- channels % num_head_channels == 0
137
- ), f"channels {channels} is not divisible by num_head_channels {num_head_channels}"
138
- self.num_heads = channels // num_head_channels
139
-
140
- self.norm = nn.LayerNorm(channels)
141
-
142
- # Separate linear layers for Q, K, and V
143
- self.to_q = nn.Linear(channels, channels)
144
- self.to_k = nn.Linear(channels, channels)
145
- self.to_v = nn.Linear(channels, channels)
146
-
147
- self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale)
148
-
149
- self.proj_out = nn.Linear(channels, channels)
150
-
151
- if relative_pos_embeddings:
152
- self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
153
- else:
154
- self.relative_pos_embeddings = None
155
-
156
- def forward(self, x1, x2, mask=None):
157
- b1, c1, *spatial1 = x1.shape
158
- b2, c2, *spatial2 = x2.shape
159
-
160
- x1_norm = self.norm(x1)
161
- x2_norm = self.norm(x2)
162
-
163
- q = self.to_q(x1_norm)
164
- k = self.to_k(x2_norm)
165
- v = self.to_v(x2_norm)
166
-
167
- h = self.attention(q, k, v, mask=mask)
168
- h = self.proj_out(h)
169
-
170
- return (x1 + h).reshape(b1, c1, *spatial1)
171
-
172
-
173
- class Perceiver(nn.Module):
174
- """Inspired by https://arxiv.org/abs/2103.03206"""
175
- def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4):
176
- """
177
- Initialize the perceiver module.
178
-
179
- :param pre_attention_query_token: Number of query tokens for pre-attention
180
- :param pre_attention_query_size: Size of each query token
181
- :param embedding_dim: Dimension of the embedding space
182
- :param num_attn_heads: Number of attention heads
183
- """
184
- super().__init__()
185
-
186
- # Initialize the pre-attention query parameter
187
- self.pre_attention_query = torch.nn.Parameter(
188
- torch.empty(1, pre_attention_query_token, pre_attention_query_size)
189
- )
190
-
191
- # Calculate the variance for uniform initialization
192
- query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token))
193
-
194
- # Initialize the pre-attention query with uniform distribution
195
- self.pre_attention_query.data.uniform_(-query_variance, query_variance)
196
-
197
- # Initialize the attention block
198
- self.attn = AttentionBlock2(embedding_dim, num_attn_heads)
199
-
200
- def forward(self, h):
201
- """
202
- Forward pass of the perceiver module.
203
- :param h: Input tensor
204
- :return: Output after applying attention mechanisms
205
- """
206
- # Expand the pre-attention query to match the batch size of the input
207
- query_ = self.pre_attention_query.expand(h.shape[0], -1, -1)
208
- # Apply the first attention mechanism (cross-attention)
209
- pre_att = self.attn(query_, h)
210
- # Apply the second attention mechanism (self-attention)
211
- attn = self.attn(pre_att, pre_att)
212
- return attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/t3/modules/t3_config.py DELETED
@@ -1,41 +0,0 @@
1
- from ..llama_configs import LLAMA_CONFIGS
2
-
3
-
4
- class T3Config:
5
- def __init__(self, text_tokens_dict_size=704):
6
- self.start_text_token = 255
7
- self.stop_text_token = 0
8
- self.text_tokens_dict_size = text_tokens_dict_size
9
- self.max_text_tokens = 2048
10
-
11
- self.start_speech_token = 6561
12
- self.stop_speech_token = 6562
13
- self.speech_tokens_dict_size = 8194
14
- self.max_speech_tokens = 4096
15
-
16
- self.llama_config_name = "Llama_520M"
17
- self.input_pos_emb = "learned"
18
- self.speech_cond_prompt_len = 150
19
-
20
- self.encoder_type = "voice_encoder"
21
- self.speaker_embed_size = 256
22
- self.use_perceiver_resampler = True
23
- self.emotion_adv = True
24
-
25
- @property
26
- def n_channels(self):
27
- return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
28
-
29
- @property
30
- def is_multilingual(self):
31
- return self.text_tokens_dict_size == 2454
32
-
33
- @classmethod
34
- def english_only(cls):
35
- """Create configuration for English-only TTS model."""
36
- return cls(text_tokens_dict_size=704)
37
-
38
- @classmethod
39
- def multilingual(cls):
40
- """Create configuration for multilingual TTS model."""
41
- return cls(text_tokens_dict_size=2454)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/t3/t3.py DELETED
@@ -1,394 +0,0 @@
1
- # Copyright (c) 2025 Resemble AI
2
- # MIT License
3
- import logging
4
- from typing import Union, Optional, List
5
-
6
- logger = logging.getLogger(__name__)
7
-
8
- from tqdm import tqdm
9
- import torch
10
- import torch.nn.functional as F
11
- from torch import nn, Tensor
12
- from transformers import LlamaModel, LlamaConfig
13
- from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper
14
-
15
- from .modules.learned_pos_emb import LearnedPositionEmbeddings
16
-
17
- from .modules.cond_enc import T3CondEnc, T3Cond
18
- from .modules.t3_config import T3Config
19
- from .llama_configs import LLAMA_CONFIGS
20
- from .inference.t3_hf_backend import T3HuggingfaceBackend
21
- from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
22
- from ..utils import AttrDict
23
-
24
-
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
- def _ensure_BOT_EOT(text_tokens: Tensor, hp):
29
- B = text_tokens.size(0)
30
- assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
31
- assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token"
32
-
33
-
34
- class T3(nn.Module):
35
- """
36
- Token-To-Token (T3) TTS model using huggingface transformer models as backbones,
37
- * tokenization, including start / stop tokens are always added externally to this class
38
- * conditioning data like CLAP, emotion, etc are all in a separate file for more modularity
39
- * careful! this class assumes relative positional encoding -- with absolute PE, we would at
40
- least want to reset the position to 0 when speech tokens begin, and optionally use a
41
- different PE embedding space for speech.
42
- """
43
-
44
- def __init__(self, hp=None):
45
- if hp is None:
46
- hp = T3Config.english_only() # Default to English-only config for backward compatibility
47
- super().__init__()
48
- self.hp = hp
49
- self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
50
- self.tfmr = LlamaModel(self.cfg)
51
- self.dim = self.cfg.hidden_size
52
- self.deepspeed_patch_applied = False
53
-
54
- # conditioning / embedding
55
- self.cond_enc = T3CondEnc(hp)
56
- self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim)
57
- self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
58
-
59
- # custom position embedding
60
- if hp.input_pos_emb == "learned":
61
- max_text_seq_len = hp.max_text_tokens + 2
62
- self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
63
-
64
- max_mel_seq_len = hp.max_speech_tokens + 2 + 2
65
- self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
66
-
67
- # logit projection
68
- self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
69
- self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
70
- self.compiled = False
71
-
72
- @property
73
- def device(self):
74
- return self.speech_head.weight.device
75
-
76
- def prepare_conditioning(self, t3_cond: T3Cond):
77
- """
78
- Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
79
- """
80
- if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
81
- t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
82
- self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
83
- return self.cond_enc(t3_cond) # (B, len_cond, dim)
84
-
85
- def prepare_input_embeds(
86
- self,
87
- *,
88
- t3_cond: T3Cond,
89
- text_tokens: torch.LongTensor,
90
- speech_tokens: torch.LongTensor,
91
- cfg_weight: float = 0.0,
92
- ):
93
- # prepare input embeddings (skip backbone tranformer embeddings)
94
- cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
95
- text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
96
- if cfg_weight > 0.0:
97
- text_emb[1].zero_() # CFG uncond
98
-
99
- speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
100
- if self.hp.input_pos_emb == "learned":
101
- text_emb = text_emb + self.text_pos_emb(text_tokens)
102
- speech_emb = speech_emb + self.speech_pos_emb(speech_tokens)
103
- len_cond = cond_emb.size(1)
104
-
105
- if cond_emb.size(0) != text_emb.size(0):
106
- cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
107
-
108
- # concat
109
- embeds = torch.stack([
110
- torch.cat((ce, te, se))
111
- for ce, te, se in zip(cond_emb, text_emb, speech_emb)
112
- ]) # (B, length, dim)
113
- return embeds, len_cond
114
-
115
- def forward(
116
- self,
117
- *,
118
- t3_cond: T3Cond,
119
- text_tokens: torch.LongTensor,
120
- text_token_lens: torch.LongTensor,
121
- speech_tokens: torch.LongTensor,
122
- speech_token_lens: torch.LongTensor,
123
- training=False,
124
- ):
125
- _ensure_BOT_EOT(text_tokens, self.hp)
126
-
127
- # prepare custom input embeds
128
- embeds, len_cond = self.prepare_input_embeds(
129
- t3_cond=t3_cond,
130
- text_tokens=text_tokens,
131
- speech_tokens=speech_tokens,
132
- )
133
-
134
- # backbone tranformer forward
135
- tfmr_out = self.tfmr.forward(
136
- input_ids=None,
137
- # position_ids=position_ids, # TODO? ROPE should be fine?
138
- inputs_embeds=embeds,
139
- output_hidden_states=True,
140
- return_dict=True,
141
- use_cache=(not training),
142
- )
143
- hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim)
144
-
145
- # post-processing: splice out text and speech parts of hidden states
146
- len_text = text_tokens.size(1)
147
- len_speech = speech_tokens.size(1)
148
- B, _, dim = hidden_states.shape
149
- device, dtype = hidden_states.device, hidden_states.dtype
150
- text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device)
151
- speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device)
152
- ttl, stl = text_token_lens, speech_token_lens
153
- for i in range(B):
154
- text_end = len_cond + ttl[i].item()
155
- speech_start = len_cond + text_tokens.size(1)
156
- speech_end = speech_start + stl[i].item()
157
- text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
158
- speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]
159
-
160
- # logit projection
161
- text_logits = self.text_head(text_latents)
162
- speech_logits = self.speech_head(speech_latents)
163
-
164
- return AttrDict(
165
- text_logits=text_logits,
166
- text_latents=text_latents,
167
- speech_logits=speech_logits,
168
- speech_latents=speech_latents,
169
- hidden_states=hidden_states,
170
- )
171
-
172
- def loss(
173
- self,
174
- *,
175
- t3_cond: T3Cond,
176
- text_tokens: torch.LongTensor,
177
- text_token_lens: torch.LongTensor,
178
- speech_tokens: torch.LongTensor,
179
- speech_token_lens: torch.LongTensor,
180
- ):
181
- "training method"
182
- len_text = text_tokens.size(1)
183
- len_speech = speech_tokens.size(1)
184
- assert len_text == text_token_lens.max()
185
- assert len_speech == speech_token_lens.max()
186
-
187
- out = self.forward(
188
- t3_cond=t3_cond,
189
- text_tokens=text_tokens,
190
- text_token_lens=text_token_lens,
191
- speech_tokens=speech_tokens,
192
- speech_token_lens=speech_token_lens,
193
- training=True,
194
- ) # (B, seq, vocab_size)
195
-
196
- # Calc CCE losses
197
- IGNORE_ID = -100
198
- device = out.text_logits.device
199
- mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text)
200
- mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech)
201
- masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
202
- masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)
203
- loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
204
- loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID)
205
-
206
- return loss_text, loss_speech
207
-
208
- @torch.inference_mode()
209
- def inference(
210
- self,
211
- *,
212
- t3_cond: T3Cond,
213
- text_tokens: Tensor,
214
- initial_speech_tokens: Optional[Tensor]=None,
215
-
216
- # misc conditioning
217
- prepend_prompt_speech_tokens: Optional[Tensor]=None,
218
-
219
- # HF generate args
220
- num_return_sequences=1,
221
- max_new_tokens=None,
222
- stop_on_eos=True,
223
- do_sample=True,
224
- temperature=0.8,
225
- top_p=0.95,
226
- min_p=0.05,
227
- length_penalty=1.0,
228
- repetition_penalty=1.2,
229
- cfg_weight=0.5,
230
- ):
231
- """
232
- Args:
233
- text_tokens: a 1D (unbatched) or 2D (batched) tensor.
234
- """
235
- # Validate / sanitize inputs
236
- assert prepend_prompt_speech_tokens is None, "not implemented"
237
- _ensure_BOT_EOT(text_tokens, self.hp)
238
- text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)
239
-
240
- # Default initial speech to a single start-of-speech token
241
- if initial_speech_tokens is None:
242
- initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
243
-
244
- # Prepare custom input embeds
245
- embeds, len_cond = self.prepare_input_embeds(
246
- t3_cond=t3_cond,
247
- text_tokens=text_tokens,
248
- speech_tokens=initial_speech_tokens,
249
- cfg_weight=cfg_weight,
250
- )
251
-
252
- # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
253
- # Note the llama-specific logic. Other tfmr types can be added later.
254
-
255
- self.compiled = False
256
-
257
- # TODO? synchronize the expensive compile function
258
- # with self.compile_lock:
259
- if not self.compiled:
260
- # Default to None for English models, only create for multilingual
261
- alignment_stream_analyzer = None
262
- if self.hp.is_multilingual:
263
- alignment_stream_analyzer = AlignmentStreamAnalyzer(
264
- self.tfmr,
265
- None,
266
- text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
267
- alignment_layer_idx=9, # TODO: hparam or something?
268
- eos_idx=self.hp.stop_speech_token,
269
- )
270
- assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
271
-
272
- patched_model = T3HuggingfaceBackend(
273
- config=self.cfg,
274
- llama=self.tfmr,
275
- speech_enc=self.speech_emb,
276
- speech_head=self.speech_head,
277
- alignment_stream_analyzer=alignment_stream_analyzer,
278
- )
279
- self.patched_model = patched_model
280
- self.compiled = True
281
-
282
- # # Run normal generate method, which calls our custom extended methods
283
- # return self.patched_model.generate(
284
- # inputs=initial_speech_tokens,
285
- # decoder_cond=embeds,
286
- # bos_token_id=self.hp.start_speech_token,
287
- # eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
288
- # pad_token_id=self.hp.stop_speech_token,
289
- # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
290
- # num_return_sequences=num_return_sequences,
291
- # temperature=temperature,
292
- # min_p=min_p,
293
- # length_penalty=length_penalty,
294
- # repetition_penalty=repetition_penalty,
295
- # do_sample=do_sample,
296
- # # cache_implementation=None if not self.compiled else "static",
297
- # )
298
-
299
- device = embeds.device
300
-
301
- bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
302
- bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
303
- bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
304
-
305
- # batch_size=2 for CFG
306
- bos_embed = torch.cat([bos_embed, bos_embed])
307
-
308
- # Combine condition and BOS token for the initial input
309
- inputs_embeds = torch.cat([embeds, bos_embed], dim=1)
310
-
311
- # Track generated token ids; start with the BOS token.
312
- generated_ids = bos_token.clone()
313
- predicted = [] # To store the predicted tokens
314
-
315
- # Instantiate the logits processors.
316
- top_p_warper = TopPLogitsWarper(top_p=top_p)
317
- min_p_warper = MinPLogitsWarper(min_p=min_p)
318
- top_p_warper = TopPLogitsWarper(top_p=top_p)
319
- repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
320
-
321
- # ---- Initial Forward Pass (no kv_cache yet) ----
322
- output = self.patched_model(
323
- inputs_embeds=inputs_embeds,
324
- past_key_values=None,
325
- use_cache=True,
326
- output_attentions=True,
327
- output_hidden_states=True,
328
- return_dict=True,
329
- )
330
- # Initialize kv_cache with the full context.
331
- past = output.past_key_values
332
-
333
- # ---- Generation Loop using kv_cache ----
334
- for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
335
- logits_step = output.logits[:, -1, :]
336
- # CFG combine → (1, V)
337
- cond = logits_step[0:1, :]
338
- uncond = logits_step[1:2, :]
339
- cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
340
- logits = cond + cfg * (cond - uncond)
341
-
342
- # Apply alignment stream analyzer integrity checks
343
- if self.patched_model.alignment_stream_analyzer is not None:
344
- if logits.dim() == 1: # guard in case something upstream squeezed
345
- logits = logits.unsqueeze(0) # (1, V)
346
- # Pass the last generated token for repetition tracking
347
- last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
348
- logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
349
-
350
- # Apply repetition penalty
351
- ids_for_proc = generated_ids[:1, ...] # batch = 1
352
- logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
353
-
354
- # Apply temperature scaling.
355
- if temperature != 1.0:
356
- logits = logits / temperature
357
-
358
- # Apply min_p and top_p filtering
359
- logits = min_p_warper(ids_for_proc, logits)
360
- logits = top_p_warper(ids_for_proc, logits)
361
-
362
- # Convert logits to probabilities and sample the next token.
363
- probs = torch.softmax(logits, dim=-1)
364
- next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
365
-
366
- predicted.append(next_token)
367
- generated_ids = torch.cat([generated_ids, next_token], dim=1)
368
-
369
- # Check for EOS token.
370
- if next_token.view(-1) == self.hp.stop_speech_token:
371
- logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}")
372
- break
373
-
374
- # Get embedding for the new token.
375
- next_token_embed = self.speech_emb(next_token)
376
- next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
377
-
378
- # For CFG
379
- next_token_embed = torch.cat([next_token_embed, next_token_embed])
380
-
381
- # Forward pass with only the new token and the cached past.
382
- output = self.patched_model(
383
- inputs_embeds=next_token_embed,
384
- past_key_values=past,
385
- output_attentions=True,
386
- output_hidden_states=True,
387
- return_dict=True,
388
- )
389
- # Update the kv_cache.
390
- past = output.past_key_values
391
-
392
- # Concatenate all predicted tokens along the sequence dimension.
393
- predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
394
- return predicted_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/tokenizers/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .tokenizer import EnTokenizer, MTLTokenizer
 
 
src/chatterbox/models/tokenizers/tokenizer.py DELETED
@@ -1,312 +0,0 @@
1
- import logging
2
- import json
3
-
4
- import torch
5
- from pathlib import Path
6
- from unicodedata import category, normalize
7
- from tokenizers import Tokenizer
8
- from huggingface_hub import hf_hub_download
9
-
10
-
11
- # Special tokens
12
- SOT = "[START]"
13
- EOT = "[STOP]"
14
- UNK = "[UNK]"
15
- SPACE = "[SPACE]"
16
- SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"]
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
- class EnTokenizer:
21
- def __init__(self, vocab_file_path):
22
- self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
23
- self.check_vocabset_sot_eot()
24
-
25
- def check_vocabset_sot_eot(self):
26
- voc = self.tokenizer.get_vocab()
27
- assert SOT in voc
28
- assert EOT in voc
29
-
30
- def text_to_tokens(self, text: str):
31
- text_tokens = self.encode(text)
32
- text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
33
- return text_tokens
34
-
35
- def encode(self, txt: str):
36
- """
37
- clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
38
- """
39
- txt = txt.replace(' ', SPACE)
40
- code = self.tokenizer.encode(txt)
41
- ids = code.ids
42
- return ids
43
-
44
- def decode(self, seq):
45
- if isinstance(seq, torch.Tensor):
46
- seq = seq.cpu().numpy()
47
-
48
- txt: str = self.tokenizer.decode(seq, skip_special_tokens=False)
49
- txt = txt.replace(' ', '')
50
- txt = txt.replace(SPACE, ' ')
51
- txt = txt.replace(EOT, '')
52
- txt = txt.replace(UNK, '')
53
- return txt
54
-
55
-
56
- # Model repository
57
- REPO_ID = "ResembleAI/chatterbox"
58
-
59
- # Global instances for optional dependencies
60
- _kakasi = None
61
- _dicta = None
62
- _russian_stresser = None
63
-
64
-
65
- def is_kanji(c: str) -> bool:
66
- """Check if character is kanji."""
67
- return 19968 <= ord(c) <= 40959
68
-
69
-
70
- def is_katakana(c: str) -> bool:
71
- """Check if character is katakana."""
72
- return 12449 <= ord(c) <= 12538
73
-
74
-
75
- def hiragana_normalize(text: str) -> str:
76
- """Japanese text normalization: converts kanji to hiragana; katakana remains the same."""
77
- global _kakasi
78
-
79
- try:
80
- if _kakasi is None:
81
- import pykakasi
82
- _kakasi = pykakasi.kakasi()
83
-
84
- result = _kakasi.convert(text)
85
- out = []
86
-
87
- for r in result:
88
- inp = r['orig']
89
- hira = r["hira"]
90
-
91
- # Any kanji in the phrase
92
- if any([is_kanji(c) for c in inp]):
93
- if hira and hira[0] in ["は", "へ"]: # Safety check for empty hira
94
- hira = " " + hira
95
- out.append(hira)
96
-
97
- # All katakana
98
- elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp
99
- out.append(r['orig'])
100
-
101
- else:
102
- out.append(inp)
103
-
104
- normalized_text = "".join(out)
105
-
106
- # Decompose Japanese characters for tokenizer compatibility
107
- import unicodedata
108
- normalized_text = unicodedata.normalize('NFKD', normalized_text)
109
-
110
- return normalized_text
111
-
112
- except ImportError:
113
- logger.warning("pykakasi not available - Japanese text processing skipped")
114
- return text
115
-
116
-
117
- def add_hebrew_diacritics(text: str) -> str:
118
- """Hebrew text normalization: adds diacritics to Hebrew text."""
119
- global _dicta
120
-
121
- try:
122
- if _dicta is None:
123
- from dicta_onnx import Dicta
124
- _dicta = Dicta()
125
-
126
- return _dicta.add_diacritics(text)
127
-
128
- except ImportError:
129
- logger.warning("dicta_onnx not available - Hebrew text processing skipped")
130
- return text
131
- except Exception as e:
132
- logger.warning(f"Hebrew diacritization failed: {e}")
133
- return text
134
-
135
-
136
- def korean_normalize(text: str) -> str:
137
- """Korean text normalization: decompose syllables into Jamo for tokenization."""
138
-
139
- def decompose_hangul(char):
140
- """Decompose Korean syllable into Jamo components."""
141
- if not ('\uac00' <= char <= '\ud7af'):
142
- return char
143
-
144
- # Hangul decomposition formula
145
- base = ord(char) - 0xAC00
146
- initial = chr(0x1100 + base // (21 * 28))
147
- medial = chr(0x1161 + (base % (21 * 28)) // 28)
148
- final = chr(0x11A7 + base % 28) if base % 28 > 0 else ''
149
-
150
- return initial + medial + final
151
-
152
- # Decompose syllables and normalize punctuation
153
- result = ''.join(decompose_hangul(char) for char in text)
154
- return result.strip()
155
-
156
-
157
- class ChineseCangjieConverter:
158
- """Converts Chinese characters to Cangjie codes for tokenization."""
159
-
160
- def __init__(self, model_dir=None):
161
- self.word2cj = {}
162
- self.cj2word = {}
163
- self.segmenter = None
164
- self._load_cangjie_mapping(model_dir)
165
- self._init_segmenter()
166
-
167
- def _load_cangjie_mapping(self, model_dir=None):
168
- """Load Cangjie mapping from HuggingFace model repository."""
169
- try:
170
- cangjie_file = hf_hub_download(
171
- repo_id=REPO_ID,
172
- filename="Cangjie5_TC.json",
173
- cache_dir=model_dir
174
- )
175
-
176
- with open(cangjie_file, "r", encoding="utf-8") as fp:
177
- data = json.load(fp)
178
-
179
- for entry in data:
180
- word, code = entry.split("\t")[:2]
181
- self.word2cj[word] = code
182
- if code not in self.cj2word:
183
- self.cj2word[code] = [word]
184
- else:
185
- self.cj2word[code].append(word)
186
-
187
- except Exception as e:
188
- logger.warning(f"Could not load Cangjie mapping: {e}")
189
-
190
- def _init_segmenter(self):
191
- """Initialize pkuseg segmenter."""
192
- try:
193
- from spacy_pkuseg import pkuseg
194
- self.segmenter = pkuseg()
195
- except ImportError:
196
- logger.warning("pkuseg not available - Chinese segmentation will be skipped")
197
- self.segmenter = None
198
-
199
- def _cangjie_encode(self, glyph: str):
200
- """Encode a single Chinese glyph to Cangjie code."""
201
- normed_glyph = glyph
202
- code = self.word2cj.get(normed_glyph, None)
203
- if code is None: # e.g. Japanese hiragana
204
- return None
205
- index = self.cj2word[code].index(normed_glyph)
206
- index = str(index) if index > 0 else ""
207
- return code + str(index)
208
-
209
-
210
- def __call__(self, text):
211
- """Convert Chinese characters in text to Cangjie tokens."""
212
- output = []
213
- if self.segmenter is not None:
214
- segmented_words = self.segmenter.cut(text)
215
- full_text = " ".join(segmented_words)
216
- else:
217
- full_text = text
218
-
219
- for t in full_text:
220
- if category(t) == "Lo":
221
- cangjie = self._cangjie_encode(t)
222
- if cangjie is None:
223
- output.append(t)
224
- continue
225
- code = []
226
- for c in cangjie:
227
- code.append(f"[cj_{c}]")
228
- code.append("[cj_.]")
229
- code = "".join(code)
230
- output.append(code)
231
- else:
232
- output.append(t)
233
- return "".join(output)
234
-
235
-
236
- def add_russian_stress(text: str) -> str:
237
- """Russian text normalization: adds stress marks to Russian text."""
238
- global _russian_stresser
239
-
240
- try:
241
- if _russian_stresser is None:
242
- from russian_text_stresser.text_stresser import RussianTextStresser
243
- _russian_stresser = RussianTextStresser()
244
-
245
- return _russian_stresser.stress_text(text)
246
-
247
- except ImportError:
248
- logger.warning("russian_text_stresser not available - Russian stress labeling skipped")
249
- return text
250
- except Exception as e:
251
- logger.warning(f"Russian stress labeling failed: {e}")
252
- return text
253
-
254
-
255
- class MTLTokenizer:
256
- def __init__(self, vocab_file_path):
257
- self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
258
- model_dir = Path(vocab_file_path).parent
259
- self.cangjie_converter = ChineseCangjieConverter(model_dir)
260
- self.check_vocabset_sot_eot()
261
-
262
- def check_vocabset_sot_eot(self):
263
- voc = self.tokenizer.get_vocab()
264
- assert SOT in voc
265
- assert EOT in voc
266
-
267
- def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
268
- """
269
- Text preprocessor that handles lowercase conversion and NFKD normalization.
270
- """
271
- preprocessed_text = raw_text
272
- if lowercase:
273
- preprocessed_text = preprocessed_text.lower()
274
- if nfkd_normalize:
275
- preprocessed_text = normalize("NFKD", preprocessed_text)
276
-
277
- return preprocessed_text
278
-
279
- def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
280
- text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
281
- text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
282
- return text_tokens
283
-
284
- def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
285
- txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
286
-
287
- # Language-specific text processing
288
- if language_id == 'zh':
289
- txt = self.cangjie_converter(txt)
290
- elif language_id == 'ja':
291
- txt = hiragana_normalize(txt)
292
- elif language_id == 'he':
293
- txt = add_hebrew_diacritics(txt)
294
- elif language_id == 'ko':
295
- txt = korean_normalize(txt)
296
- elif language_id == 'ru':
297
- txt = add_russian_stress(txt)
298
-
299
- # Prepend language token
300
- if language_id:
301
- txt = f"[{language_id.lower()}]{txt}"
302
-
303
- txt = txt.replace(' ', SPACE)
304
- return self.tokenizer.encode(txt).ids
305
-
306
- def decode(self, seq):
307
- if isinstance(seq, torch.Tensor):
308
- seq = seq.cpu().numpy()
309
-
310
- txt = self.tokenizer.decode(seq, skip_special_tokens=False)
311
- txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '')
312
- return txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/utils.py DELETED
@@ -1,4 +0,0 @@
1
- class AttrDict(dict):
2
- def __init__(self, *args, **kwargs):
3
- super(AttrDict, self).__init__(*args, **kwargs)
4
- self.__dict__ = self
 
 
 
 
 
src/chatterbox/models/voice_encoder/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .voice_encoder import VoiceEncoder, VoiceEncConfig
 
 
src/chatterbox/models/voice_encoder/config.py DELETED
@@ -1,18 +0,0 @@
1
- class VoiceEncConfig:
2
- num_mels = 40
3
- sample_rate = 16000
4
- speaker_embed_size = 256
5
- ve_hidden_size = 256
6
- flatten_lstm_params = False
7
- n_fft = 400
8
- hop_size = 160
9
- win_size = 400
10
- fmax = 8000
11
- fmin = 0
12
- preemphasis = 0.
13
- mel_power = 2.0
14
- mel_type = "amp"
15
- normalized_mels = False
16
- ve_partial_frames = 160
17
- ve_final_relu = True
18
- stft_magnitude_min = 1e-4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/voice_encoder/melspec.py DELETED
@@ -1,78 +0,0 @@
1
- from functools import lru_cache
2
-
3
- from scipy import signal
4
- import numpy as np
5
- import librosa
6
-
7
-
8
- @lru_cache()
9
- def mel_basis(hp):
10
- assert hp.fmax <= hp.sample_rate // 2
11
- return librosa.filters.mel(
12
- sr=hp.sample_rate,
13
- n_fft=hp.n_fft,
14
- n_mels=hp.num_mels,
15
- fmin=hp.fmin,
16
- fmax=hp.fmax) # -> (nmel, nfreq)
17
-
18
-
19
- def preemphasis(wav, hp):
20
- assert hp.preemphasis != 0
21
- wav = signal.lfilter([1, -hp.preemphasis], [1], wav)
22
- wav = np.clip(wav, -1, 1)
23
- return wav
24
-
25
-
26
- def melspectrogram(wav, hp, pad=True):
27
- # Run through pre-emphasis
28
- if hp.preemphasis > 0:
29
- wav = preemphasis(wav, hp)
30
- assert np.abs(wav).max() - 1 < 1e-07
31
-
32
- # Do the stft
33
- spec_complex = _stft(wav, hp, pad=pad)
34
-
35
- # Get the magnitudes
36
- spec_magnitudes = np.abs(spec_complex)
37
-
38
- if hp.mel_power != 1.0:
39
- spec_magnitudes **= hp.mel_power
40
-
41
- # Get the mel and convert magnitudes->db
42
- mel = np.dot(mel_basis(hp), spec_magnitudes)
43
- if hp.mel_type == "db":
44
- mel = _amp_to_db(mel, hp)
45
-
46
- # Normalise the mel from db to 0,1
47
- if hp.normalized_mels:
48
- mel = _normalize(mel, hp).astype(np.float32)
49
-
50
- assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check
51
- return mel # (M, T)
52
-
53
-
54
- def _stft(y, hp, pad=True):
55
- # NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for
56
- # historical consistency and streaming-version consistency
57
- return librosa.stft(
58
- y,
59
- n_fft=hp.n_fft,
60
- hop_length=hp.hop_size,
61
- win_length=hp.win_size,
62
- center=pad,
63
- pad_mode="reflect",
64
- )
65
-
66
-
67
- def _amp_to_db(x, hp):
68
- return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x))
69
-
70
-
71
- def _db_to_amp(x):
72
- return np.power(10.0, x * 0.05)
73
-
74
-
75
- def _normalize(s, hp, headroom_db=15):
76
- min_level_db = 20 * np.log10(hp.stft_magnitude_min)
77
- s = (s - min_level_db) / (-min_level_db + headroom_db)
78
- return s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/models/voice_encoder/voice_encoder.py DELETED
@@ -1,274 +0,0 @@
1
- # Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning
2
- # MIT License
3
- from typing import List, Union, Optional
4
-
5
- import numpy as np
6
- from numpy.lib.stride_tricks import as_strided
7
- import librosa
8
- import torch
9
- import torch.nn.functional as F
10
- from torch import nn, Tensor
11
-
12
- from .config import VoiceEncConfig
13
- from .melspec import melspectrogram
14
-
15
-
16
- def pack(arrays, seq_len: int=None, pad_value=0):
17
- """
18
- Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of
19
- shape (B, T, ...) by padding each individual array on the right.
20
-
21
- :param arrays: a list of array-like objects of matching shapes except for the first axis.
22
- :param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at
23
- minimum. Will default to that value if None.
24
- :param pad_value: the value to pad the arrays with.
25
- :return: a (B, T, ...) tensor
26
- """
27
- if seq_len is None:
28
- seq_len = max(len(array) for array in arrays)
29
- else:
30
- assert seq_len >= max(len(array) for array in arrays)
31
-
32
- # Convert lists to np.array
33
- if isinstance(arrays[0], list):
34
- arrays = [np.array(array) for array in arrays]
35
-
36
- # Convert to tensor and handle device
37
- device = None
38
- if isinstance(arrays[0], torch.Tensor):
39
- tensors = arrays
40
- device = tensors[0].device
41
- else:
42
- tensors = [torch.as_tensor(array) for array in arrays]
43
-
44
- # Fill the packed tensor with the array data
45
- packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:])
46
- packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device)
47
-
48
- for i, tensor in enumerate(tensors):
49
- packed_tensor[i, :tensor.size(0)] = tensor
50
-
51
- return packed_tensor
52
-
53
-
54
- def get_num_wins(
55
- n_frames: int,
56
- step: int,
57
- min_coverage: float,
58
- hp: VoiceEncConfig,
59
- ):
60
- assert n_frames > 0
61
- win_size = hp.ve_partial_frames
62
- n_wins, remainder = divmod(max(n_frames - win_size + step, 0), step)
63
- if n_wins == 0 or (remainder + (win_size - step)) / win_size >= min_coverage:
64
- n_wins += 1
65
- target_n = win_size + step * (n_wins - 1)
66
- return n_wins, target_n
67
-
68
-
69
- def get_frame_step(
70
- overlap: float,
71
- rate: float,
72
- hp: VoiceEncConfig,
73
- ):
74
- # Compute how many frames separate two partial utterances
75
- assert 0 <= overlap < 1
76
- if rate is None:
77
- frame_step = int(np.round(hp.ve_partial_frames * (1 - overlap)))
78
- else:
79
- frame_step = int(np.round((hp.sample_rate / rate) / hp.ve_partial_frames))
80
- assert 0 < frame_step <= hp.ve_partial_frames
81
- return frame_step
82
-
83
-
84
- def stride_as_partials(
85
- mel: np.ndarray,
86
- hp: VoiceEncConfig,
87
- overlap=0.5,
88
- rate: float=None,
89
- min_coverage=0.8,
90
- ):
91
- """
92
- Takes unscaled mels in (T, M) format
93
- TODO: doc
94
- """
95
- assert 0 < min_coverage <= 1
96
- frame_step = get_frame_step(overlap, rate, hp)
97
-
98
- # Compute how many partials can fit in the mel
99
- n_partials, target_len = get_num_wins(len(mel), frame_step, min_coverage, hp)
100
-
101
- # Trim or pad the mel spectrogram to match the number of partials
102
- if target_len > len(mel):
103
- mel = np.concatenate((mel, np.full((target_len - len(mel), hp.num_mels), 0)))
104
- elif target_len < len(mel):
105
- mel = mel[:target_len]
106
-
107
- # Ensure the numpy array data is float32 and contiguous in memory
108
- mel = mel.astype(np.float32, order="C")
109
-
110
- # Re-arrange the array in memory to be of shape (N, P, M) with partials overlapping eachother,
111
- # where N is the number of partials, P is the number of frames of each partial and M the
112
- # number of channels of the mel spectrograms.
113
- shape = (n_partials, hp.ve_partial_frames, hp.num_mels)
114
- strides = (mel.strides[0] * frame_step, mel.strides[0], mel.strides[1])
115
- partials = as_strided(mel, shape, strides)
116
- return partials
117
-
118
-
119
- class VoiceEncoder(nn.Module):
120
- def __init__(self, hp=VoiceEncConfig()):
121
- super().__init__()
122
-
123
- self.hp = hp
124
-
125
- # Network definition
126
- self.lstm = nn.LSTM(self.hp.num_mels, self.hp.ve_hidden_size, num_layers=3, batch_first=True)
127
- if hp.flatten_lstm_params:
128
- self.lstm.flatten_parameters()
129
- self.proj = nn.Linear(self.hp.ve_hidden_size, self.hp.speaker_embed_size)
130
-
131
- # Cosine similarity scaling (fixed initial parameter values)
132
- self.similarity_weight = nn.Parameter(torch.tensor([10.]), requires_grad=True)
133
- self.similarity_bias = nn.Parameter(torch.tensor([-5.]), requires_grad=True)
134
-
135
- @property
136
- def device(self):
137
- return next(self.parameters()).device
138
-
139
- def forward(self, mels: torch.FloatTensor):
140
- """
141
- Computes the embeddings of a batch of partial utterances.
142
-
143
- :param mels: a batch of unscaled mel spectrograms of same duration as a float32 tensor
144
- of shape (B, T, M) where T is hp.ve_partial_frames
145
- :return: the embeddings as a float32 tensor of shape (B, E) where E is
146
- hp.speaker_embed_size. Embeddings are L2-normed and thus lay in the range [-1, 1].
147
- """
148
- if self.hp.normalized_mels and (mels.min() < 0 or mels.max() > 1):
149
- raise Exception(f"Mels outside [0, 1]. Min={mels.min()}, Max={mels.max()}")
150
-
151
- # Pass the input through the LSTM layers
152
- _, (hidden, _) = self.lstm(mels)
153
-
154
- # Project the final hidden state
155
- raw_embeds = self.proj(hidden[-1])
156
- if self.hp.ve_final_relu:
157
- raw_embeds = F.relu(raw_embeds)
158
-
159
- # L2 normalize the embeddings.
160
- return raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)
161
-
162
- def inference(self, mels: torch.Tensor, mel_lens, overlap=0.5, rate: float=None, min_coverage=0.8, batch_size=None):
163
- """
164
- Computes the embeddings of a batch of full utterances with gradients.
165
-
166
- :param mels: (B, T, M) unscaled mels
167
- :return: (B, E) embeddings on CPU
168
- """
169
- mel_lens = mel_lens.tolist() if torch.is_tensor(mel_lens) else mel_lens
170
-
171
- # Compute where to split the utterances into partials
172
- frame_step = get_frame_step(overlap, rate, self.hp)
173
- n_partials, target_lens = zip(*(get_num_wins(l, frame_step, min_coverage, self.hp) for l in mel_lens))
174
-
175
- # Possibly pad the mels to reach the target lengths
176
- len_diff = max(target_lens) - mels.size(1)
177
- if len_diff > 0:
178
- pad = torch.full((mels.size(0), len_diff, self.hp.num_mels), 0, dtype=torch.float32)
179
- mels = torch.cat((mels, pad.to(mels.device)), dim=1)
180
-
181
- # Group all partials together so that we can batch them easily
182
- partials = [
183
- mel[i * frame_step: i * frame_step + self.hp.ve_partial_frames]
184
- for mel, n_partial in zip(mels, n_partials) for i in range(n_partial)
185
- ]
186
- assert all(partials[0].shape == partial.shape for partial in partials)
187
- partials = torch.stack(partials)
188
-
189
- # Forward the partials
190
- n_chunks = int(np.ceil(len(partials) / (batch_size or len(partials))))
191
- partial_embeds = torch.cat([self(batch) for batch in partials.chunk(n_chunks)], dim=0).cpu()
192
-
193
- # Reduce the partial embeds into full embeds and L2-normalize them
194
- slices = np.concatenate(([0], np.cumsum(n_partials)))
195
- raw_embeds = [torch.mean(partial_embeds[start:end], dim=0) for start, end in zip(slices[:-1], slices[1:])]
196
- raw_embeds = torch.stack(raw_embeds)
197
- embeds = raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)
198
-
199
- return embeds
200
-
201
- @staticmethod
202
- def utt_to_spk_embed(utt_embeds: np.ndarray):
203
- """
204
- Takes an array of L2-normalized utterance embeddings, computes the mean embedding and L2-normalize it to get a
205
- speaker embedding.
206
- """
207
- assert utt_embeds.ndim == 2
208
- utt_embeds = np.mean(utt_embeds, axis=0)
209
- return utt_embeds / np.linalg.norm(utt_embeds, 2)
210
-
211
- @staticmethod
212
- def voice_similarity(embeds_x: np.ndarray, embeds_y: np.ndarray):
213
- """
214
- Cosine similarity for L2-normalized utterance embeddings or speaker embeddings
215
- """
216
- embeds_x = embeds_x if embeds_x.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_x)
217
- embeds_y = embeds_y if embeds_y.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_y)
218
- return embeds_x @ embeds_y
219
-
220
- def embeds_from_mels(
221
- self, mels: Union[Tensor, List[np.ndarray]], mel_lens=None, as_spk=False, batch_size=32, **kwargs
222
- ):
223
- """
224
- Convenience function for deriving utterance or speaker embeddings from mel spectrograms.
225
-
226
- :param mels: unscaled mels strictly within [0, 1] as either a (B, T, M) tensor or a list of (Ti, M) arrays.
227
- :param mel_lens: if passing mels as a tensor, individual mel lengths
228
- :param as_spk: whether to return utterance embeddings or a single speaker embedding
229
- :param kwargs: args for inference()
230
-
231
- :returns: embeds as a (B, E) float32 numpy array if <as_spk> is False, else as a (E,) array
232
- """
233
- # Load mels in memory and pack them
234
- if isinstance(mels, List):
235
- mels = [np.asarray(mel) for mel in mels]
236
- assert all(m.shape[1] == mels[0].shape[1] for m in mels), "Mels aren't in (B, T, M) format"
237
- mel_lens = [mel.shape[0] for mel in mels]
238
- mels = pack(mels)
239
-
240
- # Embed them
241
- with torch.inference_mode():
242
- utt_embeds = self.inference(mels.to(self.device), mel_lens, batch_size=batch_size, **kwargs).numpy()
243
-
244
- return self.utt_to_spk_embed(utt_embeds) if as_spk else utt_embeds
245
-
246
- def embeds_from_wavs(
247
- self,
248
- wavs: List[np.ndarray],
249
- sample_rate,
250
- as_spk=False,
251
- batch_size=32,
252
- trim_top_db: Optional[float]=20,
253
- **kwargs
254
- ):
255
- """
256
- Wrapper around embeds_from_mels
257
-
258
- :param trim_top_db: this argument was only added for the sake of compatibility with metavoice's implementation
259
- """
260
- if sample_rate != self.hp.sample_rate:
261
- wavs = [
262
- librosa.resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast")
263
- for wav in wavs
264
- ]
265
-
266
- if trim_top_db:
267
- wavs = [librosa.effects.trim(wav, top_db=trim_top_db)[0] for wav in wavs]
268
-
269
- if "rate" not in kwargs:
270
- kwargs["rate"] = 1.3 # Resemble's default value.
271
-
272
- mels = [melspectrogram(w, self.hp).T for w in wavs]
273
-
274
- return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/mtl_tts.py DELETED
@@ -1,301 +0,0 @@
1
- from dataclasses import dataclass
2
- from pathlib import Path
3
- import os
4
-
5
- import librosa
6
- import torch
7
- import perth
8
- import torch.nn.functional as F
9
- from safetensors.torch import load_file as load_safetensors
10
- from huggingface_hub import snapshot_download
11
-
12
- from .models.t3 import T3
13
- from .models.t3.modules.t3_config import T3Config
14
- from .models.s3tokenizer import S3_SR, drop_invalid_tokens
15
- from .models.s3gen import S3GEN_SR, S3Gen
16
- from .models.tokenizers import MTLTokenizer
17
- from .models.voice_encoder import VoiceEncoder
18
- from .models.t3.modules.cond_enc import T3Cond
19
-
20
-
21
- REPO_ID = "ResembleAI/chatterbox"
22
-
23
- # Supported languages for the multilingual model
24
- SUPPORTED_LANGUAGES = {
25
- "ar": "Arabic",
26
- "da": "Danish",
27
- "de": "German",
28
- "el": "Greek",
29
- "en": "English",
30
- "es": "Spanish",
31
- "fi": "Finnish",
32
- "fr": "French",
33
- "he": "Hebrew",
34
- "hi": "Hindi",
35
- "it": "Italian",
36
- "ja": "Japanese",
37
- "ko": "Korean",
38
- "ms": "Malay",
39
- "nl": "Dutch",
40
- "no": "Norwegian",
41
- "pl": "Polish",
42
- "pt": "Portuguese",
43
- "ru": "Russian",
44
- "sv": "Swedish",
45
- "sw": "Swahili",
46
- "tr": "Turkish",
47
- "zh": "Chinese",
48
- }
49
-
50
-
51
- def punc_norm(text: str) -> str:
52
- """
53
- Quick cleanup func for punctuation from LLMs or
54
- containing chars not seen often in the dataset
55
- """
56
- if len(text) == 0:
57
- return "You need to add some text for me to talk."
58
-
59
- # Capitalise first letter
60
- if text[0].islower():
61
- text = text[0].upper() + text[1:]
62
-
63
- # Remove multiple space chars
64
- text = " ".join(text.split())
65
-
66
- # Replace uncommon/llm punc
67
- punc_to_replace = [
68
- ("...", ", "),
69
- ("…", ", "),
70
- (":", ","),
71
- (" - ", ", "),
72
- (";", ", "),
73
- ("—", "-"),
74
- ("–", "-"),
75
- (" ,", ","),
76
- ("“", "\""),
77
- ("”", "\""),
78
- ("‘", "'"),
79
- ("’", "'"),
80
- ]
81
- for old_char_sequence, new_char in punc_to_replace:
82
- text = text.replace(old_char_sequence, new_char)
83
-
84
- # Add full stop if no ending punc
85
- text = text.rstrip(" ")
86
- sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
87
- if not any(text.endswith(p) for p in sentence_enders):
88
- text += "."
89
-
90
- return text
91
-
92
-
93
- @dataclass
94
- class Conditionals:
95
- """
96
- Conditionals for T3 and S3Gen
97
- - T3 conditionals:
98
- - speaker_emb
99
- - clap_emb
100
- - cond_prompt_speech_tokens
101
- - cond_prompt_speech_emb
102
- - emotion_adv
103
- - S3Gen conditionals:
104
- - prompt_token
105
- - prompt_token_len
106
- - prompt_feat
107
- - prompt_feat_len
108
- - embedding
109
- """
110
- t3: T3Cond
111
- gen: dict
112
-
113
- def to(self, device):
114
- self.t3 = self.t3.to(device=device)
115
- for k, v in self.gen.items():
116
- if torch.is_tensor(v):
117
- self.gen[k] = v.to(device=device)
118
- return self
119
-
120
- def save(self, fpath: Path):
121
- arg_dict = dict(
122
- t3=self.t3.__dict__,
123
- gen=self.gen
124
- )
125
- torch.save(arg_dict, fpath)
126
-
127
- @classmethod
128
- def load(cls, fpath, map_location="cpu"):
129
- kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
130
- return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
131
-
132
-
133
- class ChatterboxMultilingualTTS:
134
- ENC_COND_LEN = 6 * S3_SR
135
- DEC_COND_LEN = 10 * S3GEN_SR
136
-
137
- def __init__(
138
- self,
139
- t3: T3,
140
- s3gen: S3Gen,
141
- ve: VoiceEncoder,
142
- tokenizer: MTLTokenizer,
143
- device: str,
144
- conds: Conditionals = None,
145
- ):
146
- self.sr = S3GEN_SR # sample rate of synthesized audio
147
- self.t3 = t3
148
- self.s3gen = s3gen
149
- self.ve = ve
150
- self.tokenizer = tokenizer
151
- self.device = device
152
- self.conds = conds
153
- self.watermarker = perth.PerthImplicitWatermarker()
154
-
155
- @classmethod
156
- def get_supported_languages(cls):
157
- """Return dictionary of supported language codes and names."""
158
- return SUPPORTED_LANGUAGES.copy()
159
-
160
- @classmethod
161
- def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
162
- ckpt_dir = Path(ckpt_dir)
163
-
164
- ve = VoiceEncoder()
165
- ve.load_state_dict(
166
- torch.load(ckpt_dir / "ve.pt", weights_only=True)
167
- )
168
- ve.to(device).eval()
169
-
170
- t3 = T3(T3Config.multilingual())
171
- t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
172
- if "model" in t3_state.keys():
173
- t3_state = t3_state["model"][0]
174
- t3.load_state_dict(t3_state)
175
- t3.to(device).eval()
176
-
177
- s3gen = S3Gen()
178
- s3gen.load_state_dict(
179
- torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
180
- )
181
- s3gen.to(device).eval()
182
-
183
- tokenizer = MTLTokenizer(
184
- str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
185
- )
186
-
187
- conds = None
188
- if (builtin_voice := ckpt_dir / "conds.pt").exists():
189
- conds = Conditionals.load(builtin_voice).to(device)
190
-
191
- return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
192
-
193
- @classmethod
194
- def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS':
195
- ckpt_dir = Path(
196
- snapshot_download(
197
- repo_id=REPO_ID,
198
- repo_type="model",
199
- revision="main",
200
- allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
201
- token=os.getenv("HF_TOKEN"),
202
- )
203
- )
204
- return cls.from_local(ckpt_dir, device)
205
-
206
- def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
207
- ## Load reference wav
208
- s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
209
-
210
- ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
211
-
212
- s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
213
- s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
214
-
215
- # Speech cond prompt tokens
216
- t3_cond_prompt_tokens = None
217
- if plen := self.t3.hp.speech_cond_prompt_len:
218
- s3_tokzr = self.s3gen.tokenizer
219
- t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
220
- t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
221
-
222
- # Voice-encoder speaker embedding
223
- ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
224
- ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
225
-
226
- t3_cond = T3Cond(
227
- speaker_emb=ve_embed,
228
- cond_prompt_speech_tokens=t3_cond_prompt_tokens,
229
- emotion_adv=exaggeration * torch.ones(1, 1, 1),
230
- ).to(device=self.device)
231
- self.conds = Conditionals(t3_cond, s3gen_ref_dict)
232
-
233
- def generate(
234
- self,
235
- text,
236
- language_id,
237
- audio_prompt_path=None,
238
- exaggeration=0.5,
239
- cfg_weight=0.5,
240
- temperature=0.8,
241
- repetition_penalty=2.0,
242
- min_p=0.05,
243
- top_p=1.0,
244
- ):
245
- # Validate language_id
246
- if language_id and language_id.lower() not in SUPPORTED_LANGUAGES:
247
- supported_langs = ", ".join(SUPPORTED_LANGUAGES.keys())
248
- raise ValueError(
249
- f"Unsupported language_id '{language_id}'. "
250
- f"Supported languages: {supported_langs}"
251
- )
252
-
253
- if audio_prompt_path:
254
- self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
255
- else:
256
- assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
257
-
258
- # Update exaggeration if needed
259
- if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
260
- _cond: T3Cond = self.conds.t3
261
- self.conds.t3 = T3Cond(
262
- speaker_emb=_cond.speaker_emb,
263
- cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
264
- emotion_adv=exaggeration * torch.ones(1, 1, 1),
265
- ).to(device=self.device)
266
-
267
- # Norm and tokenize text
268
- text = punc_norm(text)
269
- text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device)
270
- text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
271
-
272
- sot = self.t3.hp.start_text_token
273
- eot = self.t3.hp.stop_text_token
274
- text_tokens = F.pad(text_tokens, (1, 0), value=sot)
275
- text_tokens = F.pad(text_tokens, (0, 1), value=eot)
276
-
277
- with torch.inference_mode():
278
- speech_tokens = self.t3.inference(
279
- t3_cond=self.conds.t3,
280
- text_tokens=text_tokens,
281
- max_new_tokens=1000, # TODO: use the value in config
282
- temperature=temperature,
283
- cfg_weight=cfg_weight,
284
- repetition_penalty=repetition_penalty,
285
- min_p=min_p,
286
- top_p=top_p,
287
- )
288
- # Extract only the conditional batch.
289
- speech_tokens = speech_tokens[0]
290
-
291
- # TODO: output becomes 1D
292
- speech_tokens = drop_invalid_tokens(speech_tokens)
293
- speech_tokens = speech_tokens.to(self.device)
294
-
295
- wav, _ = self.s3gen.inference(
296
- speech_tokens=speech_tokens,
297
- ref_dict=self.conds.gen,
298
- )
299
- wav = wav.squeeze(0).detach().cpu().numpy()
300
- watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
301
- return torch.from_numpy(watermarked_wav).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/tts.py DELETED
@@ -1,272 +0,0 @@
1
- from dataclasses import dataclass
2
- from pathlib import Path
3
-
4
- import librosa
5
- import torch
6
- import perth
7
- import torch.nn.functional as F
8
- from huggingface_hub import hf_hub_download
9
- from safetensors.torch import load_file
10
-
11
- from .models.t3 import T3
12
- from .models.s3tokenizer import S3_SR, drop_invalid_tokens
13
- from .models.s3gen import S3GEN_SR, S3Gen
14
- from .models.tokenizers import EnTokenizer
15
- from .models.voice_encoder import VoiceEncoder
16
- from .models.t3.modules.cond_enc import T3Cond
17
-
18
-
19
- REPO_ID = "ResembleAI/chatterbox"
20
-
21
-
22
- def punc_norm(text: str) -> str:
23
- """
24
- Quick cleanup func for punctuation from LLMs or
25
- containing chars not seen often in the dataset
26
- """
27
- if len(text) == 0:
28
- return "You need to add some text for me to talk."
29
-
30
- # Capitalise first letter
31
- if text[0].islower():
32
- text = text[0].upper() + text[1:]
33
-
34
- # Remove multiple space chars
35
- text = " ".join(text.split())
36
-
37
- # Replace uncommon/llm punc
38
- punc_to_replace = [
39
- ("...", ", "),
40
- ("…", ", "),
41
- (":", ","),
42
- (" - ", ", "),
43
- (";", ", "),
44
- ("—", "-"),
45
- ("–", "-"),
46
- (" ,", ","),
47
- ("“", "\""),
48
- ("”", "\""),
49
- ("‘", "'"),
50
- ("’", "'"),
51
- ]
52
- for old_char_sequence, new_char in punc_to_replace:
53
- text = text.replace(old_char_sequence, new_char)
54
-
55
- # Add full stop if no ending punc
56
- text = text.rstrip(" ")
57
- sentence_enders = {".", "!", "?", "-", ","}
58
- if not any(text.endswith(p) for p in sentence_enders):
59
- text += "."
60
-
61
- return text
62
-
63
-
64
- @dataclass
65
- class Conditionals:
66
- """
67
- Conditionals for T3 and S3Gen
68
- - T3 conditionals:
69
- - speaker_emb
70
- - clap_emb
71
- - cond_prompt_speech_tokens
72
- - cond_prompt_speech_emb
73
- - emotion_adv
74
- - S3Gen conditionals:
75
- - prompt_token
76
- - prompt_token_len
77
- - prompt_feat
78
- - prompt_feat_len
79
- - embedding
80
- """
81
- t3: T3Cond
82
- gen: dict
83
-
84
- def to(self, device):
85
- self.t3 = self.t3.to(device=device)
86
- for k, v in self.gen.items():
87
- if torch.is_tensor(v):
88
- self.gen[k] = v.to(device=device)
89
- return self
90
-
91
- def save(self, fpath: Path):
92
- arg_dict = dict(
93
- t3=self.t3.__dict__,
94
- gen=self.gen
95
- )
96
- torch.save(arg_dict, fpath)
97
-
98
- @classmethod
99
- def load(cls, fpath, map_location="cpu"):
100
- if isinstance(map_location, str):
101
- map_location = torch.device(map_location)
102
- kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
103
- return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
104
-
105
-
106
- class ChatterboxTTS:
107
- ENC_COND_LEN = 6 * S3_SR
108
- DEC_COND_LEN = 10 * S3GEN_SR
109
-
110
- def __init__(
111
- self,
112
- t3: T3,
113
- s3gen: S3Gen,
114
- ve: VoiceEncoder,
115
- tokenizer: EnTokenizer,
116
- device: str,
117
- conds: Conditionals = None,
118
- ):
119
- self.sr = S3GEN_SR # sample rate of synthesized audio
120
- self.t3 = t3
121
- self.s3gen = s3gen
122
- self.ve = ve
123
- self.tokenizer = tokenizer
124
- self.device = device
125
- self.conds = conds
126
- self.watermarker = perth.PerthImplicitWatermarker()
127
-
128
- @classmethod
129
- def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
130
- ckpt_dir = Path(ckpt_dir)
131
-
132
- # Always load to CPU first for non-CUDA devices to handle CUDA-saved models
133
- if device in ["cpu", "mps"]:
134
- map_location = torch.device('cpu')
135
- else:
136
- map_location = None
137
-
138
- ve = VoiceEncoder()
139
- ve.load_state_dict(
140
- load_file(ckpt_dir / "ve.safetensors")
141
- )
142
- ve.to(device).eval()
143
-
144
- t3 = T3()
145
- t3_state = load_file(ckpt_dir / "t3_cfg.safetensors")
146
- if "model" in t3_state.keys():
147
- t3_state = t3_state["model"][0]
148
- t3.load_state_dict(t3_state)
149
- t3.to(device).eval()
150
-
151
- s3gen = S3Gen()
152
- s3gen.load_state_dict(
153
- load_file(ckpt_dir / "s3gen.safetensors"), strict=False
154
- )
155
- s3gen.to(device).eval()
156
-
157
- tokenizer = EnTokenizer(
158
- str(ckpt_dir / "tokenizer.json")
159
- )
160
-
161
- conds = None
162
- if (builtin_voice := ckpt_dir / "conds.pt").exists():
163
- conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
164
-
165
- return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
166
-
167
- @classmethod
168
- def from_pretrained(cls, device) -> 'ChatterboxTTS':
169
- # Check if MPS is available on macOS
170
- if device == "mps" and not torch.backends.mps.is_available():
171
- if not torch.backends.mps.is_built():
172
- print("MPS not available because the current PyTorch install was not built with MPS enabled.")
173
- else:
174
- print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
175
- device = "cpu"
176
-
177
- for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]:
178
- local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
179
-
180
- return cls.from_local(Path(local_path).parent, device)
181
-
182
- def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
183
- ## Load reference wav
184
- s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
185
-
186
- ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
187
-
188
- s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
189
- s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
190
-
191
- # Speech cond prompt tokens
192
- if plen := self.t3.hp.speech_cond_prompt_len:
193
- s3_tokzr = self.s3gen.tokenizer
194
- t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
195
- t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
196
-
197
- # Voice-encoder speaker embedding
198
- ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
199
- ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
200
-
201
- t3_cond = T3Cond(
202
- speaker_emb=ve_embed,
203
- cond_prompt_speech_tokens=t3_cond_prompt_tokens,
204
- emotion_adv=exaggeration * torch.ones(1, 1, 1),
205
- ).to(device=self.device)
206
- self.conds = Conditionals(t3_cond, s3gen_ref_dict)
207
-
208
- def generate(
209
- self,
210
- text,
211
- repetition_penalty=1.2,
212
- min_p=0.05,
213
- top_p=1.0,
214
- audio_prompt_path=None,
215
- exaggeration=0.5,
216
- cfg_weight=0.5,
217
- temperature=0.8,
218
- ):
219
- if audio_prompt_path:
220
- self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
221
- else:
222
- assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
223
-
224
- # Update exaggeration if needed
225
- if exaggeration != self.conds.t3.emotion_adv[0, 0, 0]:
226
- _cond: T3Cond = self.conds.t3
227
- self.conds.t3 = T3Cond(
228
- speaker_emb=_cond.speaker_emb,
229
- cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
230
- emotion_adv=exaggeration * torch.ones(1, 1, 1),
231
- ).to(device=self.device)
232
-
233
- # Norm and tokenize text
234
- text = punc_norm(text)
235
- text_tokens = self.tokenizer.text_to_tokens(text).to(self.device)
236
-
237
- if cfg_weight > 0.0:
238
- text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
239
-
240
- sot = self.t3.hp.start_text_token
241
- eot = self.t3.hp.stop_text_token
242
- text_tokens = F.pad(text_tokens, (1, 0), value=sot)
243
- text_tokens = F.pad(text_tokens, (0, 1), value=eot)
244
-
245
- with torch.inference_mode():
246
- speech_tokens = self.t3.inference(
247
- t3_cond=self.conds.t3,
248
- text_tokens=text_tokens,
249
- max_new_tokens=1000, # TODO: use the value in config
250
- temperature=temperature,
251
- cfg_weight=cfg_weight,
252
- repetition_penalty=repetition_penalty,
253
- min_p=min_p,
254
- top_p=top_p,
255
- )
256
- # Extract only the conditional batch.
257
- speech_tokens = speech_tokens[0]
258
-
259
- # TODO: output becomes 1D
260
- speech_tokens = drop_invalid_tokens(speech_tokens)
261
-
262
- speech_tokens = speech_tokens[speech_tokens < 6561]
263
-
264
- speech_tokens = speech_tokens.to(self.device)
265
-
266
- wav, _ = self.s3gen.inference(
267
- speech_tokens=speech_tokens,
268
- ref_dict=self.conds.gen,
269
- )
270
- wav = wav.squeeze(0).detach().cpu().numpy()
271
- watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
272
- return torch.from_numpy(watermarked_wav).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chatterbox/vc.py DELETED
@@ -1,104 +0,0 @@
1
- from pathlib import Path
2
-
3
- import librosa
4
- import torch
5
- import perth
6
- from huggingface_hub import hf_hub_download
7
- from safetensors.torch import load_file
8
-
9
- from .models.s3tokenizer import S3_SR
10
- from .models.s3gen import S3GEN_SR, S3Gen
11
-
12
-
13
- REPO_ID = "ResembleAI/chatterbox"
14
-
15
-
16
- class ChatterboxVC:
17
- ENC_COND_LEN = 6 * S3_SR
18
- DEC_COND_LEN = 10 * S3GEN_SR
19
-
20
- def __init__(
21
- self,
22
- s3gen: S3Gen,
23
- device: str,
24
- ref_dict: dict=None,
25
- ):
26
- self.sr = S3GEN_SR
27
- self.s3gen = s3gen
28
- self.device = device
29
- self.watermarker = perth.PerthImplicitWatermarker()
30
- if ref_dict is None:
31
- self.ref_dict = None
32
- else:
33
- self.ref_dict = {
34
- k: v.to(device) if torch.is_tensor(v) else v
35
- for k, v in ref_dict.items()
36
- }
37
-
38
- @classmethod
39
- def from_local(cls, ckpt_dir, device) -> 'ChatterboxVC':
40
- ckpt_dir = Path(ckpt_dir)
41
-
42
- # Always load to CPU first for non-CUDA devices to handle CUDA-saved models
43
- if device in ["cpu", "mps"]:
44
- map_location = torch.device('cpu')
45
- else:
46
- map_location = None
47
-
48
- ref_dict = None
49
- if (builtin_voice := ckpt_dir / "conds.pt").exists():
50
- states = torch.load(builtin_voice, map_location=map_location)
51
- ref_dict = states['gen']
52
-
53
- s3gen = S3Gen()
54
- s3gen.load_state_dict(
55
- load_file(ckpt_dir / "s3gen.safetensors"), strict=False
56
- )
57
- s3gen.to(device).eval()
58
-
59
- return cls(s3gen, device, ref_dict=ref_dict)
60
-
61
- @classmethod
62
- def from_pretrained(cls, device) -> 'ChatterboxVC':
63
- # Check if MPS is available on macOS
64
- if device == "mps" and not torch.backends.mps.is_available():
65
- if not torch.backends.mps.is_built():
66
- print("MPS not available because the current PyTorch install was not built with MPS enabled.")
67
- else:
68
- print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
69
- device = "cpu"
70
-
71
- for fpath in ["s3gen.safetensors", "conds.pt"]:
72
- local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
73
-
74
- return cls.from_local(Path(local_path).parent, device)
75
-
76
- def set_target_voice(self, wav_fpath):
77
- ## Load reference wav
78
- s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
79
-
80
- s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
81
- self.ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
82
-
83
- def generate(
84
- self,
85
- audio,
86
- target_voice_path=None,
87
- ):
88
- if target_voice_path:
89
- self.set_target_voice(target_voice_path)
90
- else:
91
- assert self.ref_dict is not None, "Please `prepare_conditionals` first or specify `target_voice_path`"
92
-
93
- with torch.inference_mode():
94
- audio_16, _ = librosa.load(audio, sr=S3_SR)
95
- audio_16 = torch.from_numpy(audio_16).float().to(self.device)[None, ]
96
-
97
- s3_tokens, _ = self.s3gen.tokenizer(audio_16)
98
- wav, _ = self.s3gen.inference(
99
- speech_tokens=s3_tokens,
100
- ref_dict=self.ref_dict,
101
- )
102
- wav = wav.squeeze(0).detach().cpu().numpy()
103
- watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
104
- return torch.from_numpy(watermarked_wav).unsqueeze(0)