Upload feature extractor
Browse files
feature_extraction_gramt_binaural_frame.py
CHANGED
|
@@ -48,6 +48,7 @@ class BinauralFeatureExtractor(SequenceFeatureExtractor):
|
|
| 48 |
def _extract_fbank_features(
|
| 49 |
self,
|
| 50 |
waveform: np.ndarray,
|
|
|
|
| 51 |
) -> np.ndarray:
|
| 52 |
"""
|
| 53 |
Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
|
|
@@ -65,7 +66,9 @@ class BinauralFeatureExtractor(SequenceFeatureExtractor):
|
|
| 65 |
)
|
| 66 |
|
| 67 |
waveform = torch.tensor(waveform.clone().detach())
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
# If waveform has two channels, but the channel information is not the first dimension, transpose.
|
| 70 |
if (waveform.ndim == 2) and (waveform.shape[0] > 100):
|
| 71 |
waveform = waveform.transpose(1, 0)
|
|
@@ -106,6 +109,7 @@ class BinauralFeatureExtractor(SequenceFeatureExtractor):
|
|
| 106 |
raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
|
| 107 |
sampling_rate: Optional[int] = None,
|
| 108 |
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
|
|
| 109 |
**kwargs,
|
| 110 |
) -> BatchFeature:
|
| 111 |
"""
|
|
@@ -136,7 +140,7 @@ class BinauralFeatureExtractor(SequenceFeatureExtractor):
|
|
| 136 |
)
|
| 137 |
|
| 138 |
# extract fbank features and pad/truncate to max_length
|
| 139 |
-
features = [self._extract_fbank_features(waveform) for waveform in raw_speech]
|
| 140 |
features = torch.nn.utils.rnn.pad_sequence(features, batch_first=True)
|
| 141 |
inputs = BatchFeature({"input_values": features})
|
| 142 |
return inputs
|
|
|
|
| 48 |
def _extract_fbank_features(
|
| 49 |
self,
|
| 50 |
waveform: np.ndarray,
|
| 51 |
+
normalize : bool,
|
| 52 |
) -> np.ndarray:
|
| 53 |
"""
|
| 54 |
Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
|
|
|
|
| 66 |
)
|
| 67 |
|
| 68 |
waveform = torch.tensor(waveform.clone().detach())
|
| 69 |
+
melspec.to(waveform.device)
|
| 70 |
+
if normalize:
|
| 71 |
+
waveform = self._normalize_audio(waveform)
|
| 72 |
# If waveform has two channels, but the channel information is not the first dimension, transpose.
|
| 73 |
if (waveform.ndim == 2) and (waveform.shape[0] > 100):
|
| 74 |
waveform = waveform.transpose(1, 0)
|
|
|
|
| 109 |
raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
|
| 110 |
sampling_rate: Optional[int] = None,
|
| 111 |
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 112 |
+
normalize : bool = True,
|
| 113 |
**kwargs,
|
| 114 |
) -> BatchFeature:
|
| 115 |
"""
|
|
|
|
| 140 |
)
|
| 141 |
|
| 142 |
# extract fbank features and pad/truncate to max_length
|
| 143 |
+
features = [self._extract_fbank_features(waveform, normalize) for waveform in raw_speech]
|
| 144 |
features = torch.nn.utils.rnn.pad_sequence(features, batch_first=True)
|
| 145 |
inputs = BatchFeature({"input_values": features})
|
| 146 |
return inputs
|