pcunwa commited on
Commit
d0d17ea
·
verified ·
1 Parent(s): 72707bf

Update bs_roformer.py

Browse files
Files changed (1) hide show
  1. bs_roformer.py +8 -1
bs_roformer.py CHANGED
@@ -982,6 +982,13 @@ class BSRoformer(Module):
982
 
983
  # istft
984
 
 
 
 
 
 
 
 
985
  recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
986
 
987
  if num_stems == 1:
@@ -1025,4 +1032,4 @@ class BSRoformer(Module):
1025
  if not return_loss_breakdown:
1026
  return total_loss
1027
 
1028
- return total_loss, (loss, multi_stft_resolution_loss)
 
982
 
983
  # istft
984
 
985
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
986
+
987
+ try:
988
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
989
+ except:
990
+ recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
991
+
992
  recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
993
 
994
  if num_stems == 1:
 
1032
  if not return_loss_breakdown:
1033
  return total_loss
1034
 
1035
+ return total_loss, (loss, multi_stft_resolution_loss)