Update bs_roformer.py
Browse files- bs_roformer.py +8 -1
bs_roformer.py
CHANGED
|
@@ -982,6 +982,13 @@ class BSRoformer(Module):
|
|
| 982 |
|
| 983 |
# istft
|
| 984 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
|
| 986 |
|
| 987 |
if num_stems == 1:
|
|
@@ -1025,4 +1032,4 @@ class BSRoformer(Module):
|
|
| 1025 |
if not return_loss_breakdown:
|
| 1026 |
return total_loss
|
| 1027 |
|
| 1028 |
-
return total_loss, (loss, multi_stft_resolution_loss)
|
|
|
|
| 982 |
|
| 983 |
# istft
|
| 984 |
|
| 985 |
+
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
|
| 986 |
+
|
| 987 |
+
try:
|
| 988 |
+
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
|
| 989 |
+
except:
|
| 990 |
+
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
|
| 991 |
+
|
| 992 |
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
|
| 993 |
|
| 994 |
if num_stems == 1:
|
|
|
|
| 1032 |
if not return_loss_breakdown:
|
| 1033 |
return total_loss
|
| 1034 |
|
| 1035 |
+
return total_loss, (loss, multi_stft_resolution_loss)
|