Update extract_feature_print.py
Browse files- extract_feature_print.py +1 -1
extract_feature_print.py
CHANGED
|
@@ -206,7 +206,7 @@ def readwave(wav_path, normalize=False):
|
|
| 206 |
if Custom_Embed == False:
|
| 207 |
feats = torch.from_numpy(wav).float()
|
| 208 |
else:
|
| 209 |
-
feats = torch.from_numpy(load_audio(
|
| 210 |
if feats.dim() == 2: # double channels
|
| 211 |
feats = feats.mean(-1)
|
| 212 |
assert feats.dim() == 1, feats.dim()
|
|
|
|
| 206 |
if Custom_Embed == False:
|
| 207 |
feats = torch.from_numpy(wav).float()
|
| 208 |
else:
|
| 209 |
+
feats = torch.from_numpy(load_audio(wav_path, sr)).to(dtype).to(device)
|
| 210 |
if feats.dim() == 2: # double channels
|
| 211 |
feats = feats.mean(-1)
|
| 212 |
assert feats.dim() == 1, feats.dim()
|