primepake commited on
Commit
55ac664
·
1 Parent(s): 9f4fc9f

add contrastive loss

Browse files
speech/config.yaml CHANGED
@@ -73,6 +73,10 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
73
  training_cfg_rate: 0.2
74
  inference_cfg_rate: 0.7
75
  reg_loss_type: 'l1'
 
 
 
 
76
  estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
77
  in_channels: 320
78
  out_channels: 80
@@ -161,6 +165,7 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
161
  center: False
162
  compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
163
  feat_extractor: !ref <feat_extractor>
 
164
  compute_f0: !name:cosyvoice.dataset.processor.compute_f0
165
  sample_rate: !ref <sample_rate>
166
  hop_size: 480
@@ -172,7 +177,7 @@ sort: !name:cosyvoice.dataset.processor.sort
172
  sort_size: 500 # sort_size should be less than shuffle_size
173
  batch: !name:cosyvoice.dataset.processor.batch
174
  batch_type: 'dynamic'
175
- max_frames_in_batch: 2000
176
  padding: !name:cosyvoice.dataset.processor.padding
177
  use_spk_embedding: False # change to True during sft
178
 
@@ -195,12 +200,12 @@ data_pipeline: [
195
  train_conf:
196
  optim: adamw
197
  optim_conf:
198
- lr: 1e-5 # change to 1e-5 during sft
199
  scheduler: constantlr # change to constantlr during sft
200
  scheduler_conf:
201
  warmup_steps: 2500
202
  max_epoch: 200
203
  grad_clip: 1
204
  accum_grad: 1
205
- log_interval: 100
206
  save_per_step: -1
 
73
  training_cfg_rate: 0.2
74
  inference_cfg_rate: 0.7
75
  reg_loss_type: 'l1'
76
+ use_immiscible: True
77
+ immiscible_k: 8
78
+ use_contrastive_fm: True
79
+ contrastive_lambda: 0.05
80
  estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
81
  in_channels: 320
82
  out_channels: 80
 
165
  center: False
166
  compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
167
  feat_extractor: !ref <feat_extractor>
168
+ token_mel_ratio: !ref <token_mel_ratio>
169
  compute_f0: !name:cosyvoice.dataset.processor.compute_f0
170
  sample_rate: !ref <sample_rate>
171
  hop_size: 480
 
177
  sort_size: 500 # sort_size should be less than shuffle_size
178
  batch: !name:cosyvoice.dataset.processor.batch
179
  batch_type: 'dynamic'
180
+ max_frames_in_batch: 25000
181
  padding: !name:cosyvoice.dataset.processor.padding
182
  use_spk_embedding: False # change to True during sft
183
 
 
200
  train_conf:
201
  optim: adamw
202
  optim_conf:
203
+ lr: 2e-6 # change to 1e-5 during sft
204
  scheduler: constantlr # change to constantlr during sft
205
  scheduler_conf:
206
  warmup_steps: 2500
207
  max_epoch: 200
208
  grad_clip: 1
209
  accum_grad: 1
210
+ log_interval: 5
211
  save_per_step: -1
speech/cosyvoice/flow/flow.py CHANGED
@@ -11,9 +11,10 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
14
  import logging
15
  import random
16
- from typing import Dict, Optional
17
  import torch
18
  import torch.nn as nn
19
  from torch.nn import functional as F
@@ -22,24 +23,57 @@ from cosyvoice.utils.mask import make_pad_mask
22
 
23
 
24
  class MaskedDiffWithXvec(torch.nn.Module):
25
- def __init__(self,
26
- input_size: int = 512,
27
- output_size: int = 80,
28
- spk_embed_dim: int = 192,
29
- output_type: str = "mel",
30
- vocab_size: int = 4096,
31
- input_frame_rate: int = 50,
32
- only_mask_loss: bool = True,
33
- encoder: torch.nn.Module = None,
34
- length_regulator: torch.nn.Module = None,
35
- decoder: torch.nn.Module = None,
36
- decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
- 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
- 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
- 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
- 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
- mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
- 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  super().__init__()
44
  self.input_size = input_size
45
  self.output_size = output_size
@@ -58,22 +92,22 @@ class MaskedDiffWithXvec(torch.nn.Module):
58
  self.only_mask_loss = only_mask_loss
59
 
60
  def forward(
61
- self,
62
- batch: dict,
63
- device: torch.device,
64
  ) -> Dict[str, Optional[torch.Tensor]]:
65
- token = batch['speech_token'].to(device)
66
- token_len = batch['speech_token_len'].to(device)
67
- feat = batch['speech_feat'].to(device)
68
- feat_len = batch['speech_feat_len'].to(device)
69
- embedding = batch['embedding'].to(device)
70
 
71
  # xvec projection
72
  embedding = F.normalize(embedding, dim=1)
73
  embedding = self.spk_embed_affine_layer(embedding)
74
 
75
  # concat text and prompt_text
76
- print('token_len values: ', token_len)
77
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
78
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
79
 
@@ -98,20 +132,22 @@ class MaskedDiffWithXvec(torch.nn.Module):
98
  mask.unsqueeze(1),
99
  h.transpose(1, 2).contiguous(),
100
  embedding,
101
- cond=conds
102
  )
103
- return {'loss': loss}
104
 
105
  @torch.inference_mode()
106
- def inference(self,
107
- token,
108
- token_len,
109
- prompt_token,
110
- prompt_token_len,
111
- prompt_feat,
112
- prompt_feat_len,
113
- embedding,
114
- flow_cache):
 
 
115
  assert token.shape[0] == 1
116
  # xvec projection
117
  embedding = F.normalize(embedding, dim=1)
@@ -119,18 +155,31 @@ class MaskedDiffWithXvec(torch.nn.Module):
119
 
120
  # concat speech token and prompt speech token
121
  token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
122
- token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
 
 
 
123
  mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
124
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
125
 
126
  # text encode
127
  h, h_lengths = self.encoder(token, token_len)
128
  h = self.encoder_proj(h)
129
- mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
130
- h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
 
 
 
 
 
 
 
 
131
 
132
  # get conditions
133
- conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
 
 
134
  conds[:, :mel_len1] = prompt_feat
135
  conds = conds.transpose(1, 2)
136
 
@@ -142,7 +191,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
142
  cond=conds,
143
  n_timesteps=10,
144
  prompt_len=mel_len1,
145
- cache=flow_cache
146
  )
