dipsmom commited on
Commit
8f441e2
·
verified ·
1 Parent(s): 3ef6cdb

Update audio_feature_extraction.py

Browse files
Files changed (1) hide show
  1. audio_feature_extraction.py +17 -35
audio_feature_extraction.py CHANGED
@@ -1,41 +1,23 @@
1
- import gradio as gr
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- from audio_feature_extraction import extract_features
5
 
6
- def plot_spectrum(audio_path):
7
- # Extract features from the audio file
8
- mean_features = extract_features(audio_path)
9
 
10
- # Perform FFT to obtain frequency components
11
- fft_spectrum = np.fft.fft(mean_features)
12
- frequencies = np.fft.fftfreq(len(fft_spectrum), d=1/16000)
13
- magnitude_spectrum = np.abs(fft_spectrum)
14
 
15
- # Plot the frequency spectrum
16
- plt.figure(figsize=(12, 6))
17
- plt.plot(frequencies[:len(frequencies)//2], magnitude_spectrum[:len(magnitude_spectrum)//2])
18
- plt.xlabel("Frequency (Hz)")
19
- plt.ylabel("Magnitude")
20
- plt.title("Frequency Spectrum of the Audio File")
21
- plt.grid()
22
- plt.tight_layout()
23
 
24
- # Save the plot to a file
25
- plot_path = "spectrum_plot.png"
26
- plt.savefig(plot_path)
27
- plt.close()
28
 
29
- return plot_path
 
30
 
31
- # Define the Gradio interface
32
- iface = gr.Interface(
33
- fn=plot_spectrum,
34
- inputs=gr.Audio(source="upload", type="filepath"),
35
- outputs=gr.Image(type="filepath"),
36
- title="Audio Feature Extraction with Wav2Vec2",
37
- description="Upload an audio file to extract features and view the frequency spectrum."
38
- )
39
-
40
- if __name__ == "__main__":
41
- iface.launch()
 
1
+ import torch
2
+ import librosa
3
+ from transformers import Wav2Vec2Processor, Wav2Vec2Model
 
4
 
5
+ def extract_features(audio_path):
6
+ # Load the audio file
7
+ audio_data, sr = librosa.load(audio_path, sr=16000) # Ensure 16 kHz sampling rate
8
 
9
+ # Load Wav2Vec2 processor and model
10
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
11
+ model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
 
12
 
13
+ # Process the audio data
14
+ input_values = processor(audio_data, return_tensors="pt", sampling_rate=16000).input_values
 
 
 
 
 
 
15
 
16
+ # Extract features
17
+ with torch.no_grad():
18
+ features = model(input_values).last_hidden_state
 
19
 
20
+ # Compute the mean feature vector
21
+ mean_features = features.mean(dim=1).squeeze().cpu().numpy()
22
 
23
+ return mean_features