Pj12 commited on
Commit
501b9c4
·
verified ·
1 Parent(s): f6675ab

Update extract_feature_print.py

Browse files
Files changed (1) hide show
  1. extract_feature_print.py +4 -1
extract_feature_print.py CHANGED
@@ -190,7 +190,10 @@ os.makedirs(outPath, exist_ok=True)
190
  def readwave(wav_path, normalize=False):
191
  wav, sr = sf.read(wav_path)
192
  assert sr == 16000
193
- feats = torch.from_numpy(wav).float()
 
 
 
194
  if feats.dim() == 2: # double channels
195
  feats = feats.mean(-1)
196
  assert feats.dim() == 1, feats.dim()
 
190
  def readwave(wav_path, normalize=False):
191
  wav, sr = sf.read(wav_path)
192
  assert sr == 16000
193
+ if Custom_Embed:
194
+ feats = torch.from_numpy(wav).float()
195
+ else:
196
+ feats = torch.from_numpy(load_audio(wav, sr)).to(dtype).to(device)
197
  if feats.dim() == 2: # double channels
198
  feats = feats.mean(-1)
199
  assert feats.dim() == 1, feats.dim()