147
  feat = feat[:, :, mel_len1:]
148
  assert feat.shape[2] == mel_len2
@@ -150,25 +199,58 @@ class MaskedDiffWithXvec(torch.nn.Module):
150
 
151
 
152
  class CausalMaskedDiffWithXvec(torch.nn.Module):
153
- def __init__(self,
154
- input_size: int = 512,
155
- output_size: int = 80,
156
- spk_embed_dim: int = 192,
157
- output_type: str = "mel",
158
- vocab_size: int = 4096,
159
- input_frame_rate: int = 50,
160
- only_mask_loss: bool = True,
161
- token_mel_ratio: int = 2,
162
- pre_lookahead_len: int = 3,
163
- encoder: torch.nn.Module = None,
164
- decoder: torch.nn.Module = None,
165
- decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
166
- 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
167
- 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
168
- 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
169
- 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
170
- mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
171
- 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  super().__init__()
173
  self.input_size = input_size
174
  self.output_size = output_size
@@ -186,32 +268,26 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
186
  self.only_mask_loss = only_mask_loss
187
  self.token_mel_ratio = token_mel_ratio
188
  self.pre_lookahead_len = pre_lookahead_len
 
 
189
 
190
  def forward(
191
- self,
192
- batch: dict,
193
- device: torch.device,
194
  ) -> Dict[str, Optional[torch.Tensor]]:
195
- token = batch['speech_token'].to(device)
196
- token_len = batch['speech_token_len'].to(device)
197
- feat = batch['speech_feat'].to(device)
198
- feat_len = batch['speech_feat_len'].to(device)
199
- embedding = batch['embedding'].to(device)
200
-
201
- # print('token: ', token.shape)
202
- # print('token_len: ', token_len.shape)
203
- # print('feat: ', feat.shape)
204
- # print('feat_len: ', feat_len.shape)
205
- # print('embedding: ', embedding.shape)
206
 
207
  # NOTE unified training, static_chunk_size > 0 or = 0
208
- streaming = False# if random.random() < 0.5 else False
209
 
210
  # xvec projection
211
  embedding = F.normalize(embedding, dim=1)
212
  embedding = self.spk_embed_affine_layer(embedding)
213
- # print('token_len values: ', token_len)
214
- # concat text and prompt_text
215
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
216
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
217
 
@@ -229,42 +305,50 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
229
  conds = conds.transpose(1, 2)
230
 
231
  mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
232
-
233
- # print('feat shape: ', feat.shape)
234
- # print('mask shape: ', mask.shape)
235
- # print('h shape: ', h.shape)
236
- # print('embedding shape: ', embedding.shape)
237
- # print('conds shape: ', conds.shape)
238
- # print('streaming: ', streaming)
239
-
240
- loss, _ = self.decoder.compute_loss(
241
- feat.transpose(1, 2).contiguous(),
242
- mask.unsqueeze(1),
243
- h.transpose(1, 2).contiguous(),
244
- embedding,
245
- cond=conds,
246
- streaming=streaming,
247
- )
248
- return {'loss': loss}
 
 
 
249
 
250
  @torch.inference_mode()
251
- def inference(self,
252
- token,
253
- token_len,
254
- prompt_token,
255
- prompt_token_len,
256
- prompt_feat,
257
- prompt_feat_len,
258
- embedding,
259
- streaming,
260
- finalize):
 
 
261
  assert token.shape[0] == 1
262
  # xvec projection
263
  embedding = F.normalize(embedding, dim=1)
264
  embedding = self.spk_embed_affine_layer(embedding)
265
 
266
  # concat text and prompt_text
267
- token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
 
 
 
268
  mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
269
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
270
 
@@ -272,13 +356,20 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
272
  if finalize is True:
273
  h, h_lengths = self.encoder(token, token_len, streaming=streaming)
274
  else:
275
- token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
276
- h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
 
 
 
 
 
277
  mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
278
  h = self.encoder_proj(h)
279
 
280
  # get conditions
281
- conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
 
 
282
  conds[:, :mel_len1] = prompt_feat
283
  conds = conds.transpose(1, 2)
284
 
@@ -289,7 +380,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
289
  spks=embedding,
290
  cond=conds,
291
  n_timesteps=10,
292
- streaming=streaming
293
  )
