primepake commited on
Commit
9f4fc9f
·
1 Parent(s): ca7dd21

add Immiscible training code

Browse files
speech/cosyvoice/flow/flow_matching.py CHANGED
@@ -32,6 +32,8 @@ class ConditionalCFM(BASECFM):
32
  in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
33
  # Just change the architecture of the estimator here
34
  self.estimator = estimator
 
 
35
 
36
  @torch.inference_mode()
37
  def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
@@ -169,14 +171,73 @@ class ConditionalCFM(BASECFM):
169
  y: conditional flow
170
  shape: (batch_size, n_feats, mel_timesteps)
171
  """
172
- b, _, t = mu.shape
173
 
174
  # random timestep
175
  t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
176
  if self.t_scheduler == 'cosine':
177
  t = 1 - torch.cos(t * 0.5 * torch.pi)
178
- # sample noise p(x_0)
179
- z = torch.randn_like(x1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  y = (1 - (1 - self.sigma_min) * t) * z + t * x1
182
  u = x1 - (1 - self.sigma_min) * z
@@ -187,13 +248,7 @@ class ConditionalCFM(BASECFM):
187
  mu = mu * cfg_mask.view(-1, 1, 1)
188
  spks = spks * cfg_mask.view(-1, 1)
189
  cond = cond * cfg_mask.view(-1, 1, 1)
190
- # print('y shape: ', y.shape)
191
- # print('mask shape: ', mask.shape)
192
- # print('mu shape: ', mu.shape)
193
- # print('t shape: ', t.shape)
194
- # print('spks shape: ', spks.shape)
195
- # print('cond shape: ', cond.shape)
196
- # print('streaming: ', streaming)
197
  pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
198
  loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
199
  return loss, y
 
32
  in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
33
  # Just change the architecture of the estimator here
34
  self.estimator = estimator
35
+ self.use_immiscible = cfm_params.use_immiscible
36
+ self.immiscible_k = cfm_params.immiscible_k
37
 
38
  @torch.inference_mode()
39
  def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
 
171
  y: conditional flow
172
  shape: (batch_size, n_feats, mel_timesteps)
173
  """
174
+ b, d, T = mu.shape
175
 
176
  # random timestep
177
  t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
178
  if self.t_scheduler == 'cosine':
179
  t = 1 - torch.cos(t * 0.5 * torch.pi)
180
+
181
+
182
+ print(f"\n=== Immiscible Diffusion Debug ===")
183
+ print(f"x1 shape: {x1.shape}")
184
+ print(f"mu shape: {mu.shape}")
185
+ print(f"t shape: {t.shape}")
186
+ print(f"Device: {x1.device}")
187
+ print(f"Dtype: {x1.dtype}")
188
+
189
+ # Apply immiscible diffusion with KNN
190
+ if self.use_immiscible:
191
+ k = getattr(self, 'immiscible_k', 4)
192
+
193
+ # Generate k noise samples for each data point
194
+ z_candidates = torch.randn(b, k, d, T, device=x1.device, dtype=x1.dtype)
195
+
196
+ print(f"z_candidates shape: {z_candidates.shape}")
197
+ print(f"z_candidates stats - mean: {z_candidates.mean():.4f}, std: {z_candidates.std():.4f}")
198
+
199
+ # Flatten for distance computation
200
+ x1_flat = x1.flatten(start_dim=1).to(torch.float16)
201
+ z_candidates_flat = z_candidates.flatten(start_dim=2).to(torch.float16)
202
+
203
+ print(f"x1_flat shape: {x1_flat.shape}")
204
+ print(f"z_candidates_flat shape: {z_candidates_flat.shape}")
205
+
206
+ # Calculate distances
207
+ distances = torch.norm(x1_flat.unsqueeze(1) - z_candidates_flat, dim=2)
208
+
209
+ print(f"distances shape: {distances.shape}")
210
+ print(f"distances stats - mean: {distances.mean():.4f}, std: {distances.std():.4f}")
211
+ print(f"distances min: {distances.min():.4f}, max: {distances.max():.4f}")
212
+
213
+ # Pick the nearest noise for each data point
214
+ min_distances, min_indices = torch.min(distances, dim=1)
215
+
216
+
217
+ print(f"min_indices: {min_indices[:10]}") # First 10 indices
218
+ print(f"min_distances stats - mean: {min_distances.mean():.4f}, std: {min_distances.std():.4f}")
219
+
220
+ # Gather the selected noise samples
221
+ z = torch.gather(
222
+ z_candidates,
223
+ 1,
224
+ min_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(-1, 1, d, T)
225
+ )[:, 0, :, :]
226
+
227
+ print(f"Selected z shape: {z.shape}")
228
+ print(f"Selected z stats - mean: {z.mean():.4f}, std: {z.std():.4f}")
229
+
230
+ # Calculate distance reduction
231
+ with torch.no_grad():
232
+ orig_distance = distances[:, 0].mean()
233
+ selected_distance = min_distances.mean()
234
+ reduction_rate = (orig_distance - selected_distance) / orig_distance
235
+ print(f"Distance reduction: {reduction_rate:.3%}")
236
+ print(f"Original distance: {orig_distance:.4f}")
237
+ print(f"Selected distance: {selected_distance:.4f}")
238
+ else:
239
+ # sample noise p(x_0)
240
+ z = torch.randn_like(x1)
241
 
242
  y = (1 - (1 - self.sigma_min) * t) * z + t * x1
243
  u = x1 - (1 - self.sigma_min) * z
 
248
  mu = mu * cfg_mask.view(-1, 1, 1)
249
  spks = spks * cfg_mask.view(-1, 1)
250
  cond = cond * cfg_mask.view(-1, 1, 1)
251
+
 
 
 
 
 
 
252
  pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
253
  loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
254
  return loss, y