resp / app.py
h3rsh's picture
Create app.py
e912d0a verified
import os
import numpy as np
import librosa
import pickle
import tensorflow as tf
import gradio as gr
from scipy import signal
import warnings
import tempfile
warnings.filterwarnings("ignore", message="Trying to estimate tuning from empty frequency set.")
# Common parameters (must match training parameters)
target_sr = 22050
target_duration = 4
n_fft = 512
hop_length = 512
class RespiratoryPredictor:
def __init__(self, model_path='respiratory_model.keras', scalers_path='scalers.pkl',
norm_params_path='norm_params.pkl', class_names_path='class_names.pkl'):
"""Initialize the predictor with trained model and scalers."""
self.target_sr = target_sr
self.target_duration = target_duration
self.n_fft = n_fft
self.hop_length = hop_length
# Load model
try:
self.model = tf.keras.models.load_model(model_path)
print(f"βœ“ Model loaded from {model_path}")
except Exception as e:
print(f"βœ— Error loading model: {e}")
raise
# Load scalers
try:
with open(scalers_path, 'rb') as f:
self.scalers = pickle.load(f)
print(f"βœ“ Scalers loaded from {scalers_path}")
except Exception as e:
print(f"βœ— Error loading scalers: {e}")
raise
# Load normalization parameters
try:
with open(norm_params_path, 'rb') as f:
self.norm_params = pickle.load(f)
print(f"βœ“ Normalization parameters loaded from {norm_params_path}")
except Exception as e:
print(f"βœ— Error loading normalization parameters: {e}")
raise
# Load class names
try:
with open(class_names_path, 'rb') as f:
self.class_names = pickle.load(f)
print(f"βœ“ Class names loaded from {class_names_path}")
except Exception as e:
print(f"βœ— Error loading class names: {e}")
raise
def denoise_audio(self, audio, sr, methods=['adaptive_median', 'bandpass']):
"""Denoise audio signal"""
denoised_audio = audio.copy()
for method in methods:
if method == 'adaptive_median':
window_size = int(sr * 0.01) # 10 ms window
if window_size % 2 == 0:
window_size += 1
denoised_audio = signal.medfilt(denoised_audio, kernel_size=window_size)
elif method == 'bandpass':
low_freq = 50
high_freq = 2000
nyquist = sr / 2
low = low_freq / nyquist
high = high_freq / nyquist
b, a = signal.butter(4, [low, high], btype='band')
denoised_audio = signal.filtfilt(b, a, denoised_audio)
return denoised_audio
def extract_features(self, audio_data, sr):
"""Extract features from audio in the same format as during training"""
# Mel spectrogram
mel_spec = librosa.feature.melspectrogram(
y=audio_data, sr=sr, n_mels=128, n_fft=self.n_fft, hop_length=self.hop_length)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
# MFCC
mfcc = librosa.feature.mfcc(y=audio_data, sr=sr, n_mfcc=20, hop_length=self.hop_length)
# Chroma
chroma = librosa.feature.chroma_stft(y=audio_data, sr=sr, hop_length=self.hop_length)
features = {
'mel_spec': mel_spec_db,
'mfcc': mfcc,
'chroma': chroma
}
return features
def pad_or_crop(self, arr, shape):
"""Pad or crop array to target shape"""
out = np.zeros(shape, dtype=arr.dtype)
n_feat, n_fr = arr.shape
out[:min(n_feat, shape[0]), :min(n_fr, shape[1])] = arr[:shape[0], :shape[1]]
return out
def prepare_input_data(self, features, n_frames=259):
"""Prepare input data for the multi-input model"""
mfcc = self.pad_or_crop(features['mfcc'], (20, n_frames))
chroma = self.pad_or_crop(features['chroma'], (12, n_frames))
mspec = self.pad_or_crop(features['mel_spec'], (128, n_frames))
# Add channel dimension
X_mfcc = mfcc[..., np.newaxis]
X_chroma = chroma[..., np.newaxis]
X_mspec = mspec[..., np.newaxis]
return X_mfcc, X_chroma, X_mspec
def normalize_features(self, X_mfcc, X_chroma, X_mspec):
"""Normalize features using the same parameters as training"""
def norm(X, mean, std):
Xf = X.reshape(X.shape[0], -1)
Xn = (Xf - mean) / (std + 1e-8)
return Xn.reshape(X.shape)
X_mfcc_norm = norm(X_mfcc, self.norm_params['mfcc_mean'], self.norm_params['mfcc_std'])
X_chroma_norm = norm(X_chroma, self.norm_params['chroma_mean'], self.norm_params['chroma_std'])
X_mspec_norm = norm(X_mspec, self.norm_params['mspec_mean'], self.norm_params['mspec_std'])
return X_mfcc_norm, X_chroma_norm, X_mspec_norm
def predict_audio(self, audio_file_path):
"""
Predict the class of an audio file for Gradio interface.
Args:
audio_file_path: Path to the uploaded audio file
Returns:
tuple: (prediction_text, confidence_text, probabilities_dict)
"""
try:
# Load and process audio
audio, sr = librosa.load(audio_file_path, sr=self.target_sr, duration=self.target_duration)
# Ensure audio is the right length
target_samples = self.target_sr * self.target_duration
if len(audio) < target_samples:
audio = np.pad(audio, (0, target_samples - len(audio)), mode='constant')
elif len(audio) > target_samples:
audio = audio[:target_samples]
# Denoise audio
denoised_audio = self.denoise_audio(audio, self.target_sr)
# Extract features
features = self.extract_features(denoised_audio, self.target_sr)
# Prepare input data
X_mfcc, X_chroma, X_mspec = self.prepare_input_data(features)
# Normalize features
X_mfcc_norm, X_chroma_norm, X_mspec_norm = self.normalize_features(X_mfcc, X_chroma, X_mspec)
# Add batch dimension
X_mfcc_batch = np.expand_dims(X_mfcc_norm, axis=0)
X_chroma_batch = np.expand_dims(X_chroma_norm, axis=0)
X_mspec_batch = np.expand_dims(X_mspec_norm, axis=0)
# Make prediction
prediction_prob = self.model.predict([X_mfcc_batch, X_chroma_batch, X_mspec_batch], verbose=0)
prediction = int(np.argmax(prediction_prob[0]))
confidence = float(np.max(prediction_prob[0]))
# Get class name
class_name = self.class_names[prediction] if prediction < len(self.class_names) else f"Class {prediction}"
# Format results for Gradio
prediction_text = f"🎯 **Prediction**: {class_name}"
confidence_text = f"πŸ“Š **Confidence**: {confidence:.2%}"
# Create probabilities dictionary for all classes
probabilities_dict = {}
for i, (class_name_item, prob) in enumerate(zip(self.class_names, prediction_prob[0])):
probabilities_dict[class_name_item] = float(prob)
return prediction_text, confidence_text, probabilities_dict
except Exception as e:
error_msg = f"❌ Error processing audio: {str(e)}"
return error_msg, "", {}
# Initialize the predictor
print("Loading model and components...")
try:
predictor = RespiratoryPredictor()
print("βœ… All components loaded successfully!")
except Exception as e:
print(f"❌ Failed to initialize predictor: {e}")
raise
def predict_respiratory_sound(audio_file):
"""
Gradio interface function for respiratory sound prediction.
Args:
audio_file: Uploaded audio file from Gradio
Returns:
tuple: (prediction, confidence, probabilities)
"""
if audio_file is None:
return "⚠️ Please upload an audio file", "", {}
return predictor.predict_audio(audio_file)
# Create Gradio interface
with gr.Blocks(title="Respiratory Sound Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🫁 Respiratory Sound Classification
Upload an audio file containing respiratory sounds to classify the type of breathing pattern.
**Supported formats**: WAV, MP3, M4A, FLAC
**Duration**: Audio will be processed as 4-second segments
"""
)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="πŸ“€ Upload Respiratory Sound",
type="filepath",
sources=["upload"]
)
predict_btn = gr.Button("πŸ” Analyze Sound", variant="primary")
with gr.Column():
prediction_output = gr.Markdown(label="🎯 Prediction")
confidence_output = gr.Markdown(label="πŸ“Š Confidence")
probabilities_output = gr.Label(
label="πŸ“ˆ Class Probabilities",
num_top_classes=len(predictor.class_names)
)
# Event handlers
predict_btn.click(
fn=predict_respiratory_sound,
inputs=[audio_input],
outputs=[prediction_output, confidence_output, probabilities_output]
)
# Auto-predict when file is uploaded
audio_input.change(
fn=predict_respiratory_sound,
inputs=[audio_input],
outputs=[prediction_output, confidence_output, probabilities_output]
)
gr.Markdown(
"""
---
### ℹ️ About
This model classifies respiratory sounds into different categories.
Upload clear audio recordings of breathing sounds for best results.
**Note**: This is for research/educational purposes only and should not be used for medical diagnosis.
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch()