h3rsh commited on
Commit
e912d0a
Β·
verified Β·
1 Parent(s): 4a70560

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -0
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import librosa
4
+ import pickle
5
+ import tensorflow as tf
6
+ import gradio as gr
7
+ from scipy import signal
8
+ import warnings
9
+ import tempfile
10
+
11
+ warnings.filterwarnings("ignore", message="Trying to estimate tuning from empty frequency set.")
12
+
13
+ # Common parameters (must match training parameters)
14
+ target_sr = 22050
15
+ target_duration = 4
16
+ n_fft = 512
17
+ hop_length = 512
18
+
19
+ class RespiratoryPredictor:
20
+ def __init__(self, model_path='respiratory_model.keras', scalers_path='scalers.pkl',
21
+ norm_params_path='norm_params.pkl', class_names_path='class_names.pkl'):
22
+ """Initialize the predictor with trained model and scalers."""
23
+ self.target_sr = target_sr
24
+ self.target_duration = target_duration
25
+ self.n_fft = n_fft
26
+ self.hop_length = hop_length
27
+
28
+ # Load model
29
+ try:
30
+ self.model = tf.keras.models.load_model(model_path)
31
+ print(f"βœ“ Model loaded from {model_path}")
32
+ except Exception as e:
33
+ print(f"βœ— Error loading model: {e}")
34
+ raise
35
+
36
+ # Load scalers
37
+ try:
38
+ with open(scalers_path, 'rb') as f:
39
+ self.scalers = pickle.load(f)
40
+ print(f"βœ“ Scalers loaded from {scalers_path}")
41
+ except Exception as e:
42
+ print(f"βœ— Error loading scalers: {e}")
43
+ raise
44
+
45
+ # Load normalization parameters
46
+ try:
47
+ with open(norm_params_path, 'rb') as f:
48
+ self.norm_params = pickle.load(f)
49
+ print(f"βœ“ Normalization parameters loaded from {norm_params_path}")
50
+ except Exception as e:
51
+ print(f"βœ— Error loading normalization parameters: {e}")
52
+ raise
53
+
54
+ # Load class names
55
+ try:
56
+ with open(class_names_path, 'rb') as f:
57
+ self.class_names = pickle.load(f)
58
+ print(f"βœ“ Class names loaded from {class_names_path}")
59
+ except Exception as e:
60
+ print(f"βœ— Error loading class names: {e}")
61
+ raise
62
+
63
+ def denoise_audio(self, audio, sr, methods=['adaptive_median', 'bandpass']):
64
+ """Denoise audio signal"""
65
+ denoised_audio = audio.copy()
66
+
67
+ for method in methods:
68
+ if method == 'adaptive_median':
69
+ window_size = int(sr * 0.01) # 10 ms window
70
+ if window_size % 2 == 0:
71
+ window_size += 1
72
+ denoised_audio = signal.medfilt(denoised_audio, kernel_size=window_size)
73
+ elif method == 'bandpass':
74
+ low_freq = 50
75
+ high_freq = 2000
76
+ nyquist = sr / 2
77
+ low = low_freq / nyquist
78
+ high = high_freq / nyquist
79
+ b, a = signal.butter(4, [low, high], btype='band')
80
+ denoised_audio = signal.filtfilt(b, a, denoised_audio)
81
+
82
+ return denoised_audio
83
+
84
+ def extract_features(self, audio_data, sr):
85
+ """Extract features from audio in the same format as during training"""
86
+ # Mel spectrogram
87
+ mel_spec = librosa.feature.melspectrogram(
88
+ y=audio_data, sr=sr, n_mels=128, n_fft=self.n_fft, hop_length=self.hop_length)
89
+ mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
90
+
91
+ # MFCC
92
+ mfcc = librosa.feature.mfcc(y=audio_data, sr=sr, n_mfcc=20, hop_length=self.hop_length)
93
+
94
+ # Chroma
95
+ chroma = librosa.feature.chroma_stft(y=audio_data, sr=sr, hop_length=self.hop_length)
96
+
97
+ features = {
98
+ 'mel_spec': mel_spec_db,
99
+ 'mfcc': mfcc,
100
+ 'chroma': chroma
101
+ }
102
+ return features
103
+
104
+ def pad_or_crop(self, arr, shape):
105
+ """Pad or crop array to target shape"""
106
+ out = np.zeros(shape, dtype=arr.dtype)
107
+ n_feat, n_fr = arr.shape
108
+ out[:min(n_feat, shape[0]), :min(n_fr, shape[1])] = arr[:shape[0], :shape[1]]
109
+ return out
110
+
111
+ def prepare_input_data(self, features, n_frames=259):
112
+ """Prepare input data for the multi-input model"""
113
+ mfcc = self.pad_or_crop(features['mfcc'], (20, n_frames))
114
+ chroma = self.pad_or_crop(features['chroma'], (12, n_frames))
115
+ mspec = self.pad_or_crop(features['mel_spec'], (128, n_frames))
116
+
117
+ # Add channel dimension
118
+ X_mfcc = mfcc[..., np.newaxis]
119
+ X_chroma = chroma[..., np.newaxis]
120
+ X_mspec = mspec[..., np.newaxis]
121
+
122
+ return X_mfcc, X_chroma, X_mspec
123
+
124
+ def normalize_features(self, X_mfcc, X_chroma, X_mspec):
125
+ """Normalize features using the same parameters as training"""
126
+ def norm(X, mean, std):
127
+ Xf = X.reshape(X.shape[0], -1)
128
+ Xn = (Xf - mean) / (std + 1e-8)
129
+ return Xn.reshape(X.shape)
130
+
131
+ X_mfcc_norm = norm(X_mfcc, self.norm_params['mfcc_mean'], self.norm_params['mfcc_std'])
132
+ X_chroma_norm = norm(X_chroma, self.norm_params['chroma_mean'], self.norm_params['chroma_std'])
133
+ X_mspec_norm = norm(X_mspec, self.norm_params['mspec_mean'], self.norm_params['mspec_std'])
134
+
135
+ return X_mfcc_norm, X_chroma_norm, X_mspec_norm
136
+
137
+ def predict_audio(self, audio_file_path):
138
+ """
139
+ Predict the class of an audio file for Gradio interface.
140
+
141
+ Args:
142
+ audio_file_path: Path to the uploaded audio file
143
+
144
+ Returns:
145
+ tuple: (prediction_text, confidence_text, probabilities_dict)
146
+ """
147
+ try:
148
+ # Load and process audio
149
+ audio, sr = librosa.load(audio_file_path, sr=self.target_sr, duration=self.target_duration)
150
+
151
+ # Ensure audio is the right length
152
+ target_samples = self.target_sr * self.target_duration
153
+ if len(audio) < target_samples:
154
+ audio = np.pad(audio, (0, target_samples - len(audio)), mode='constant')
155
+ elif len(audio) > target_samples:
156
+ audio = audio[:target_samples]
157
+
158
+ # Denoise audio
159
+ denoised_audio = self.denoise_audio(audio, self.target_sr)
160
+
161
+ # Extract features
162
+ features = self.extract_features(denoised_audio, self.target_sr)
163
+
164
+ # Prepare input data
165
+ X_mfcc, X_chroma, X_mspec = self.prepare_input_data(features)
166
+
167
+ # Normalize features
168
+ X_mfcc_norm, X_chroma_norm, X_mspec_norm = self.normalize_features(X_mfcc, X_chroma, X_mspec)
169
+
170
+ # Add batch dimension
171
+ X_mfcc_batch = np.expand_dims(X_mfcc_norm, axis=0)
172
+ X_chroma_batch = np.expand_dims(X_chroma_norm, axis=0)
173
+ X_mspec_batch = np.expand_dims(X_mspec_norm, axis=0)
174
+
175
+ # Make prediction
176
+ prediction_prob = self.model.predict([X_mfcc_batch, X_chroma_batch, X_mspec_batch], verbose=0)
177
+ prediction = int(np.argmax(prediction_prob[0]))
178
+ confidence = float(np.max(prediction_prob[0]))
179
+
180
+ # Get class name
181
+ class_name = self.class_names[prediction] if prediction < len(self.class_names) else f"Class {prediction}"
182
+
183
+ # Format results for Gradio
184
+ prediction_text = f"🎯 **Prediction**: {class_name}"
185
+ confidence_text = f"πŸ“Š **Confidence**: {confidence:.2%}"
186
+
187
+ # Create probabilities dictionary for all classes
188
+ probabilities_dict = {}
189
+ for i, (class_name_item, prob) in enumerate(zip(self.class_names, prediction_prob[0])):
190
+ probabilities_dict[class_name_item] = float(prob)
191
+
192
+ return prediction_text, confidence_text, probabilities_dict
193
+
194
+ except Exception as e:
195
+ error_msg = f"❌ Error processing audio: {str(e)}"
196
+ return error_msg, "", {}
197
+
198
+ # Initialize the predictor
199
+ print("Loading model and components...")
200
+ try:
201
+ predictor = RespiratoryPredictor()
202
+ print("βœ… All components loaded successfully!")
203
+ except Exception as e:
204
+ print(f"❌ Failed to initialize predictor: {e}")
205
+ raise
206
+
207
+ def predict_respiratory_sound(audio_file):
208
+ """
209
+ Gradio interface function for respiratory sound prediction.
210
+
211
+ Args:
212
+ audio_file: Uploaded audio file from Gradio
213
+
214
+ Returns:
215
+ tuple: (prediction, confidence, probabilities)
216
+ """
217
+ if audio_file is None:
218
+ return "⚠️ Please upload an audio file", "", {}
219
+
220
+ return predictor.predict_audio(audio_file)
221
+
222
+ # Create Gradio interface
223
+ with gr.Blocks(title="Respiratory Sound Classifier", theme=gr.themes.Soft()) as demo:
224
+ gr.Markdown(
225
+ """
226
+ # 🫁 Respiratory Sound Classification
227
+
228
+ Upload an audio file containing respiratory sounds to classify the type of breathing pattern.
229
+
230
+ **Supported formats**: WAV, MP3, M4A, FLAC
231
+ **Duration**: Audio will be processed as 4-second segments
232
+ """
233
+ )
234
+
235
+ with gr.Row():
236
+ with gr.Column():
237
+ audio_input = gr.Audio(
238
+ label="πŸ“€ Upload Respiratory Sound",
239
+ type="filepath",
240
+ sources=["upload"]
241
+ )
242
+
243
+ predict_btn = gr.Button("πŸ” Analyze Sound", variant="primary")
244
+
245
+ with gr.Column():
246
+ prediction_output = gr.Markdown(label="🎯 Prediction")
247
+ confidence_output = gr.Markdown(label="πŸ“Š Confidence")
248
+
249
+ probabilities_output = gr.Label(
250
+ label="πŸ“ˆ Class Probabilities",
251
+ num_top_classes=len(predictor.class_names)
252
+ )
253
+
254
+ # Event handlers
255
+ predict_btn.click(
256
+ fn=predict_respiratory_sound,
257
+ inputs=[audio_input],
258
+ outputs=[prediction_output, confidence_output, probabilities_output]
259
+ )
260
+
261
+ # Auto-predict when file is uploaded
262
+ audio_input.change(
263
+ fn=predict_respiratory_sound,
264
+ inputs=[audio_input],
265
+ outputs=[prediction_output, confidence_output, probabilities_output]
266
+ )
267
+
268
+ gr.Markdown(
269
+ """
270
+ ---
271
+
272
+ ### ℹ️ About
273
+ This model classifies respiratory sounds into different categories.
274
+ Upload clear audio recordings of breathing sounds for best results.
275
+
276
+ **Note**: This is for research/educational purposes only and should not be used for medical diagnosis.
277
+ """
278
+ )
279
+
280
+ # Launch the app
281
+ if __name__ == "__main__":
282
+ demo.launch()