Commit ·
f9c55bd
1
Parent(s): 8093841
Updated to input wav file directly
Browse files- feature_extractor.py +10 -3
feature_extractor.py
CHANGED
|
@@ -40,24 +40,31 @@ def change_sample_rate(y, sample_rate, new_sample_rate):
|
|
| 40 |
value = librosa.resample(y, sample_rate, new_sample_rate)
|
| 41 |
return value
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def get_wav2vecembeddings_from_audiofile(wav_file):
|
| 44 |
print("the file is", wav_file)
|
| 45 |
speech, sample_rate = sf.read(wav_file)
|
|
|
|
|
|
|
|
|
|
| 46 |
# change sample rate to 16 000 hertz
|
| 47 |
resampled = change_sample_rate(speech, sample_rate, new_sample_rate)
|
| 48 |
print("the speech is", speech)
|
| 49 |
-
input_values = processor(
|
| 50 |
print("input values", input_values)
|
| 51 |
# import pdb
|
| 52 |
# pdb.set_trace()
|
| 53 |
|
| 54 |
with torch.no_grad():
|
| 55 |
encoded_states = model(
|
| 56 |
-
|
| 57 |
# attention_mask=input_values["attention_mask"],
|
| 58 |
output_hidden_states=True
|
| 59 |
)
|
| 60 |
-
|
| 61 |
last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
|
| 62 |
print("getting wav2vec2 embeddings")
|
| 63 |
print(last_hidden_state)
|
|
|
|
| 40 |
value = librosa.resample(y, sample_rate, new_sample_rate)
|
| 41 |
return value
|
| 42 |
|
| 43 |
+
def stereo_to_mono(audio_input):
|
| 44 |
+
X = audio_input.mean(axis=1, keepdims=True)
|
| 45 |
+
X = np.squeeze(X)
|
| 46 |
+
return X
|
| 47 |
+
|
| 48 |
def get_wav2vecembeddings_from_audiofile(wav_file):
|
| 49 |
print("the file is", wav_file)
|
| 50 |
speech, sample_rate = sf.read(wav_file)
|
| 51 |
+
|
| 52 |
+
if len(speech.shape) > 1:
|
| 53 |
+
speech = stereo_to_mono(speech)
|
| 54 |
# change sample rate to 16 000 hertz
|
| 55 |
resampled = change_sample_rate(speech, sample_rate, new_sample_rate)
|
| 56 |
print("the speech is", speech)
|
| 57 |
+
input_values = processor(wav_file, return_tensors="pt", padding=True, sampling_rate=new_sample_rate) # there is no truncation param anymore
|
| 58 |
print("input values", input_values)
|
| 59 |
# import pdb
|
| 60 |
# pdb.set_trace()
|
| 61 |
|
| 62 |
with torch.no_grad():
|
| 63 |
encoded_states = model(
|
| 64 |
+
input_values=input_values["input_ids"],
|
| 65 |
# attention_mask=input_values["attention_mask"],
|
| 66 |
output_hidden_states=True
|
| 67 |
)
|
|
|
|
| 68 |
last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
|
| 69 |
print("getting wav2vec2 embeddings")
|
| 70 |
print(last_hidden_state)
|