amp / streamlit_ui /model_deployment.py
magnumical's picture
Upload model_deployment.py
e549051 verified
import streamlit as st
import numpy as np
import os
import librosa
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
from tensorflow.keras.models import load_model
import logging
from prometheus_client import Counter, Histogram, start_http_server
import time
from scipy.signal import butter, sosfilt
import pandas as pd
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("audio_classifier_dep")
# Paths and Constants
MODEL_PATH = "./models"
MODELS = {
"binary": {
"augmented": "final_model_binary_augmented.h5",
"log_mel": "final_model_binary_log_mel.h5",
"mfcc": "final_model_binary_mfcc.h5",
},
"multi": {
"augmented": "final_model_multi_augmented.h5",
"log_mel": "final_model_multi_log_mel.h5",
"mfcc": "final_model_multi_mfcc.h5",
}
}
CLASS_NAMES = {
"binary": ["Abnormal", "Normal"],
"multi": ["Chronic Respiratory Diseases", "Normal", "Respiratory Infections"]
}
# Define Prometheus metrics
REQUEST_COUNT = Counter('audio_classifier_requests_total', 'Total number of requests to the classifier')
RESPONSE_TIME = Histogram('audio_classifier_response_time_seconds', 'Time taken to process requests')
ERROR_COUNT = Counter('audio_classifier_errors_total', 'Total number of errors during classification')
REQUEST_COUNT._value.set(0)
# Start Prometheus HTTP server
start_http_server(9100, addr="0.0.0.0")
individual_response_times = []
def filtering(audio, sr):
"""
Apply a bandpass filter to audio data.
Args:
audio: The input audio signal.
sr: The sampling rate of the audio.
Returns:
Filtered audio signal.
"""
# Define cutoff frequencies
low_cutoff = 50 # 50 Hz
high_cutoff = min(5000, sr / 2 - 1) # Ensure it is below Nyquist frequency
if low_cutoff >= high_cutoff:
raise ValueError(
f"Invalid filter range: low_cutoff={low_cutoff}, high_cutoff={high_cutoff} for sampling rate {sr}"
)
# Design a bandpass filter
sos = butter(N=10, Wn=[low_cutoff, high_cutoff], btype='band', fs=sr, output='sos')
# Apply the filter
filtered_audio = sosfilt(sos, audio)
return filtered_audio
def save_uploaded_file(uploaded_file):
"""Save the uploaded file temporarily."""
temp_file_path = os.path.join("temp_audio", uploaded_file.name)
os.makedirs("temp_audio", exist_ok=True)
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
return temp_file_path
def display_results(predicted_class, probabilities, model_type):
"""Display the classification results."""
class_label = CLASS_NAMES[model_type][predicted_class]
st.success(f"Classification Complete! Predicted Class: **{class_label}**")
st.write("### Prediction Probabilities")
class_probabilities = {
CLASS_NAMES[model_type][i]: prob for i, prob in enumerate(probabilities)
}
st.bar_chart(class_probabilities)
## Augmentation Functions
def add_noise(data, noise_factor=0.001):
noise = np.random.randn(len(data))
return data + noise_factor * noise
def shift(data, shift_factor=1600):
return np.roll(data, shift_factor)
def stretch(data, rate=1.2):
return librosa.effects.time_stretch(data, rate=rate)
def pitch_shift(data, sr, n_steps=3):
return librosa.effects.pitch_shift(data, sr=sr, n_steps=n_steps)
def preprocess_audio(audio_file, mode="augmented", input_shape=None):
"""
Preprocess an audio file for classification by resampling, padding/truncating,
and extracting features (e.g., MFCC, Log-Mel spectrogram, or Augmented features).
Args:
audio_file: Path to the audio file.
mode: Feature extraction mode ('mfcc', 'log_mel', or 'augmented').
input_shape: Expected input shape of the model for feature alignment.
Returns:
Extracted features as per the mode.
"""
try:
sr_new = 16000 # Resample audio to 16 kHz
x, sr = librosa.load(audio_file, sr=sr_new)
x = filtering(x, sr)
logger.info(f"Loaded audio file '{audio_file}' with shape {x.shape} and sampling rate {sr}.")
max_len = 5 * sr_new
if x.shape[0] < max_len:
x = np.pad(x, (0, max_len - x.shape[0]))
logger.info(f"Audio padded to {max_len} samples.")
else:
x = x[:max_len]
logger.info(f"Audio truncated to {max_len} samples.")
# Handle each mode separately
if mode == 'mfcc':
feature = librosa.feature.mfcc(y=x, sr=sr_new, n_mfcc=20) # Extract MFCC
feature = normalize(feature, axis=1)
elif mode == 'log_mel':
mel_spec = librosa.feature.melspectrogram(y=x, sr=sr_new, n_mels=20, fmax=8000)
feature = librosa.power_to_db(mel_spec, ref=np.max) # Extract Log-Mel spectrogram
feature = normalize(feature, axis=1)
elif mode == 'augmented':
features = []
# Base MFCC
base_mfcc = np.mean(librosa.feature.mfcc(y=x, sr=sr_new, n_mfcc=52).T, axis=0)
features.append(base_mfcc)
# Augmented features
for augmentation in [
lambda d: add_noise(d, 0.001),
lambda d: shift(d, 1600),
lambda d: stretch(d, 1.2),
lambda d: pitch_shift(d, sr_new, 3)
]:
augmented_data = augmentation(x)
aug_mfcc = np.mean(librosa.feature.mfcc(y=augmented_data, sr=sr_new, n_mfcc=52).T, axis=0)
features.append(aug_mfcc)
# Average augmented features
feature = np.mean(features, axis=0)
feature = normalize(feature.reshape(1, -1), axis=1).flatten() # Normalize
else:
raise ValueError(f"Unknown mode: {mode}")
# Reshape for model input if required
if input_shape:
feature = _reshape_feature(feature, input_shape)
logger.info(f"Feature extracted with shape {feature.shape}.")
return np.expand_dims(feature, axis=-1) # Add channel dimension
except Exception as e:
logger.error(f"Error in preprocessing audio: {e}")
raise
def _reshape_feature(feature, input_shape):
"""
Reshape the feature to match the expected input shape of the model.
Args:
feature: The extracted feature.
input_shape: The expected input shape of the model.
Returns:
Reshaped feature.
"""
expected_time_frames = input_shape[1]
if len(feature) > expected_time_frames:
feature = feature[:expected_time_frames]
elif len(feature) < expected_time_frames:
feature = np.pad(feature, (0, expected_time_frames - len(feature)))
return feature
def classify_audio(model_type, feature_type, file_path):
"""
Classify an audio file using the specified model.
Args:
model_type: Type of model ('binary' or 'multi').
feature_type: Type of feature extraction ('mfcc', 'log_mel', or 'augmented').
file_path: Path to the audio file.
Returns:
Predicted class and prediction probabilities.
"""
if model_type not in MODELS or feature_type not in MODELS[model_type]:
raise ValueError(f"Invalid combination of model type and feature type: {model_type}, {feature_type}")
# Load the correct model based on the type and feature
model_file = os.path.join(MODEL_PATH, MODELS[model_type][feature_type])
if not os.path.exists(model_file):
raise FileNotFoundError(f"Model file not found: {model_file}")
logger.info(f"Loading model from {model_file} for feature type '{feature_type}' and model type '{model_type}'...")
model = load_model(model_file)
# Get input shape from the model
input_shape = model.input_shape
# Preprocess audio
processed_audio = preprocess_audio(file_path, mode=feature_type, input_shape=input_shape)
# Add batch dimension
processed_audio = np.expand_dims(processed_audio, axis=0)
# Predict
predictions = model.predict(processed_audio)
predicted_class = np.argmax(predictions, axis=1)[0]
probabilities = predictions[0].tolist()
logger.info(f"Prediction complete. Predicted class: {predicted_class}, Probabilities: {probabilities}")
return predicted_class, probabilities
def classify_audio_with_metrics(model_type, feature_type, file_path):
global individual_response_times
logger.info("Audio classification request received.")
REQUEST_COUNT.inc()
start_time = time.time()
try:
result = classify_audio(model_type, feature_type, file_path)
return result
except Exception as e:
ERROR_COUNT.inc()
logger.error("Error during classification: %s", e)
raise
finally:
response_time = time.time() - start_time
RESPONSE_TIME.observe(response_time)
individual_response_times.append(response_time)
logger.info("Request processed. Response time: %.3f seconds", response_time)
def run():
st.title("Respiratory Sound Classifier: Inference and Deployment")
st.markdown("""
Welcome to the **Inference and Deployment** page! This tool allows you to classify respiratory sounds
into various categories using pre-trained models. Choose one of the two modes below based on your needs:
- **Quick Multiclass Mode:** A fast and straightforward way to classify audio files using a multiclass model with augmented features.
- **Flexible Mode:** Customize the classification process by selecting your preferred model type (binary/multi) and feature type (MFCC, Log-Mel, or Augmented).
- **Metrics Dashboard:** Monitor live metrics including request counts, response times, and error rates.
""")
# Tabs for three modes
tab1, tab2, tab3 = st.tabs(["Quick Multiclass Mode", "Flexible Mode", "Metrics Dashboard"])
# Tab 1: Quick Multiclass (Augmented) Mode
with tab1:
st.subheader("Quick Multiclass (Augmented) Mode")
st.markdown("""
This mode is optimized for quick classification of respiratory sounds into multiple categories
(e.g., Chronic Respiratory Diseases, Normal, Respiratory Infections). It automatically uses the
multiclass model with augmented features for robust and accurate results.
""")
uploaded_file = st.file_uploader(
"Upload an Audio File for Multiclass Classification",
type=["wav", "mp3"],
help="Supported formats: WAV, MP3",
)
if uploaded_file is not None:
temp_file_path = save_uploaded_file(uploaded_file)
st.audio(temp_file_path, format="audio/wav", start_time=0)
try:
with st.spinner("Classifying the audio file, please wait..."):
predicted_class, probabilities = classify_audio_with_metrics(
model_type="multi", feature_type="augmented", file_path=temp_file_path
)
# Display results
display_results(predicted_class, probabilities, "multi")
except Exception as e:
st.error(f"Error: {e}")
finally:
os.remove(temp_file_path)
# Tab 2: Flexible Mode
with tab2:
st.subheader("Flexible Mode")
st.markdown("""
The Flexible Mode gives you control over the classification process. Select the model type
(binary or multiclass) and the feature type (MFCC, Log-Mel, or Augmented) to suit your specific requirements.
""")
model_type = st.selectbox(
"Select Model Type",
["binary", "multi"],
help="Choose between binary or multi-class classification.",
)
feature_type = st.selectbox(
"Select Feature Type",
["mfcc", "log_mel", "augmented"],
help="Choose the feature extraction type.",
)
uploaded_file = st.file_uploader(
"Upload an Audio File",
type=["wav", "mp3"],
help="Supported formats: WAV, MP3",
)
if uploaded_file is not None:
temp_file_path = save_uploaded_file(uploaded_file)
st.audio(temp_file_path, format="audio/wav", start_time=0)
try:
with st.spinner("Classifying the audio file, please wait..."):
predicted_class, probabilities = classify_audio_with_metrics(
model_type, feature_type, temp_file_path
)
# Display results
display_results(predicted_class, probabilities, model_type)
except Exception as e:
st.error(f"Error: {e}")
finally:
os.remove(temp_file_path)
# Tab 3: Metrics Dashboard
with tab3:
st.subheader("Metrics Dashboard")
st.markdown("""
This dashboard shows live metrics for the application, including request counts, response times,
and error counts. These metrics are tracked internally and updated in real-time.
""")
col1, col2, col3 = st.columns(3)
col1.metric("Total Requests", REQUEST_COUNT._value.get())
col2.metric("Total Errors", ERROR_COUNT._value.get())
if individual_response_times:
avg_response_time = sum(individual_response_times) / len(individual_response_times)
else:
avg_response_time = 0
col3.metric("Avg Response Time (s)", f"{avg_response_time:.3f}")
st.markdown("### Individual Response Times")
if individual_response_times:
df = pd.DataFrame({
"Request Index": range(1, len(individual_response_times) + 1),
"Response Time (s)": individual_response_times
})
st.dataframe(df)
else:
st.warning("No response time data available.")
if __name__ == "__main__":
run()