294
  feat = feat[:, :, mel_len1:]
295
  assert feat.shape[2] == mel_len2
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ from ast import List
15
  import logging
16
  import random
17
+ from typing import Dict, Optional, Tuple
18
  import torch
19
  import torch.nn as nn
20
  from torch.nn import functional as F
 
23
 
24
 
25
  class MaskedDiffWithXvec(torch.nn.Module):
26
+ def __init__(
27
+ self,
28
+ input_size: int = 512,
29
+ output_size: int = 80,
30
+ spk_embed_dim: int = 192,
31
+ output_type: str = "mel",
32
+ vocab_size: int = 4096,
33
+ input_frame_rate: int = 50,
34
+ only_mask_loss: bool = True,
35
+ encoder: torch.nn.Module = None,
36
+ length_regulator: torch.nn.Module = None,
37
+ decoder: torch.nn.Module = None,
38
+ decoder_conf: Dict = {
39
+ "in_channels": 240,
40
+ "out_channel": 80,
41
+ "spk_emb_dim": 80,
42
+ "n_spks": 1,
43
+ "cfm_params": DictConfig(
44
+ {
45
+ "sigma_min": 1e-06,
46
+ "solver": "euler",
47
+ "t_scheduler": "cosine",
48
+ "training_cfg_rate": 0.2,
49
+ "inference_cfg_rate": 0.7,
50
+ "reg_loss_type": "l1",
51
+ "use_immiscible": True,
52
+ "immiscible_k": 8,
53
+ "use_contrastive_fm": False,
54
+ "contrastive_lambda": 0.05
55
+ }
56
+ ),
57
+ "decoder_params": {
58
+ "channels": [256, 256],
59
+ "dropout": 0.0,
60
+ "attention_head_dim": 64,
61
+ "n_blocks": 4,
62
+ "num_mid_blocks": 12,
63
+ "num_heads": 8,
64
+ "act_fn": "gelu",
65
+ },
66
+ },
67
+ mel_feat_conf: Dict = {
68
+ "n_fft": 1024,
69
+ "num_mels": 80,
70
+ "sampling_rate": 22050,
71
+ "hop_size": 256,
72
+ "win_size": 1024,
73
+ "fmin": 0,
74
+ "fmax": 8000,
75
+ },
76
+ ):
77
  super().__init__()
