SAM commited on
Commit
3eae7ea
·
unverified ·
1 Parent(s): f6beba0

update contrastive loss

Browse files
speech/cosyvoice/flow/flow_matching.py CHANGED
@@ -270,21 +270,10 @@ class ConditionalCFM(BASECFM):
270
  # sample noise p(x_0)
271
  z = torch.randn_like(x1)
272
 
273
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
274
- u = x1 - (1 - self.sigma_min) * z
275
 
276
- # during training, we randomly drop condition to trade off mode coverage and sample fidelity
277
- if self.training_cfg_rate > 0:
278
- cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
279
- mu = mu * cfg_mask.view(-1, 1, 1)
280
- spks = spks * cfg_mask.view(-1, 1)
281
- cond = cond * cfg_mask.view(-1, 1, 1)
282
-
283
- pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
284
- fm_loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
285
 
286
-
287
-
288
  # Get negative targets from shifted indices
289
  if b > 1:
290
  perm = torch.randperm(b, device=x1.device)
@@ -296,32 +285,34 @@ class ConditionalCFM(BASECFM):
296
 
297
  # Get negative samples
298
  x1_neg = x1[perm]
299
- mask_neg = mask[perm]
300
-
301
- # Generate independent noise for negatives
302
- z_neg = torch.randn_like(x1_neg)
303
-
304
- # Compute negative velocities
305
- u_neg = x1_neg - (1 - self.sigma_min) * z_neg
306
-
307
- # Contrastive loss
308
- contrastive_loss = F.mse_loss(
309
- pred * mask_neg,
310
- u_neg * mask_neg,
311
- reduction="sum"
312
- ) / (torch.sum(mask_neg) * d)
313
-
314
- # print('before contrastive_loss: ', contrastive_loss)
315
  else:
316
- contrastive_loss = torch.tensor(0.0, device=fm_loss.device)
317
- # print("fm_loss: ", fm_loss)
318
-
319
- contrastive_loss = self.lambda_weight * contrastive_loss
320
- # print('contrastive_loss: ', contrastive_loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
- loss = fm_loss - contrastive_loss
323
 
324
- return loss, y
325
 
326
 
327
  class CausalConditionalCFM(ConditionalCFM):
 
270
  # sample noise p(x_0)
271
  z = torch.randn_like(x1)
272
 
273
+ x_t = (1 - (1 - self.sigma_min) * t) * z + t * x1
 
274
 
275
+ u_positive = x1 - (1 - self.sigma_min) * z
 
 
 
 
 
 
 
 
276
 
 
 
277
  # Get negative targets from shifted indices
278
  if b > 1:
279
  perm = torch.randperm(b, device=x1.device)
 
285
 
286
  # Get negative samples
287
  x1_neg = x1[perm]
288
+
289
+ # KEY: Use the SAME z that created x_t (not new noise)
290
+ # This asks: "what if x_t came from x1_neg instead?"
291
+ u_negative = x1_neg - (1 - self.sigma_min) * z
 
 
 
 
 
 
 
 
 
 
 
 
292
  else:
293
+ u_negative = u_positive
294
+
295
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
296
+ if self.training_cfg_rate > 0:
297
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
298
+ mu = mu * cfg_mask.view(-1, 1, 1)
299
+ spks = spks * cfg_mask.view(-1, 1)
300
+ cond = cond * cfg_mask.view(-1, 1, 1)
301
+
302
+ pred = self.estimator(x_t, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
303
+
304
+ positive_loss = F.mse_loss(pred * mask, u_positive * mask, reduction="sum") / (torch.sum(mask) * d)
305
+
306
+ if b > 1:
307
+ # Negative loss: pred should NOT match velocities from other trajectories
308
+ negative_loss = F.mse_loss(pred * mask, u_negative * mask, reduction="sum") / (torch.sum(mask) * d)
309
+ else:
310
+ negative_loss = torch.tensor(0.0, device=positive_loss.device)
311
+
312
 
313
+ loss = positive_loss - self.lambda_weight * negative_loss
314
 
315
+ return loss, x_t
316
 
317
 
318
  class CausalConditionalCFM(ConditionalCFM):