amp / streamlit_ui /model_performance.py
magnumical's picture
Upload 103 files
9c6b905 verified
import streamlit as st
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
)
import numpy as np
import os
# Paths and Constants
MODEL_PATH = "./models"
DATASET_PATH = "./processed_datasets"
SUMMARY_FILE = "./streamlit_ui/model_evaluation_summary.csv"
MODELS = [
"final_model_binary_augmented.h5",
"final_model_binary_log_mel.h5",
"final_model_binary_mfcc.h5",
"final_model_multi_augmented.h5",
"final_model_multi_log_mel.h5",
"final_model_multi_mfcc.h5"
]
DATASETS = {
"binary_augmented": ("X_test_binary_augmented.npy", "y_test_binary_augmented.npy"),
"binary_log_mel": ("X_test_binary_log_mel.npy", "y_test_binary_log_mel.npy"),
"binary_mfcc": ("X_test_binary_mfcc.npy", "y_test_binary_mfcc.npy"),
"multi_augmented": ("X_test_multi_augmented.npy", "y_test_multi_augmented.npy"),
"multi_log_mel": ("X_test_multi_log_mel.npy", "y_test_multi_log_mel.npy"),
"multi_mfcc": ("X_test_multi_mfcc.npy", "y_test_multi_mfcc.npy")
}
# Function to evaluate models
def evaluate_models():
metrics_dict = []
for model_name in MODELS:
mode_key = model_name.replace("final_model_", "").replace(".h5", "").replace(" ", "_").lower()
dataset = DATASETS.get(mode_key)
if dataset:
model_path = os.path.join(MODEL_PATH, model_name)
if not os.path.exists(model_path):
st.error(f"Model file not found: {model_path}")
continue
try:
model = load_model(model_path)
except Exception as e:
st.error(f"Error loading model {model_name}: {e}")
continue
X_test_path, y_test_path = dataset
try:
X_test = np.load(os.path.join(DATASET_PATH, X_test_path))
y_test = np.load(os.path.join(DATASET_PATH, y_test_path))
except Exception as e:
st.error(f"Error loading dataset for {mode_key}: {e}")
continue
# Reshape X_test if needed
if len(X_test.shape) == 2 and model.input_shape[1:] == (X_test.shape[1], 1):
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)
if X_test.shape[1:] != model.input_shape[1:]:
st.error(f"Shape mismatch: Model {model_name} expects {model.input_shape[1:]}, but dataset has {X_test.shape[1:]}")
continue
try:
y_pred_prob = model.predict(X_test)
if len(y_test.shape) > 1: # Multi-class, one-hot encoded
y_true = np.argmax(y_test, axis=1)
else: # Binary classification
y_true = y_test
if y_pred_prob.shape[1] > 1: # Multi-class classification
y_pred = np.argmax(y_pred_prob, axis=1)
else: # Binary classification, use threshold
y_pred = (y_pred_prob > 0.5).astype(int)
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')
# For multi-class models, ensure AUC is computed correctly
if y_pred_prob.shape[1] > 1: # Multi-class
auc = roc_auc_score(y_test, y_pred_prob, multi_class='ovr')
else: # Binary classification
auc = roc_auc_score(y_true, y_pred_prob)
metrics_dict.append({
"Model": mode_key,
"Accuracy": accuracy,
"Precision": precision,
"Recall": recall,
"F1 Score": f1,
"ROC-AUC": auc
})
except Exception as e:
st.error(f"Error during evaluation for model {model_name}: {e}")
continue
metrics_df = pd.DataFrame(metrics_dict)
metrics_df.to_csv(SUMMARY_FILE, index=False)
return metrics_df
# Streamlit App Function
def run():
st.header("Model Performance")
st.subheader("Model Training Scheme")
st.write(""" Below is the illusteration of the model training and evaluation framework. The framework includes the following steps:
""")
st.image("./streamlit_ui/img/training.png", caption="Model Training and Evaluation Framework")
st.subheader("Types of Models and Input Features")
st.write("""The workflow supports six distinct models, combining two classification settings (binary and multi-class) with three feature extraction modes (MFCC, Log-Mel Spectrogram, and Augmented Features).
Binary classification models distinguish between Normal and Abnormal, while multi-class models categorize audio into broader conditions like Normal, Chronic Respiratory Diseases, and Respiratory Infections.
MFCC and Log-Mel features require 2D CNN architectures due to their structured, matrix-like representations, whereas augmented features leverage 1D CNNs to handle diverse feature sets.
""")
st.subheader("Feature Types and Their Characteristics")
st.markdown("""
| **Feature Mode** | **Definition** | **Representation Type** | **Model Type Required** |
|-------------------|-----------------------------------------------------------------------------------------------------------|--------------------------|--------------------------|
| **MFCC** | Represents short-term power spectrum using the Mel scale, emphasizing human auditory perception. | 2D Matrix | 2D CNN |
| **Log-Mel** | A spectrogram emphasizing quieter sounds by applying a logarithmic scale to Mel frequencies. | 2D Matrix | 2D CNN |
| **Augmented** | Features derived from MFCCs with added variability through augmentation techniques such as noise addition, pitch shifting, and time stretching. | 1D Sequence | 1D CNN |
""")
st.write("""
Using multiple feature types allows models to leverage diverse characteristics of the audio data, ensuring better generalization and robustness:
- **MFCC:** Captures the most relevant features of human auditory perception.
- **Log-Mel:** Highlights quieter components, providing a balanced representation of both loud and soft sounds.
- **Augmented:** Adds variability to training data, improving robustness and reducing overfitting risks.
""")
st.subheader("Training, Optimization, and Evaluation")
st.markdown("""
### **Model Building Process**
The model building process was built to handle both binary and multi-class classification.
The architecture is designed dynamically to adapt to the input features and classification requirements.
- **Binary Classification:** The output layer uses a single neuron with a sigmoid activation function.
- **Multi-Class Classification:** The output layer uses softmax activation with a number of neurons equal to the classes.
---
### **Why Both Binary and Multi-Class?**
My goal was creating a `Comprehensive Testing Framework` that we could test different types of input, models, and outputs to see which one suits well!
This is usually what I do in my projects, so the overall idea is creating a framework that you can subtitute its compoionents. For example, you can implement **Audio Spectrogram Transformer** and add it to the testing framework easily.
- **Binary Classification:** Focuses on distinguishing Normal from Abnormal cases.
This approach simplifies the task --> but it faces overfitting and generalization issues.
- **Multi-Class Classification:** Groups Abnormal cases into specific categories (e.g., Chronic Respiratory Diseases, Respiratory Infections).
Due to data imbalance, some groups with fewer than five cases were merged together. I searched about disease and grouped them based on some papers that I read. This granularity was between binary and all-class classification.
---
### **Training Scheme**
Models are trained with:
- **Loss Functions:**
- `binary_crossentropy` for binary classification.
- `categorical_crossentropy` for multi-class classification.
- **Regularization:** Dropout layers mitigate overfitting.
- **Optimizer:** Adamax is chosen for its adaptability to varying learning rates.
- **Early Stopping:** Monitors validation loss and halts training if no improvement is observed after five consecutive epochs.
---
### **Hyperparameter Optimization**
Optimization with Optuna tunes:
- Number of filters in convolutional layers.
- Units in dense layers.
- Dropout rates.
- Learning rates.
---
### **Evaluation Landscape**
Models are evaluated using:
- **Validation Accuracy:** Ensures performance during training.
- **Test Metrics:** Precision, recall, F1-score, confusion matrix, and ROC-AUC.
---
### **Model Selection**
The best model for each feature mode and classification task is selected based on:
- Performance on unseen test data.
- Robustness to imbalances and augmentation artifacts.
---
### **Challenges Addressed**
- Balancing complexity (multi-class) with generalizability (binary) --> sice data is highly imbalanced!
- Handling diverse feature representations and their impact on model performance.
- Making the framework modifiable for future.
""")
st.subheader("Model Names and Configurations")
st.markdown("""
To keep the visualizations concise, the models are referred to using short names in the images. Below is the mapping between these names and their configurations:
### **Binary Classification Models**
- **binary_augmented:** Uses augmented features.
Model file: `final_model_binary_augmented.h5`
- **binary_log_mel:** Uses Log-Mel Spectrogram features.
Model file: `final_model_binary_log_mel.h5`
- **binary_mfcc:** Uses MFCC features.
Model file: `final_model_binary_mfcc.h5`
Classes:
- **Abnormal**
- **Normal**
---
### **Multi-Class Classification Models**
- **multi_augmented:** Uses augmented features.
Model file: `final_model_multi_augmented.h5`
- **multi_log_mel:** Uses Log-Mel Spectrogram features.
Model file: `final_model_multi_log_mel.h5`
- **multi_mfcc:** Uses MFCC features.
Model file: `final_model_multi_mfcc.h5`
Classes:
- **Chronic Respiratory Diseases**
- **Normal**
- **Respiratory Infections**
""")
# Check if the summary file exists
if os.path.exists(SUMMARY_FILE):
st.info("Loading precomputed metrics from summary file.")
metrics_df = pd.read_csv(SUMMARY_FILE)
else:
# Add a loading spinner while the evaluation is running
with st.spinner("Processing models and generating metrics, please wait..."):
metrics_df = evaluate_models()
st.success("Model evaluation completed and saved!")
st.subheader("Model Metrics")
st.dataframe(metrics_df)
st.subheader("Metrics Visualization")
if not metrics_df.empty:
metric_to_plot = st.selectbox("Select Metric to Visualize", metrics_df.columns[1:])
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(data=metrics_df, x="Model", y=metric_to_plot, ax=ax, palette="viridis",legend=False)
ax.set_title(f"Model {metric_to_plot}", fontsize=16, fontweight='bold')
ax.set_xlabel("Model", fontsize=12)
ax.set_ylabel(metric_to_plot, fontsize=12)
st.pyplot(fig)
# Additional Image Types
st.subheader("Additional Visualizations")
st.text("Below are other types of visualizations for analysis:")
# Heatmap of Metrics
st.subheader("Metrics Heatmap")
fig, ax = plt.subplots(figsize=(10, 6))
sns.heatmap(metrics_df.set_index("Model"), annot=True, fmt=".2f", cmap="coolwarm", ax=ax)
ax.set_title("Heatmap of Model Metrics", fontsize=16, fontweight='bold')
st.pyplot(fig)
# Pairplot of Metrics
st.subheader("Metrics Pairplot")
if len(metrics_df) > 1: # Pairplot requires at least two rows
fig = sns.pairplot(metrics_df.drop("Model", axis=1))
st.pyplot(fig=fig.fig)