78
  self.input_size = input_size
79
  self.output_size = output_size
 
92
  self.only_mask_loss = only_mask_loss
93
 
94
  def forward(
95
+ self,
96
+ batch: dict,
97
+ device: torch.device,
98
  ) -> Dict[str, Optional[torch.Tensor]]:
99
+ token = batch["speech_token"].to(device)
100
+ token_len = batch["speech_token_len"].to(device)
101
+ feat = batch["speech_feat"].to(device)
102
+ feat_len = batch["speech_feat_len"].to(device)
103
+ embedding = batch["embedding"].to(device)
104
 
105
  # xvec projection
106
  embedding = F.normalize(embedding, dim=1)
107
  embedding = self.spk_embed_affine_layer(embedding)
108
 
109
  # concat text and prompt_text
110
+ print("token_len values: ", token_len)
111
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
112
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
113
 
 
132
  mask.unsqueeze(1),
133
  h.transpose(1, 2).contiguous(),
134
  embedding,
135
+ cond=conds,
136
  )
137
+ return {"loss": loss}
138
 
139
  @torch.inference_mode()
140
+ def inference(
141
+ self,
142
+ token,
143
+ token_len,
144
+ prompt_token,
145
+ prompt_token_len,
146
+ prompt_feat,
147
+ prompt_feat_len,
148
+ embedding,
149
+ flow_cache,
150
+ ):
151
  assert token.shape[0] == 1
152
  # xvec projection
153
  embedding = F.normalize(embedding, dim=1)
 
155
 
156
  # concat speech token and prompt speech token
157
  token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
158
+ token, token_len = (
159
+ torch.concat([prompt_token, token], dim=1),
160
+ prompt_token_len + token_len,
161
+ )
162
  mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
163
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
164
 
165
  # text encode
166
  h, h_lengths = self.encoder(token, token_len)
167
  h = self.encoder_proj(h)
168
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(
169
+ token_len2 / self.input_frame_rate * 22050 / 256
170
+ )
171
+ h, h_lengths = self.length_regulator.inference(
172
+ h[:, :token_len1],
173
+ h[:, token_len1:],
174
+ mel_len1,
175
+ mel_len2,
176
+ self.input_frame_rate,
177
+ )
178
 
179
  # get conditions
180
+ conds = torch.zeros(
181
+ [1, mel_len1 + mel_len2, self.output_size], device=token.device
182
+ ).to(h.dtype)
183
  conds[:, :mel_len1] = prompt_feat
184
  conds = conds.transpose(1, 2)
185
 
 
191
  cond=conds,
192
  n_timesteps=10,
193
  prompt_len=mel_len1,
194
+ cache=flow_cache,
195
  )
196
  feat = feat[:, :, mel_len1:]
197
  assert feat.shape[2] == mel_len2
 
199
 
200
 
201
  class CausalMaskedDiffWithXvec(torch.nn.Module):
202
+ def __init__(
203
+ self,
204
+ input_size: int = 512,
205
+ output_size: int = 80,
206
+ spk_embed_dim: int = 192,
207
+ output_type: str = "mel",
208
+ vocab_size: int = 4096,
209
+ input_frame_rate: int = 50,
210
+ only_mask_loss: bool = True,
211
+ token_mel_ratio: int = 2,
212
+ pre_lookahead_len: int = 3,
213
+ encoder: torch.nn.Module = None,
214
+ decoder: torch.nn.Module = None,
215
+ decoder_conf: Dict = {
216
+ "in_channels": 240,
217
+ "out_channel": 80,
218
+ "spk_emb_dim": 80,
219
+ "n_spks": 1,
220
+ "cfm_params": DictConfig(
221
+ {
222
+ "sigma_min": 1e-06,
223
+ "solver": "euler",
224
+ "t_scheduler": "cosine",
225
+ "training_cfg_rate": 0.2,
226
+ "inference_cfg_rate": 0.7,
227
+ "reg_loss_type": "l1",
228
+ "use_immiscible": True,
229
+ "immiscible_k": 8,
230
+ "use_contrastive_fm": True,
231
+ "contrastive_lambda": 0.05
232
+ }
233
+ ),
234
+ "decoder_params": {
235
+ "channels": [256, 256],
236
+ "dropout": 0.0,
237
+ "attention_head_dim": 64,
238
+ "n_blocks": 4,
239
+ "num_mid_blocks": 12,
240
+ "num_heads": 8,
241
+ "act_fn": "gelu",
242
+ },
243
+ },
244
+ mel_feat_conf: Dict = {
245
+ "n_fft": 1024,
246
+ "num_mels": 80,
247
+ "sampling_rate": 22050,
248
+ "hop_size": 256,
249
+ "win_size": 1024,
250
+ "fmin": 0,
251
+ "fmax": 8000,
252
+ },
253
+ ):
254
  super().__init__()
