lord-reso commited on
Commit
5c2c92b
·
verified ·
1 Parent(s): 7956bf6

Update stft.py

Browse files
Files changed (1) hide show
  1. stft.py +1 -3
stft.py CHANGED
@@ -38,8 +38,6 @@ from scipy.signal import get_window
38
  from librosa.util import pad_center, tiny
39
  from audio_processing import window_sumsquare
40
 
41
- use_cuda = torch.cuda.is_available()
42
- device = torch.device('cuda' if use_cuda else 'cpu')
43
 
44
  class STFT(torch.nn.Module):
45
  """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
@@ -126,7 +124,7 @@ class STFT(torch.nn.Module):
126
  np.where(window_sum > tiny(window_sum))[0])
127
  window_sum = torch.autograd.Variable(
128
  torch.from_numpy(window_sum), requires_grad=False)
129
- window_sum = window_sum.to(device) if magnitude.is_cuda else window_sum
130
  inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
131
 
132
  # scale by hop ratio
 
38
  from librosa.util import pad_center, tiny
39
  from audio_processing import window_sumsquare
40
 
 
 
41
 
42
  class STFT(torch.nn.Module):
43
  """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
 
124
  np.where(window_sum > tiny(window_sum))[0])
125
  window_sum = torch.autograd.Variable(
126
  torch.from_numpy(window_sum), requires_grad=False)
127
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
128
  inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
129
 
130
  # scale by hop ratio