GokseninYuksel commited on
Commit
ec039f2
·
verified ·
1 Parent(s): f0497da

Upload feature extractor

Browse files
feature_extraction_gramt_binaural_frame.py CHANGED
@@ -48,6 +48,7 @@ class BinauralFeatureExtractor(SequenceFeatureExtractor):
48
  def _extract_fbank_features(
49
  self,
50
  waveform: np.ndarray,
 
51
  ) -> np.ndarray:
52
  """
53
  Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
@@ -65,7 +66,9 @@ class BinauralFeatureExtractor(SequenceFeatureExtractor):
65
  )
66
 
67
  waveform = torch.tensor(waveform.clone().detach())
68
- waveform = self._normalize_audio(waveform)
 
 
69
  # If waveform has two channels, but the channel information is not the first dimension, transpose.
70
  if (waveform.ndim == 2) and (waveform.shape[0] > 100):
71
  waveform = waveform.transpose(1, 0)
@@ -106,6 +109,7 @@ class BinauralFeatureExtractor(SequenceFeatureExtractor):
106
  raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
107
  sampling_rate: Optional[int] = None,
108
  return_tensors: Optional[Union[str, TensorType]] = None,
 
109
  **kwargs,
110
  ) -> BatchFeature:
111
  """
@@ -136,7 +140,7 @@ class BinauralFeatureExtractor(SequenceFeatureExtractor):
136
  )
137
 
138
  # extract fbank features and pad/truncate to max_length
139
- features = [self._extract_fbank_features(waveform) for waveform in raw_speech]
140
  features = torch.nn.utils.rnn.pad_sequence(features, batch_first=True)
141
  inputs = BatchFeature({"input_values": features})
142
  return inputs
 
48
  def _extract_fbank_features(
49
  self,
50
  waveform: np.ndarray,
51
+ normalize : bool,
52
  ) -> np.ndarray:
53
  """
54
  Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
 
66
  )
67
 
68
  waveform = torch.tensor(waveform.clone().detach())
69
+ melspec.to(waveform.device)
70
+ if normalize:
71
+ waveform = self._normalize_audio(waveform)
72
  # If waveform has two channels, but the channel information is not the first dimension, transpose.
73
  if (waveform.ndim == 2) and (waveform.shape[0] > 100):
74
  waveform = waveform.transpose(1, 0)
 
109
  raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
110
  sampling_rate: Optional[int] = None,
111
  return_tensors: Optional[Union[str, TensorType]] = None,
112
+ normalize : bool = True,
113
  **kwargs,
114
  ) -> BatchFeature:
115
  """
 
140
  )
141
 
142
  # extract fbank features and pad/truncate to max_length
143
+ features = [self._extract_fbank_features(waveform, normalize) for waveform in raw_speech]
144
  features = torch.nn.utils.rnn.pad_sequence(features, batch_first=True)
145
  inputs = BatchFeature({"input_values": features})
146
  return inputs