255
  self.input_size = input_size
256
  self.output_size = output_size
 
268
  self.only_mask_loss = only_mask_loss
269
  self.token_mel_ratio = token_mel_ratio
270
  self.pre_lookahead_len = pre_lookahead_len
271
+ print(" decoder_conf['cfm_params']: ", decoder_conf["cfm_params"])
272
+ self.use_contrastive_fm = decoder_conf["cfm_params"]["use_contrastive_fm"]
273
 
274
  def forward(
275
+ self,
276
+ batch: dict,
277
+ device: torch.device,
278
  ) -> Dict[str, Optional[torch.Tensor]]:
279
+ token = batch["speech_token"].to(device)
280
+ token_len = batch["speech_token_len"].to(device)
281
+ feat = batch["speech_feat"].to(device)
282
+ feat_len = batch["speech_feat_len"].to(device)
283
+ embedding = batch["embedding"].to(device)
 
 
 
 
 
 
284
 
285
  # NOTE unified training, static_chunk_size > 0 or = 0
286
+ streaming = False # if random.random() < 0.5 else False
287
 
288
  # xvec projection
289
  embedding = F.normalize(embedding, dim=1)
290
  embedding = self.spk_embed_affine_layer(embedding)
 
 
291
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
292
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
293
 
 
305
  conds = conds.transpose(1, 2)
306
 
307
  mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
308
+ if not self.use_contrastive_fm:
309
+ loss, _ = self.decoder.compute_loss(
310
+ feat.transpose(1, 2).contiguous(),
311
+ mask.unsqueeze(1),
312
+ h.transpose(1, 2).contiguous(),
313
+ embedding,
314
+ cond=conds,
315
+ streaming=streaming,
316
+ )
317
+ else:
318
+ # print("use contrastive fm")
319
+ loss, _ = self.decoder.compute_loss_contrastive(
320
+ feat.transpose(1, 2).contiguous(),
321
+ mask.unsqueeze(1),
322
+ h.transpose(1, 2).contiguous(),
323
+ embedding,
324
+ cond=conds,
325
+ streaming=streaming,
326
+ )
327
+ return {"loss": loss}
328
 
329
  @torch.inference_mode()
330
+ def inference(
331
+ self,
332
+ token,
333
+ token_len,
334
+ prompt_token,
335
+ prompt_token_len,
336
+ prompt_feat,
337
+ prompt_feat_len,
338
+ embedding,
339
+ streaming,
340
+ finalize,
341
+ ):
342
  assert token.shape[0] == 1
343
  # xvec projection
344
  embedding = F.normalize(embedding, dim=1)
345
  embedding = self.spk_embed_affine_layer(embedding)
346
 
347
  # concat text and prompt_text
348
+ token, token_len = (
349
+ torch.concat([prompt_token, token], dim=1),
350
+ prompt_token_len + token_len,
351
+ )
352
  mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
353
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
354
 
 
356
  if finalize is True:
357
  h, h_lengths = self.encoder(token, token_len, streaming=streaming)
358
  else:
359
+ token, context = (
360
+ token[:, : -self.pre_lookahead_len],
361
+ token[:, -self.pre_lookahead_len :],
362
+ )
363
+ h, h_lengths = self.encoder(
364
+ token, token_len, context=context, streaming=streaming
365
+ )
366
  mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
367
  h = self.encoder_proj(h)
368
 
369
  # get conditions
370
+ conds = torch.zeros(
371
+ [1, mel_len1 + mel_len2, self.output_size], device=token.device
372
+ ).to(h.dtype)
373
  conds[:, :mel_len1] = prompt_feat
