Chatanya commited on
Commit
1348175
·
verified ·
1 Parent(s): 629b847

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Outline:
3
+ - Create animation: animate charts (potentially using streamlit)
4
+ '''
5
+ import librosa
6
+ import streamlit as st
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import pandas as pd
10
+ import pickle
11
+ import keras
12
+ import tensorflow
13
+ import matplotlib.animation as animation
14
+
15
+ model_path = "model_simple.sav" #Defines the path to the model file
16
+
17
+ emotion_map = {
18
+ 'Disgust': 0,
19
+ 'Happiness': 1,
20
+ 'Saddness': 2,
21
+ 'Neutral': 3,
22
+ 'Fear': 4,
23
+ 'Anger': 5,
24
+ 'Surprise': 6
25
+ } #Maps emotions to integers: taken from data preprocessing
26
+
27
+ reversed_emotion_map = {value:key for key, value in emotion_map.items()}
28
+ #Reverses emotion mapping such that integers can be mapped into emotions
29
+
30
+ #Uses librosa to load the inputted audio file as a list of frequency values
31
+ @st.cache_data
32
+ def process_audio(input_file):
33
+ st.audio(input_file) #Creates an audio player within the streamlit app
34
+ audio_signal, sample_rate = librosa.load(input_file)
35
+ return audio_signal, sample_rate
36
+
37
+ #Creates a line chart displaying the audio frequency using librosa
38
+ def display_spectrum_animation(audio_signal, sample_rate):
39
+ S = np.abs(librosa.stft(audio_signal))
40
+ frequencies = librosa.fft_frequencies(sr=sample_rate)
41
+
42
+ fig, ax = plt.subplots()
43
+
44
+ def update_spectrum(num, S, ax):
45
+ ax.clear()
46
+ ax.plot(frequencies, S[:, num])
47
+ ax.set_xlabel("Frequency (Hz)")
48
+ ax.set_ylabel("Amplitude")
49
+
50
+ ani = animation.FuncAnimation(fig, update_spectrum, frames=S.shape[1], fargs=[S, ax], blit=False)
51
+ ani.save("spectrum_animation.gif", writer="imagemagick")
52
+ st.image("spectrum_animation.gif")
53
+
54
+
55
+ @st.cache_data
56
+ def display_frequency(audio_signal, sample_rate):
57
+ frequency_plot = librosa.display.waveshow(audio_signal, sr = sample_rate)
58
+ st.pyplot(plt.gcf())
59
+
60
+ #Creates and displays a mel spectrogram using librosa
61
+ @st.cache_data
62
+ def display_mel_spectogram(audio_signal, sample_rate):
63
+ fig, ax = plt.subplots()
64
+ audio_time = audio_signal.shape[0]/sample_rate
65
+ D = librosa.amplitude_to_db(np.abs(librosa.stft(audio_signal)), ref = np.max)
66
+
67
+ amt_to_add = int(D.shape[-1]/audio_time)
68
+
69
+ specshow = librosa.display.specshow(D, sr = sample_rate, x_axis = "time", y_axis = "log", ax = ax)
70
+
71
+ def update_spectrogram (num, D, ax, plus):
72
+ ax.clear()
73
+ librosa.display.specshow(D[:, :num + plus], sr = sample_rate, x_axis = "time", y_axis = "log", ax = ax)
74
+
75
+ ani = animation.FuncAnimation(fig, update_spectrogram, frames = np.arange(1, D.shape[1]), fargs = [D, ax, amt_to_add], blit = False)
76
+ ani.save("spectrogram_animation.gif", writer = "imagemagick")
77
+ st.image("spectrogram_animation.gif")
78
+
79
+ #Creates the interface allowing users to select which plot they want displayed
80
+ def create_selections(audio_signal, sample_rate):
81
+ chart_options = ["Spectrum", "Mel-Spectogram"] #Graph titles go here
82
+ functions = [display_spectrum_animation, display_mel_spectogram] #Graphing functions go here
83
+ chart_selector = st.radio(
84
+ label = "",
85
+ options = chart_options,
86
+ horizontal = True
87
+ )
88
+ selection_index = chart_options.index(chart_selector)
89
+ functions[selection_index](audio_signal, sample_rate)
90
+
91
+ #Helper function to force the length of a given frequency array into a specific length
92
+ #Currently, this length is hard-coded at 66,150 though that may change in the future
93
+ @st.cache_data
94
+ def standardize_waveform_length(waveform):
95
+ audio_length = 66150
96
+ if len(waveform) > audio_length:
97
+ waveform = waveform[:audio_length]
98
+ else:
99
+ waveform = np.pad(waveform, (0, max(0, audio_length - len(waveform))), "constant")
100
+ return waveform
101
+
102
+ #Takes in a given audio signal and returns its mel-frequency cepstral coefficients
103
+ @st.cache_data
104
+ def preprocess_audio_for_prediction(audio_signal, sample_rate):
105
+ waveform = standardize_waveform_length(waveform = audio_signal)
106
+ mfcc = librosa.feature.mfcc(y = waveform, sr = sample_rate, n_mels = 128)
107
+ mfcc = mfcc.reshape(-1)
108
+ return mfcc
109
+
110
+ #Loads the model given in model_path and returns a Keras Sequential model
111
+ @st.cache_data
112
+ def load_model(model_path):
113
+ model = pickle.load(open(model_path, "rb"))
114
+ return model
115
+
116
+ #Uses the model to predict the speaker's emotion in the given audio clip
117
+ @st.cache_data
118
+ def get_emotion_prediction(mfcc):
119
+ model = load_model(model_path)
120
+ prediction = model.predict(mfcc[None])
121
+ predicted_index = np.argmax(prediction)
122
+ emotion = reversed_emotion_map[predicted_index]
123
+ return emotion
124
+
125
+ #Combines all model functions and displays the model output as a subheader
126
+ @st.cache_data
127
+ def display_prediction(audio_signal, sample_rate):
128
+ mfcc = preprocess_audio_for_prediction(audio_signal, sample_rate)
129
+ prediction = get_emotion_prediction(mfcc)
130
+ st.subheader("Predicted Emotion: " + prediction, divider = True)
131
+
132
+ #Defines the entire process of inputting audio, displaying the model's predictions, and displaying graphs
133
+ def run(input_file):
134
+ audio_signal, sample_rate = process_audio(input_file)
135
+ display_prediction(audio_signal, sample_rate)
136
+ create_selections(audio_signal, sample_rate)
137
+
138
+ #Creates an input area to upload the file
139
+ def main():
140
+ st.header("Upload your file here")
141
+ file_uploader = st.file_uploader("", type = "wav")
142
+ if file_uploader is not None:
143
+ run(file_uploader)
144
+
145
+ if __name__ == "__main__":
146
+ main()