Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
4eed684
1
Parent(s):
a8d6b21
adding flowvae
Browse files
flowae/models/diffusion/fm.py
CHANGED
|
@@ -12,6 +12,7 @@ class FM:
|
|
| 12 |
self.timescale = timescale
|
| 13 |
self.use_immiscible = use_immiscible
|
| 14 |
self.k_candidates = k_candidates
|
|
|
|
| 15 |
|
| 16 |
def alpha(self, t):
|
| 17 |
return 1.0 - t
|
|
|
|
| 12 |
self.timescale = timescale
|
| 13 |
self.use_immiscible = use_immiscible
|
| 14 |
self.k_candidates = k_candidates
|
| 15 |
+
print('use_immiscible: ', use_immiscible, 'k_candidates: ', k_candidates)
|
| 16 |
|
| 17 |
def alpha(self, t):
|
| 18 |
return 1.0 - t
|
flowae/models/networks/consistency_audio_decoder_unet.py
CHANGED
|
@@ -274,9 +274,9 @@ class AudioDiffusionUNet(nn.Module):
|
|
| 274 |
size=x.shape[-1],
|
| 275 |
mode='linear' # or 'linear' for smoother interpolation
|
| 276 |
)
|
| 277 |
-
|
| 278 |
# Add latent conditioning to audio features
|
| 279 |
-
return x
|
| 280 |
|
| 281 |
def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
|
| 282 |
"""
|
|
@@ -288,11 +288,13 @@ class AudioDiffusionUNet(nn.Module):
|
|
| 288 |
z_dec: [batch, 64, n_frames] - latent conditioning (any length)
|
| 289 |
"""
|
| 290 |
# Embed audio input
|
|
|
|
| 291 |
x = self.embed_audio(x) # [batch, c0, samples]
|
| 292 |
-
|
| 293 |
# Add latent conditioning
|
| 294 |
if z_dec is not None:
|
| 295 |
x = self.condition_with_latents(x, z_dec)
|
|
|
|
| 296 |
|
| 297 |
# Embed timestep
|
| 298 |
if t is None:
|
|
|
|
| 274 |
size=x.shape[-1],
|
| 275 |
mode='linear' # or 'linear' for smoother interpolation
|
| 276 |
)
|
| 277 |
+
print('shape of z_proj: ', z_proj.shape)
|
| 278 |
# Add latent conditioning to audio features
|
| 279 |
+
return torch.cat([x, z_proj], dim=1)
|
| 280 |
|
| 281 |
def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
|
| 282 |
"""
|
|
|
|
| 288 |
z_dec: [batch, 64, n_frames] - latent conditioning (any length)
|
| 289 |
"""
|
| 290 |
# Embed audio input
|
| 291 |
+
print('shape of x: ', x.shape, 'shape of z_dec: ', z_dec.shape)
|
| 292 |
x = self.embed_audio(x) # [batch, c0, samples]
|
| 293 |
+
print('shape of x: ', x.shape)
|
| 294 |
# Add latent conditioning
|
| 295 |
if z_dec is not None:
|
| 296 |
x = self.condition_with_latents(x, z_dec)
|
| 297 |
+
print('shape of x: ', x.shape)
|
| 298 |
|
| 299 |
# Embed timestep
|
| 300 |
if t is None:
|
flowae/trainers/audio_ldm_trainer.py
CHANGED
|
@@ -412,15 +412,30 @@ class AudioLDMTrainer(BaseTrainer):
|
|
| 412 |
for i in range(min(signal.batch_size, 5)): # Save up to 5 per batch
|
| 413 |
idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
|
| 414 |
if max_samples is None or idx < max_samples:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
# Save as wav files
|
| 416 |
sf.write(
|
| 417 |
os.path.join(cache_gen_dir, f'{idx}.wav'),
|
| 418 |
-
|
| 419 |
int(recons[i].sample_rate)
|
| 420 |
)
|
| 421 |
sf.write(
|
| 422 |
os.path.join(cache_gt_dir, f'{idx}.wav'),
|
| 423 |
-
|
| 424 |
int(signal[i].sample_rate)
|
| 425 |
)
|
| 426 |
cnt += 1
|
|
@@ -493,14 +508,29 @@ class AudioLDMTrainer(BaseTrainer):
|
|
| 493 |
for i in range(min(gt_signal.batch_size, 5)):
|
| 494 |
idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
|
| 495 |
if max_samples is None or idx < max_samples:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
sf.write(
|
| 497 |
os.path.join(cache_gen_dir, f'{idx}.wav'),
|
| 498 |
-
|
| 499 |
int(pred_signal[i].sample_rate)
|
| 500 |
)
|
| 501 |
sf.write(
|
| 502 |
os.path.join(cache_gt_dir, f'{idx}.wav'),
|
| 503 |
-
|
| 504 |
int(gt_signal[i].sample_rate)
|
| 505 |
)
|
| 506 |
cnt += 1
|
|
|
|
| 412 |
for i in range(min(signal.batch_size, 5)): # Save up to 5 per batch
|
| 413 |
idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
|
| 414 |
if max_samples is None or idx < max_samples:
|
| 415 |
+
tmp_recon = recons[i].audio_data.cpu().numpy()
|
| 416 |
+
if tmp_recon.dim() == 3:
|
| 417 |
+
tmp_recon = tmp_recon.squeeze(0)
|
| 418 |
+
elif tmp_recon.dim() == 1:
|
| 419 |
+
tmp_recon = tmp_recon.unsqueeze(0)
|
| 420 |
+
tmp_recon = tmp_recon.T
|
| 421 |
+
|
| 422 |
+
tmp_signal = signal[i].audio_data.cpu().numpy()
|
| 423 |
+
if tmp_signal.dim() == 3:
|
| 424 |
+
tmp_signal = tmp_signal.squeeze(0)
|
| 425 |
+
elif tmp_signal.dim() == 1:
|
| 426 |
+
tmp_signal = tmp_signal.unsqueeze(0)
|
| 427 |
+
tmp_signal = tmp_signal.T
|
| 428 |
+
|
| 429 |
+
|
| 430 |
# Save as wav files
|
| 431 |
sf.write(
|
| 432 |
os.path.join(cache_gen_dir, f'{idx}.wav'),
|
| 433 |
+
tmp_recon,
|
| 434 |
int(recons[i].sample_rate)
|
| 435 |
)
|
| 436 |
sf.write(
|
| 437 |
os.path.join(cache_gt_dir, f'{idx}.wav'),
|
| 438 |
+
tmp_signal,
|
| 439 |
int(signal[i].sample_rate)
|
| 440 |
)
|
| 441 |
cnt += 1
|
|
|
|
| 508 |
for i in range(min(gt_signal.batch_size, 5)):
|
| 509 |
idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
|
| 510 |
if max_samples is None or idx < max_samples:
|
| 511 |
+
tmp_recon = pred_signal[i].audio_data.cpu().numpy()
|
| 512 |
+
if tmp_recon.dim() == 3:
|
| 513 |
+
tmp_recon = tmp_recon.squeeze(0)
|
| 514 |
+
elif tmp_recon.dim() == 1:
|
| 515 |
+
tmp_recon = tmp_recon.unsqueeze(0)
|
| 516 |
+
tmp_recon = tmp_recon.T
|
| 517 |
+
|
| 518 |
+
tmp_signal = gt_signal[i].audio_data.cpu().numpy()
|
| 519 |
+
if tmp_signal.dim() == 3:
|
| 520 |
+
tmp_signal = tmp_signal.squeeze(0)
|
| 521 |
+
elif tmp_signal.dim() == 1:
|
| 522 |
+
tmp_signal = tmp_signal.unsqueeze(0)
|
| 523 |
+
tmp_signal = tmp_signal.T
|
| 524 |
+
|
| 525 |
+
|
| 526 |
sf.write(
|
| 527 |
os.path.join(cache_gen_dir, f'{idx}.wav'),
|
| 528 |
+
tmp_recon,
|
| 529 |
int(pred_signal[i].sample_rate)
|
| 530 |
)
|
| 531 |
sf.write(
|
| 532 |
os.path.join(cache_gt_dir, f'{idx}.wav'),
|
| 533 |
+
tmp_signal,
|
| 534 |
int(gt_signal[i].sample_rate)
|
| 535 |
)
|
| 536 |
cnt += 1
|