374
  conds = conds.transpose(1, 2)
375
 
 
380
  spks=embedding,
381
  cond=conds,
382
  n_timesteps=10,
383
+ streaming=streaming,
384
  )
385
  feat = feat[:, :, mel_len1:]
386
  assert feat.shape[2] == mel_len2
speech/cosyvoice/flow/flow_matching.py CHANGED
@@ -34,6 +34,7 @@ class ConditionalCFM(BASECFM):
34
  self.estimator = estimator
35
  self.use_immiscible = cfm_params.use_immiscible
36
  self.immiscible_k = cfm_params.immiscible_k
 
37
 
38
  @torch.inference_mode()
39
  def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
@@ -177,14 +178,6 @@ class ConditionalCFM(BASECFM):
177
  t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
178
  if self.t_scheduler == 'cosine':
179
  t = 1 - torch.cos(t * 0.5 * torch.pi)
180
-
181
-
182
- print(f"\n=== Immiscible Diffusion Debug ===")
183
- print(f"x1 shape: {x1.shape}")
184
- print(f"mu shape: {mu.shape}")
185
- print(f"t shape: {t.shape}")
186
- print(f"Device: {x1.device}")
187
- print(f"Dtype: {x1.dtype}")
188
 
189
  # Apply immiscible diffusion with KNN
190
  if self.use_immiscible:
@@ -192,49 +185,87 @@ class ConditionalCFM(BASECFM):
192
 
193
  # Generate k noise samples for each data point
194
  z_candidates = torch.randn(b, k, d, T, device=x1.device, dtype=x1.dtype)
195
-
196
- print(f"z_candidates shape: {z_candidates.shape}")
197
- print(f"z_candidates stats - mean: {z_candidates.mean():.4f}, std: {z_candidates.std():.4f}")
198
-
199
- # Flatten for distance computation
200
  x1_flat = x1.flatten(start_dim=1).to(torch.float16)
201
  z_candidates_flat = z_candidates.flatten(start_dim=2).to(torch.float16)
202
 
203
- print(f"x1_flat shape: {x1_flat.shape}")
204
- print(f"z_candidates_flat shape: {z_candidates_flat.shape}")
205
-
206
- # Calculate distances
207
  distances = torch.norm(x1_flat.unsqueeze(1) - z_candidates_flat, dim=2)
 
 
208
 
209
- print(f"distances shape: {distances.shape}")
210
- print(f"distances stats - mean: {distances.mean():.4f}, std: {distances.std():.4f}")
211
- print(f"distances min: {distances.min():.4f}, max: {distances.max():.4f}")
 
 
 
212
 
