primepake commited on
Commit
4eed684
·
1 Parent(s): a8d6b21

adding flowvae

Browse files
flowae/models/diffusion/fm.py CHANGED
@@ -12,6 +12,7 @@ class FM:
12
  self.timescale = timescale
13
  self.use_immiscible = use_immiscible
14
  self.k_candidates = k_candidates
 
15
 
16
  def alpha(self, t):
17
  return 1.0 - t
 
12
  self.timescale = timescale
13
  self.use_immiscible = use_immiscible
14
  self.k_candidates = k_candidates
15
+ print('use_immiscible: ', use_immiscible, 'k_candidates: ', k_candidates)
16
 
17
  def alpha(self, t):
18
  return 1.0 - t
flowae/models/networks/consistency_audio_decoder_unet.py CHANGED
@@ -274,9 +274,9 @@ class AudioDiffusionUNet(nn.Module):
274
  size=x.shape[-1],
275
  mode='linear' # or 'linear' for smoother interpolation
276
  )
277
-
278
  # Add latent conditioning to audio features
279
- return x + z_proj
280
 
281
  def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
282
  """
@@ -288,11 +288,13 @@ class AudioDiffusionUNet(nn.Module):
288
  z_dec: [batch, 64, n_frames] - latent conditioning (any length)
289
  """
290
  # Embed audio input
 
291
  x = self.embed_audio(x) # [batch, c0, samples]
292
-
293
  # Add latent conditioning
294
  if z_dec is not None:
295
  x = self.condition_with_latents(x, z_dec)
 
296
 
297
  # Embed timestep
298
  if t is None:
 
274
  size=x.shape[-1],
275
  mode='linear' # or 'linear' for smoother interpolation
276
  )
277
+ print('shape of z_proj: ', z_proj.shape)
278
  # Add latent conditioning to audio features
279
+ return torch.cat([x, z_proj], dim=1)
280
 
281
  def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
282
  """
 
288
  z_dec: [batch, 64, n_frames] - latent conditioning (any length)
289
  """
290
  # Embed audio input
291
+ print('shape of x: ', x.shape, 'shape of z_dec: ', z_dec.shape)
292
  x = self.embed_audio(x) # [batch, c0, samples]
293
+ print('shape of x: ', x.shape)
294
  # Add latent conditioning
295
  if z_dec is not None:
296
  x = self.condition_with_latents(x, z_dec)
297
+ print('shape of x: ', x.shape)
298
 
299
  # Embed timestep
300
  if t is None:
flowae/trainers/audio_ldm_trainer.py CHANGED
@@ -412,15 +412,30 @@ class AudioLDMTrainer(BaseTrainer):
412
  for i in range(min(signal.batch_size, 5)): # Save up to 5 per batch
413
  idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
414
  if max_samples is None or idx < max_samples:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  # Save as wav files
416
  sf.write(
417
  os.path.join(cache_gen_dir, f'{idx}.wav'),
418
- recons[i].audio_data.cpu().numpy().T,
419
  int(recons[i].sample_rate)
420
  )
421
  sf.write(
422
  os.path.join(cache_gt_dir, f'{idx}.wav'),
423
- signal[i].audio_data.cpu().numpy().T,
424
  int(signal[i].sample_rate)
425
  )
426
  cnt += 1
@@ -493,14 +508,29 @@ class AudioLDMTrainer(BaseTrainer):
493
  for i in range(min(gt_signal.batch_size, 5)):
494
  idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
495
  if max_samples is None or idx < max_samples:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  sf.write(
497
  os.path.join(cache_gen_dir, f'{idx}.wav'),
498
- pred_signal[i].audio_data.cpu().numpy().T,
499
  int(pred_signal[i].sample_rate)
500
  )
501
  sf.write(
502
  os.path.join(cache_gt_dir, f'{idx}.wav'),
503
- gt_signal[i].audio_data.cpu().numpy().T,
504
  int(gt_signal[i].sample_rate)
505
  )
506
  cnt += 1
 
412
  for i in range(min(signal.batch_size, 5)): # Save up to 5 per batch
413
  idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
414
  if max_samples is None or idx < max_samples:
415
+ tmp_recon = recons[i].audio_data.cpu().numpy()
416
+ if tmp_recon.dim() == 3:
417
+ tmp_recon = tmp_recon.squeeze(0)
418
+ elif tmp_recon.dim() == 1:
419
+ tmp_recon = tmp_recon.unsqueeze(0)
420
+ tmp_recon = tmp_recon.T
421
+
422
+ tmp_signal = signal[i].audio_data.cpu().numpy()
423
+ if tmp_signal.dim() == 3:
424
+ tmp_signal = tmp_signal.squeeze(0)
425
+ elif tmp_signal.dim() == 1:
426
+ tmp_signal = tmp_signal.unsqueeze(0)
427
+ tmp_signal = tmp_signal.T
428
+
429
+
430
  # Save as wav files
431
  sf.write(
432
  os.path.join(cache_gen_dir, f'{idx}.wav'),
433
+ tmp_recon,
434
  int(recons[i].sample_rate)
435
  )
436
  sf.write(
437
  os.path.join(cache_gt_dir, f'{idx}.wav'),
438
+ tmp_signal,
439
  int(signal[i].sample_rate)
440
  )
441
  cnt += 1
 
508
  for i in range(min(gt_signal.batch_size, 5)):
509
  idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
510
  if max_samples is None or idx < max_samples:
511
+ tmp_recon = pred_signal[i].audio_data.cpu().numpy()
512
+ if tmp_recon.dim() == 3:
513
+ tmp_recon = tmp_recon.squeeze(0)
514
+ elif tmp_recon.dim() == 1:
515
+ tmp_recon = tmp_recon.unsqueeze(0)
516
+ tmp_recon = tmp_recon.T
517
+
518
+ tmp_signal = gt_signal[i].audio_data.cpu().numpy()
519
+ if tmp_signal.dim() == 3:
520
+ tmp_signal = tmp_signal.squeeze(0)
521
+ elif tmp_signal.dim() == 1:
522
+ tmp_signal = tmp_signal.unsqueeze(0)
523
+ tmp_signal = tmp_signal.T
524
+
525
+
526
  sf.write(
527
  os.path.join(cache_gen_dir, f'{idx}.wav'),
528
+ tmp_recon,
529
  int(pred_signal[i].sample_rate)
530
  )
531
  sf.write(
532
  os.path.join(cache_gt_dir, f'{idx}.wav'),
533
+ tmp_signal,
534
  int(gt_signal[i].sample_rate)
535
  )
536
  cnt += 1