213
- # Pick the nearest noise for each data point
214
- min_distances, min_indices = torch.min(distances, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
 
 
 
 
 
216
 
217
- print(f"min_indices: {min_indices[:10]}") # First 10 indices
218
- print(f"min_distances stats - mean: {min_distances.mean():.4f}, std: {min_distances.std():.4f}")
 
 
219
 
220
- # Gather the selected noise samples
221
  z = torch.gather(
222
  z_candidates,
223
  1,
224
  min_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(-1, 1, d, T)
225
  )[:, 0, :, :]
226
 
227
- print(f"Selected z shape: {z.shape}")
228
- print(f"Selected z stats - mean: {z.mean():.4f}, std: {z.std():.4f}")
229
-
230
- # Calculate distance reduction
231
- with torch.no_grad():
232
- orig_distance = distances[:, 0].mean()
233
- selected_distance = min_distances.mean()
234
- reduction_rate = (orig_distance - selected_distance) / orig_distance
235
- print(f"Distance reduction: {reduction_rate:.3%}")
236
- print(f"Original distance: {orig_distance:.4f}")
237
- print(f"Selected distance: {selected_distance:.4f}")
238
  else:
239
  # sample noise p(x_0)
240
  z = torch.randn_like(x1)
@@ -250,7 +281,27 @@ class ConditionalCFM(BASECFM):
250
  cond = cond * cfg_mask.view(-1, 1, 1)
251
 
252
  pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
253
- loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  return loss, y
255
 
256
 
 
34
  self.estimator = estimator
35
  self.use_immiscible = cfm_params.use_immiscible
36
  self.immiscible_k = cfm_params.immiscible_k
37
+ self.lambda_weight = cfm_params.contrastive_lambda
38
 
39
  @torch.inference_mode()
40
  def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
 
178
  t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
179
  if self.t_scheduler == 'cosine':
180
  t = 1 - torch.cos(t * 0.5 * torch.pi)
 
 
 
 
 
 
 
 
181
 
182
  # Apply immiscible diffusion with KNN
183
  if self.use_immiscible:
 
185
 
186
  # Generate k noise samples for each data point
187
  z_candidates = torch.randn(b, k, d, T, device=x1.device, dtype=x1.dtype)
188
+
 
 
 
 
189
  x1_flat = x1.flatten(start_dim=1).to(torch.float16)
190
  z_candidates_flat = z_candidates.flatten(start_dim=2).to(torch.float16)
191
 
192
+
 
 
 
193
  distances = torch.norm(x1_flat.unsqueeze(1) - z_candidates_flat, dim=2)
194
+
195
+ min_distances, min_indices = torch.min(distances, dim=1)
196
 
197
+
198
+ z = torch.gather(
199
+ z_candidates,
200
+ 1,
201
+ min_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(-1, 1, d, T)
202
+ )[:, 0, :, :]
203
 
204
+ else:
205
+ # sample noise p(x_0)
206
+ z = torch.randn_like(x1)
207
+
208
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
209
+ u = x1 - (1 - self.sigma_min) * z
210
+
211
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
212
+ if self.training_cfg_rate > 0:
213
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
214
+ mu = mu * cfg_mask.view(-1, 1, 1)
215
+ spks = spks * cfg_mask.view(-1, 1)
216
+ cond = cond * cfg_mask.view(-1, 1, 1)
217
+
218
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
219
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
220
+ return loss, y
221
+
222
+ def compute_loss_contrastive(self, x1, mask, mu, spks=None, cond=None, streaming=False):
223
+ """Computes diffusion loss
224
+
225
+ Args:
226
+ x1 (torch.Tensor): Target
227
+ shape: (batch_size, n_feats, mel_timesteps)
228
+ mask (torch.Tensor): target mask
229
+ shape: (batch_size, 1, mel_timesteps)
230
+ mu (torch.Tensor): output of encoder
231
+ shape: (batch_size, n_feats, mel_timesteps)
232
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
233
+ shape: (batch_size, spk_emb_dim)
234
+
235
+ Returns:
236
+ loss: conditional flow matching loss
237
+ y: conditional flow
238
+ shape: (batch_size, n_feats, mel_timesteps)
239
+ """
240
+ b, d, T = mu.shape
241
+
242
+ # random timestep
243
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
244
+ if self.t_scheduler == 'cosine':
245
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
246
+
247
+ # Apply immiscible diffusion with KNN
248
+ if self.use_immiscible:
249
+ k = getattr(self, 'immiscible_k', 4)
250
 
251
+ # Generate k noise samples for each data point
252
+ z_candidates = torch.randn(b, k, d, T, device=x1.device, dtype=x1.dtype)
253
+
254
+ x1_flat = x1.flatten(start_dim=1).to(torch.float16)
255
+ z_candidates_flat = z_candidates.flatten(start_dim=2).to(torch.float16)
256
 
257
+
258
+ distances = torch.norm(x1_flat.unsqueeze(1) - z_candidates_flat, dim=2)
259
+
260
+ min_distances, min_indices = torch.min(distances, dim=1)
261
 
262
+
263
  z = torch.gather(
264
  z_candidates,
265
  1,
266
  min_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(-1, 1, d, T)
267
  )[:, 0, :, :]
268
 
 
 
 
 
 
 
 
 
 
 
 
269
  else:
270
  # sample noise p(x_0)
271
  z = torch.randn_like(x1)
 
281
  cond = cond * cfg_mask.view(-1, 1, 1)
282
 
283
  pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
284
+ fm_loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
285
+
286
+ neg_indices = torch.roll(torch.arange(b, device=x1.device), shifts=1)
287
+
288
+ # Get negative targets from shifted indices
289
+ if b > 1:
290
+ u_neg = u[neg_indices]
291
+ neg_mask = mask[neg_indices]
292
+
293
+ # Contrastive loss
294
+ contrastive_loss = F.mse_loss(
295
+ pred * neg_mask,
296
+ u_neg * neg_mask,
297
+ reduction="sum"
298
+ ) / (torch.sum(neg_mask) * d)
299
+ print('contrastive_loss: ', contrastive_loss)
300
+ else:
301
+ contrastive_loss = torch.tensor(0.0, device=fm_loss.device)
302
+
303
+ loss = fm_loss - self.lambda_weight * contrastive_loss
304
+
305
  return loss, y
306
 
307
 
speech/cosyvoice/utils/executor.py CHANGED
@@ -33,6 +33,7 @@ class Executor:
33
  gan: bool = False,
34
  ref_model: torch.nn.Module = None,
35
  dpo_loss: torch.nn.Module = None,
 
36
  ):
37
  self.gan = gan
38
  self.ref_model = ref_model
@@ -41,6 +42,7 @@ class Executor:
41
  self.epoch = 0
42
  self.rank = int(os.environ.get("RANK", 0))
43
  self.device = torch.device(f"cuda:{self.rank}")
 
44
 
45
  def train_one_epoc(
46
  self,
@@ -69,16 +71,20 @@ class Executor:
69
 
70
  use_ddp = info_dict["train_engine"] == "torch_ddp"
71
 
 
72
  for batch_idx, batch_dict in enumerate(train_data_loader):
73
  info_dict["tag"] = "TRAIN"
74
  info_dict["step"] = self.step
75
  info_dict["epoch"] = self.epoch
76
  info_dict["batch_idx"] = batch_idx
77
 
 
78
  if use_ddp and (batch_idx + 1) % info_dict["accum_grad"] != 0:
79
  context = model.no_sync
80
  else:
81
  context = nullcontext
 
 
82
  with context():
83
  info_dict = batch_forward(
84
  model,
@@ -88,6 +94,7 @@ class Executor:
88
  ref_model=self.ref_model,
89
  dpo_loss=self.dpo_loss,
90
  )
 
91
  info_dict = batch_backward(model, scaler, info_dict)
92
 
93
  info_dict = update_parameter_and_lr(
 
33
  gan: bool = False,
34
  ref_model: torch.nn.Module = None,
35
  dpo_loss: torch.nn.Module = None,
36
+ use_contrastive_fm: bool = False
37
  ):
38
  self.gan = gan
39
  self.ref_model = ref_model
 
42
  self.epoch = 0
43
  self.rank = int(os.environ.get("RANK", 0))
44
  self.device = torch.device(f"cuda:{self.rank}")
45
+ self.use_contrastive_fm = use_contrastive_fm
46
 
47
  def train_one_epoc(
48
  self,
 
71
 
72
  use_ddp = info_dict["train_engine"] == "torch_ddp"
73
 
74
+
75
  for batch_idx, batch_dict in enumerate(train_data_loader):
76
  info_dict["tag"] = "TRAIN"
77
  info_dict["step"] = self.step
78
  info_dict["epoch"] = self.epoch
79
  info_dict["batch_idx"] = batch_idx
80
 
81
+
82
  if use_ddp and (batch_idx + 1) % info_dict["accum_grad"] != 0:
83
  context = model.no_sync
84
  else:
85
  context = nullcontext
86
+
87
+
88
  with context():
89
  info_dict = batch_forward(
90
  model,
 
94
  ref_model=self.ref_model,
95
  dpo_loss=self.dpo_loss,
96
  )
97
+
98
  info_dict = batch_backward(model, scaler, info_dict)
99
 
100
  info_dict = update_parameter_and_lr(
speech/cosyvoice/utils/train_utils.py CHANGED
@@ -250,6 +250,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
250
 
251
  with autocast:
252
  info_dict['loss_dict'] = model(batch, device)
 
253
  if ref_model is not None and dpo_loss is not None:
254
  chosen_logps = info_dict['loss_dict']["chosen_logps"]
255
  rejected_logps = info_dict['loss_dict']["rejected_logps"]
 
250
 
251
  with autocast:
252
  info_dict['loss_dict'] = model(batch, device)
253
+ # print('infor_dict loss_dict : ', info_dict['loss_dict'])
254
  if ref_model is not None and dpo_loss is not None:
255
  chosen_logps = info_dict['loss_dict']["chosen_logps"]
256
  rejected_logps = info_dict['loss_dict']["rejected_logps"]