Spaces:
Runtime error
Runtime error
Tanish Mantri commited on
Commit ·
c3893d5
1
Parent(s): 65cff80
Add COVID detection app (model to be uploaded via web)
Browse files- .gitignore +1 -0
- app.py +11 -0
- demo.py +356 -0
- models/label_encoder.pkl +0 -0
- requirements_hf.txt +8 -0
- src/__init__.py +25 -0
- src/baseline_models.py +319 -0
- src/dataset.py +287 -0
- src/deep_learning_models.py +392 -0
- src/evaluation.py +345 -0
- src/preprocessing.py +281 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
models/*.keras
|
app.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Space deployment file
|
| 3 |
+
"""
|
| 4 |
+
import demo
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
interface = demo.create_demo_interface(
|
| 8 |
+
model_path="models/respiratory_cnn_best.keras",
|
| 9 |
+
model_type="deep"
|
| 10 |
+
)
|
| 11 |
+
interface.launch()
|
demo.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio web interface for respiratory disease detection.
|
| 3 |
+
Upload audio and get real-time predictions.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import joblib
|
| 11 |
+
from tensorflow import keras
|
| 12 |
+
|
| 13 |
+
# Add src to path
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent / 'src'))
|
| 15 |
+
|
| 16 |
+
from preprocessing import AudioPreprocessor
|
| 17 |
+
from evaluation import AudioVisualizer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RespiratoryDiseasePredictor:
|
| 21 |
+
"""Predictor wrapper for the demo interface."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model_path: str, model_type: str = 'baseline'):
|
| 24 |
+
"""
|
| 25 |
+
Initialize predictor.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model_path: Path to saved model
|
| 29 |
+
model_type: Type of model ('baseline' or 'deep')
|
| 30 |
+
"""
|
| 31 |
+
self.model_type = model_type
|
| 32 |
+
self.preprocessor = AudioPreprocessor(sample_rate=16000, duration=5.0)
|
| 33 |
+
|
| 34 |
+
# Load model
|
| 35 |
+
if model_type == 'baseline':
|
| 36 |
+
model_data = joblib.load(model_path)
|
| 37 |
+
self.model = model_data['model']
|
| 38 |
+
self.scaler = model_data['scaler']
|
| 39 |
+
else:
|
| 40 |
+
self.model = keras.models.load_model(model_path)
|
| 41 |
+
self.scaler = None
|
| 42 |
+
|
| 43 |
+
# Load label encoder
|
| 44 |
+
import pickle
|
| 45 |
+
label_encoder_path = Path(model_path).parent / 'label_encoder.pkl'
|
| 46 |
+
|
| 47 |
+
if label_encoder_path.exists():
|
| 48 |
+
with open(label_encoder_path, 'rb') as f:
|
| 49 |
+
self.label_encoder = pickle.load(f)
|
| 50 |
+
self.class_names = list(self.label_encoder.classes_)
|
| 51 |
+
print(f"Loaded class names: {self.class_names}")
|
| 52 |
+
else:
|
| 53 |
+
print(f"⚠️ Warning: label_encoder.pkl not found at {label_encoder_path}")
|
| 54 |
+
print("Using default class names")
|
| 55 |
+
# Try to infer from data directory
|
| 56 |
+
data_dir = Path('data/raw')
|
| 57 |
+
if data_dir.exists():
|
| 58 |
+
self.class_names = sorted([d.name for d in data_dir.iterdir() if d.is_dir()])
|
| 59 |
+
print(f"Inferred class names from data directory: {self.class_names}")
|
| 60 |
+
else:
|
| 61 |
+
self.class_names = None
|
| 62 |
+
|
| 63 |
+
print(f"Model loaded: {model_type}")
|
| 64 |
+
if self.class_names:
|
| 65 |
+
print(f"Classes: {self.class_names}")
|
| 66 |
+
|
| 67 |
+
def _format_disease_name(self, class_name: str) -> str:
|
| 68 |
+
"""Convert class name to human-readable disease name."""
|
| 69 |
+
disease_map = {
|
| 70 |
+
'covid': 'COVID-19',
|
| 71 |
+
'healthy': 'Healthy (No Disease)',
|
| 72 |
+
'symptomatic': 'Symptomatic (Non-COVID)',
|
| 73 |
+
'asthma': 'Asthma',
|
| 74 |
+
'unknown': 'Unknown Condition'
|
| 75 |
+
}
|
| 76 |
+
return disease_map.get(class_name.lower(), class_name.title())
|
| 77 |
+
|
| 78 |
+
def _format_diagnosis(self, predicted_class: str, confidence: float) -> str:
|
| 79 |
+
"""Format diagnosis result in human-readable format."""
|
| 80 |
+
disease_name = self._format_disease_name(predicted_class)
|
| 81 |
+
|
| 82 |
+
# Special handling for healthy vs disease
|
| 83 |
+
if predicted_class.lower() == 'healthy':
|
| 84 |
+
result = f"✅ NO DISEASE DETECTED\n\n"
|
| 85 |
+
result += f"Result: {disease_name}\n"
|
| 86 |
+
result += f"Confidence: {confidence:.1f}%\n\n"
|
| 87 |
+
result += "Your respiratory sounds appear normal and healthy."
|
| 88 |
+
elif predicted_class.lower() == 'covid':
|
| 89 |
+
result = f"⚠️ COVID-19 DETECTED\n\n"
|
| 90 |
+
result += f"Result: {disease_name}\n"
|
| 91 |
+
result += f"Confidence: {confidence:.1f}%\n\n"
|
| 92 |
+
result += "Indicators of COVID-19 detected in respiratory sounds.\n"
|
| 93 |
+
result += "⚠️ Please consult a healthcare professional and get tested."
|
| 94 |
+
elif predicted_class.lower() == 'symptomatic':
|
| 95 |
+
result = f"⚠️ RESPIRATORY SYMPTOMS DETECTED\n\n"
|
| 96 |
+
result += f"Result: {disease_name}\n"
|
| 97 |
+
result += f"Confidence: {confidence:.1f}%\n\n"
|
| 98 |
+
result += "Respiratory symptoms detected, but not specifically COVID-19.\n"
|
| 99 |
+
result += "Please consult a healthcare professional if symptoms persist."
|
| 100 |
+
else:
|
| 101 |
+
result = f"🔍 {disease_name.upper()} DETECTED\n\n"
|
| 102 |
+
result += f"Result: {disease_name}\n"
|
| 103 |
+
result += f"Confidence: {confidence:.1f}%\n\n"
|
| 104 |
+
result += "Please consult a healthcare professional for proper diagnosis."
|
| 105 |
+
|
| 106 |
+
return result
|
| 107 |
+
|
| 108 |
+
def predict_audio(self, audio_path: str):
|
| 109 |
+
"""
|
| 110 |
+
Predict disease from audio file.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
audio_path: Path to audio file
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Dictionary with predictions and probabilities
|
| 117 |
+
"""
|
| 118 |
+
try:
|
| 119 |
+
# Load and preprocess audio
|
| 120 |
+
audio = self.preprocessor.load_audio(audio_path)
|
| 121 |
+
|
| 122 |
+
if self.model_type == 'baseline':
|
| 123 |
+
# Extract MFCC and compute statistics
|
| 124 |
+
mfcc = self.preprocessor.extract_mfcc(audio)
|
| 125 |
+
features = self.preprocessor.compute_statistics(mfcc)
|
| 126 |
+
features = features.reshape(1, -1)
|
| 127 |
+
|
| 128 |
+
# Scale features
|
| 129 |
+
if self.scaler:
|
| 130 |
+
features = self.scaler.transform(features)
|
| 131 |
+
|
| 132 |
+
# Predict
|
| 133 |
+
prediction = self.model.predict(features)[0]
|
| 134 |
+
probabilities = self.model.predict_proba(features)[0]
|
| 135 |
+
|
| 136 |
+
else: # deep learning
|
| 137 |
+
# Extract MFCC for deep learning
|
| 138 |
+
mfcc = self.preprocessor.extract_mfcc(audio)
|
| 139 |
+
features = np.expand_dims(mfcc, axis=0)
|
| 140 |
+
features = np.expand_dims(features, axis=-1)
|
| 141 |
+
|
| 142 |
+
# Predict
|
| 143 |
+
raw_output = self.model.predict(features, verbose=0)[0]
|
| 144 |
+
|
| 145 |
+
# Handle binary classification (output shape: (1,) or (2,))
|
| 146 |
+
if len(raw_output) == 1:
|
| 147 |
+
# FLIPPED: Training data has labels reversed
|
| 148 |
+
# High output = COVID, Low output = Healthy
|
| 149 |
+
prob_covid = float(raw_output[0])
|
| 150 |
+
prob_healthy = 1.0 - prob_covid
|
| 151 |
+
|
| 152 |
+
# Create probability array matching class order ['covid', 'healthy']
|
| 153 |
+
probabilities = np.array([prob_covid, prob_healthy])
|
| 154 |
+
|
| 155 |
+
# Adjusted threshold: require 65% confidence for COVID detection
|
| 156 |
+
# This reduces false positives (healthy flagged as COVID)
|
| 157 |
+
prediction = int(prob_covid < 0.65) # 1 if healthy, 0 if covid
|
| 158 |
+
else:
|
| 159 |
+
# Multi-class output
|
| 160 |
+
probabilities = raw_output
|
| 161 |
+
prediction = np.argmax(probabilities)
|
| 162 |
+
|
| 163 |
+
# Format results with human-readable output
|
| 164 |
+
if self.class_names:
|
| 165 |
+
predicted_class = self.class_names[prediction]
|
| 166 |
+
confidence = float(probabilities[prediction]) * 100
|
| 167 |
+
|
| 168 |
+
# Create human-readable result
|
| 169 |
+
result_text = self._format_diagnosis(predicted_class, confidence)
|
| 170 |
+
|
| 171 |
+
# Add debug info
|
| 172 |
+
result_text += f"\n\n[Debug Info]\n"
|
| 173 |
+
result_text += f"Raw model output: {probabilities}\n"
|
| 174 |
+
result_text += f"Prediction index: {prediction}\n"
|
| 175 |
+
result_text += f"Audio shape: {audio.shape}, MFCC shape: {mfcc.shape}\n"
|
| 176 |
+
|
| 177 |
+
# Format probabilities with disease names
|
| 178 |
+
prob_dict = {
|
| 179 |
+
self._format_disease_name(self.class_names[i]): float(probabilities[i])
|
| 180 |
+
for i in range(len(self.class_names))
|
| 181 |
+
}
|
| 182 |
+
else:
|
| 183 |
+
predicted_class = f"Class {prediction}"
|
| 184 |
+
result_text = f"Predicted: {predicted_class}"
|
| 185 |
+
prob_dict = {f"Class {i}": float(probabilities[i])
|
| 186 |
+
for i in range(len(probabilities))}
|
| 187 |
+
|
| 188 |
+
return result_text, prob_dict, audio, mfcc
|
| 189 |
+
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Error during prediction: {e}")
|
| 192 |
+
return None, None, None, None
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def create_demo_interface(model_path: str, model_type: str = 'baseline'):
|
| 196 |
+
"""
|
| 197 |
+
Create Gradio interface for the model.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
model_path: Path to trained model
|
| 201 |
+
model_type: Type of model ('baseline' or 'deep')
|
| 202 |
+
"""
|
| 203 |
+
# Initialize predictor
|
| 204 |
+
predictor = RespiratoryDiseasePredictor(model_path, model_type)
|
| 205 |
+
|
| 206 |
+
def predict(audio):
|
| 207 |
+
"""Prediction function for Gradio interface."""
|
| 208 |
+
if audio is None:
|
| 209 |
+
return "No audio provided", {}, None, None
|
| 210 |
+
|
| 211 |
+
# Handle both file path and tuple (sample_rate, audio_data)
|
| 212 |
+
if isinstance(audio, tuple):
|
| 213 |
+
# Gradio microphone input
|
| 214 |
+
import soundfile as sf
|
| 215 |
+
import tempfile
|
| 216 |
+
|
| 217 |
+
sr, audio_data = audio
|
| 218 |
+
# Save temporarily
|
| 219 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as f:
|
| 220 |
+
sf.write(f.name, audio_data, sr)
|
| 221 |
+
audio_path = f.name
|
| 222 |
+
else:
|
| 223 |
+
audio_path = audio
|
| 224 |
+
|
| 225 |
+
# Make prediction
|
| 226 |
+
predicted_class, probabilities, audio_signal, mfcc = predictor.predict_audio(audio_path)
|
| 227 |
+
|
| 228 |
+
if predicted_class is None:
|
| 229 |
+
return "Error processing audio", {}, None, None
|
| 230 |
+
|
| 231 |
+
# Create visualization
|
| 232 |
+
import matplotlib.pyplot as plt
|
| 233 |
+
import librosa.display
|
| 234 |
+
|
| 235 |
+
# Waveform
|
| 236 |
+
fig1, ax1 = plt.subplots(figsize=(10, 3))
|
| 237 |
+
time = np.arange(len(audio_signal)) / predictor.preprocessor.sample_rate
|
| 238 |
+
ax1.plot(time, audio_signal, linewidth=0.5)
|
| 239 |
+
ax1.set_xlabel('Time (s)')
|
| 240 |
+
ax1.set_ylabel('Amplitude')
|
| 241 |
+
ax1.set_title('Audio Waveform')
|
| 242 |
+
ax1.grid(True, alpha=0.3)
|
| 243 |
+
plt.tight_layout()
|
| 244 |
+
|
| 245 |
+
# MFCC visualization
|
| 246 |
+
fig2, ax2 = plt.subplots(figsize=(10, 4))
|
| 247 |
+
img = ax2.imshow(mfcc, aspect='auto', origin='lower', cmap='viridis')
|
| 248 |
+
ax2.set_xlabel('Time Frame')
|
| 249 |
+
ax2.set_ylabel('MFCC Coefficient')
|
| 250 |
+
ax2.set_title('MFCC Features')
|
| 251 |
+
plt.colorbar(img, ax=ax2, label='Coefficient Value')
|
| 252 |
+
plt.tight_layout()
|
| 253 |
+
|
| 254 |
+
return predicted_class, probabilities, fig1, fig2
|
| 255 |
+
|
| 256 |
+
# Create Gradio interface
|
| 257 |
+
demo = gr.Interface(
|
| 258 |
+
fn=predict,
|
| 259 |
+
inputs=[
|
| 260 |
+
gr.Audio(type="filepath", label="Upload Audio (Cough/Voice Recording)")
|
| 261 |
+
],
|
| 262 |
+
outputs=[
|
| 263 |
+
gr.Textbox(label="Diagnosis Result", lines=6),
|
| 264 |
+
gr.Label(label="Detailed Probabilities", num_top_classes=10),
|
| 265 |
+
gr.Plot(label="Audio Waveform"),
|
| 266 |
+
gr.Plot(label="MFCC Features")
|
| 267 |
+
],
|
| 268 |
+
title="Covid Detection AI",
|
| 269 |
+
description="""
|
| 270 |
+
⚠️ **AI IN DEVELOPMENT - NOT FOR MEDICAL USE** ⚠️
|
| 271 |
+
|
| 272 |
+
**IMPORTANT:** This AI system is currently under development and should NOT be used as a
|
| 273 |
+
substitute for professional medical diagnosis. If the system flags potential COVID-19 or
|
| 274 |
+
any respiratory condition, you MUST contact a healthcare professional immediately for
|
| 275 |
+
proper testing, diagnosis, and treatment.
|
| 276 |
+
|
| 277 |
+
---
|
| 278 |
+
|
| 279 |
+
Upload a cough, breath, or voice recording to detect potential respiratory diseases.
|
| 280 |
+
|
| 281 |
+
**Supported formats:** WAV, MP3, FLAC
|
| 282 |
+
""",
|
| 283 |
+
article="""
|
| 284 |
+
---
|
| 285 |
+
|
| 286 |
+
### ⚠️ CRITICAL MEDICAL DISCLAIMER ⚠️
|
| 287 |
+
|
| 288 |
+
**THIS AI IS IN ACTIVE DEVELOPMENT AND NOT APPROVED FOR MEDICAL USE**
|
| 289 |
+
|
| 290 |
+
- ❌ **DO NOT** use this tool to self-diagnose
|
| 291 |
+
- ❌ **DO NOT** use this as a replacement for COVID-19 testing
|
| 292 |
+
- ❌ **DO NOT** delay seeking medical care based on these results
|
| 293 |
+
- ✅ **ALWAYS** consult a healthcare professional if you have symptoms
|
| 294 |
+
- ✅ **ALWAYS** get proper medical testing if flagged for COVID-19
|
| 295 |
+
- ✅ **ALWAYS** follow official health guidelines and protocols
|
| 296 |
+
|
| 297 |
+
**If this system detects COVID-19 or any respiratory condition, immediately contact
|
| 298 |
+
your doctor or local health authority for proper testing and medical guidance.**
|
| 299 |
+
|
| 300 |
+
---
|
| 301 |
+
|
| 302 |
+
### How it works:
|
| 303 |
+
1. Upload an audio recording (cough, breath, or voice)
|
| 304 |
+
2. The AI extracts audio features (MFCCs - Mel-frequency cepstral coefficients)
|
| 305 |
+
3. The model predicts the likelihood of different respiratory conditions
|
| 306 |
+
4. Results show the predicted disease and confidence scores
|
| 307 |
+
|
| 308 |
+
### Model Information:
|
| 309 |
+
- **Model Type:** {model_type}
|
| 310 |
+
- **Audio Processing:** 16kHz sampling rate, 5-second segments
|
| 311 |
+
- **Features:** MFCC, spectral features, temporal features
|
| 312 |
+
- **Status:** Development prototype - not clinically validated
|
| 313 |
+
|
| 314 |
+
### Legal Disclaimer:
|
| 315 |
+
This tool is for educational and research purposes only. It is not a substitute for
|
| 316 |
+
professional medical advice, diagnosis, or treatment. The developers assume no
|
| 317 |
+
liability for any health decisions made based on this system's output.
|
| 318 |
+
""".format(model_type=model_type.upper()),
|
| 319 |
+
examples=[
|
| 320 |
+
# Add example audio files here if available
|
| 321 |
+
],
|
| 322 |
+
allow_flagging="never",
|
| 323 |
+
theme=gr.themes.Soft()
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
return demo
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def main():
|
| 330 |
+
"""Launch the demo interface."""
|
| 331 |
+
import argparse
|
| 332 |
+
|
| 333 |
+
parser = argparse.ArgumentParser(description='Launch Gradio demo for respiratory disease detection')
|
| 334 |
+
parser.add_argument('--model_path', type=str, required=True,
|
| 335 |
+
help='Path to trained model file')
|
| 336 |
+
parser.add_argument('--model_type', type=str, default='baseline',
|
| 337 |
+
choices=['baseline', 'deep'],
|
| 338 |
+
help='Type of model')
|
| 339 |
+
parser.add_argument('--share', action='store_true',
|
| 340 |
+
help='Create public link')
|
| 341 |
+
parser.add_argument('--port', type=int, default=7860,
|
| 342 |
+
help='Port to run the server on')
|
| 343 |
+
|
| 344 |
+
args = parser.parse_args()
|
| 345 |
+
|
| 346 |
+
# Create and launch interface
|
| 347 |
+
demo = create_demo_interface(args.model_path, args.model_type)
|
| 348 |
+
demo.launch(
|
| 349 |
+
share=args.share,
|
| 350 |
+
server_port=args.port,
|
| 351 |
+
server_name="0.0.0.0"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
if __name__ == "__main__":
|
| 356 |
+
main()
|
models/label_encoder.pkl
ADDED
|
Binary file (291 Bytes). View file
|
|
|
requirements_hf.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.1
|
| 2 |
+
tensorflow==2.15.0
|
| 3 |
+
librosa==0.10.1
|
| 4 |
+
numpy==1.24.3
|
| 5 |
+
scikit-learn==1.3.2
|
| 6 |
+
matplotlib==3.8.2
|
| 7 |
+
soundfile==0.12.1
|
| 8 |
+
joblib==1.3.2
|
src/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Respiratory Disease Detection AI Package
|
| 3 |
+
Detect respiratory diseases from voice, breath, and cough recordings.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = '1.0.0'
|
| 7 |
+
__author__ = 'Your Name'
|
| 8 |
+
|
| 9 |
+
from .preprocessing import AudioPreprocessor, AudioAugmenter
|
| 10 |
+
from .dataset import AudioDataset
|
| 11 |
+
from .baseline_models import BaselineModel, ModelComparison
|
| 12 |
+
from .deep_learning_models import CNNModel, LSTMModel
|
| 13 |
+
from .evaluation import ModelEvaluator, AudioVisualizer
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'AudioPreprocessor',
|
| 17 |
+
'AudioAugmenter',
|
| 18 |
+
'AudioDataset',
|
| 19 |
+
'BaselineModel',
|
| 20 |
+
'ModelComparison',
|
| 21 |
+
'CNNModel',
|
| 22 |
+
'LSTMModel',
|
| 23 |
+
'ModelEvaluator',
|
| 24 |
+
'AudioVisualizer',
|
| 25 |
+
]
|
src/baseline_models.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline machine learning models for respiratory disease detection.
|
| 3 |
+
Includes Random Forest, SVM, and other classical ML algorithms.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pickle
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Tuple, Optional
|
| 10 |
+
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
|
| 11 |
+
from sklearn.svm import SVC
|
| 12 |
+
from sklearn.linear_model import LogisticRegression
|
| 13 |
+
from sklearn.preprocessing import StandardScaler
|
| 14 |
+
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
|
| 15 |
+
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, roc_curve
|
| 16 |
+
import joblib
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BaselineModel:
|
| 20 |
+
"""Wrapper for baseline ML models."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, model_type: str = 'random_forest', **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
Initialize baseline model.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
model_type: Type of model ('random_forest', 'svm', 'logistic', 'gradient_boost')
|
| 28 |
+
**kwargs: Additional parameters for the model
|
| 29 |
+
"""
|
| 30 |
+
self.model_type = model_type
|
| 31 |
+
self.scaler = StandardScaler()
|
| 32 |
+
self.model = self._create_model(model_type, **kwargs)
|
| 33 |
+
self.is_fitted = False
|
| 34 |
+
|
| 35 |
+
def _create_model(self, model_type: str, **kwargs):
|
| 36 |
+
"""Create the specified model."""
|
| 37 |
+
if model_type == 'random_forest':
|
| 38 |
+
return RandomForestClassifier(
|
| 39 |
+
n_estimators=kwargs.get('n_estimators', 200),
|
| 40 |
+
max_depth=kwargs.get('max_depth', 20),
|
| 41 |
+
min_samples_split=kwargs.get('min_samples_split', 5),
|
| 42 |
+
min_samples_leaf=kwargs.get('min_samples_leaf', 2),
|
| 43 |
+
random_state=kwargs.get('random_state', 42),
|
| 44 |
+
n_jobs=kwargs.get('n_jobs', -1),
|
| 45 |
+
verbose=kwargs.get('verbose', 1)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
elif model_type == 'svm':
|
| 49 |
+
return SVC(
|
| 50 |
+
kernel=kwargs.get('kernel', 'rbf'),
|
| 51 |
+
C=kwargs.get('C', 1.0),
|
| 52 |
+
gamma=kwargs.get('gamma', 'scale'),
|
| 53 |
+
probability=True,
|
| 54 |
+
random_state=kwargs.get('random_state', 42),
|
| 55 |
+
verbose=kwargs.get('verbose', True)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
elif model_type == 'logistic':
|
| 59 |
+
return LogisticRegression(
|
| 60 |
+
max_iter=kwargs.get('max_iter', 1000),
|
| 61 |
+
C=kwargs.get('C', 1.0),
|
| 62 |
+
random_state=kwargs.get('random_state', 42),
|
| 63 |
+
n_jobs=kwargs.get('n_jobs', -1),
|
| 64 |
+
verbose=kwargs.get('verbose', 1)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
elif model_type == 'gradient_boost':
|
| 68 |
+
return GradientBoostingClassifier(
|
| 69 |
+
n_estimators=kwargs.get('n_estimators', 200),
|
| 70 |
+
learning_rate=kwargs.get('learning_rate', 0.1),
|
| 71 |
+
max_depth=kwargs.get('max_depth', 5),
|
| 72 |
+
random_state=kwargs.get('random_state', 42),
|
| 73 |
+
verbose=kwargs.get('verbose', 1)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 78 |
+
|
| 79 |
+
def train(self, X_train: np.ndarray, y_train: np.ndarray,
|
| 80 |
+
X_val: Optional[np.ndarray] = None, y_val: Optional[np.ndarray] = None):
|
| 81 |
+
"""
|
| 82 |
+
Train the model.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
X_train: Training features
|
| 86 |
+
y_train: Training labels
|
| 87 |
+
X_val: Validation features (optional)
|
| 88 |
+
y_val: Validation labels (optional)
|
| 89 |
+
"""
|
| 90 |
+
print(f"Training {self.model_type} model...")
|
| 91 |
+
print(f"Training samples: {len(X_train)}")
|
| 92 |
+
|
| 93 |
+
# Scale features
|
| 94 |
+
X_train_scaled = self.scaler.fit_transform(X_train)
|
| 95 |
+
|
| 96 |
+
# Train model
|
| 97 |
+
self.model.fit(X_train_scaled, y_train)
|
| 98 |
+
self.is_fitted = True
|
| 99 |
+
|
| 100 |
+
# Evaluate on training set
|
| 101 |
+
train_acc = self.model.score(X_train_scaled, y_train)
|
| 102 |
+
print(f"Training accuracy: {train_acc:.4f}")
|
| 103 |
+
|
| 104 |
+
# Evaluate on validation set if provided
|
| 105 |
+
if X_val is not None and y_val is not None:
|
| 106 |
+
X_val_scaled = self.scaler.transform(X_val)
|
| 107 |
+
val_acc = self.model.score(X_val_scaled, y_val)
|
| 108 |
+
print(f"Validation accuracy: {val_acc:.4f}")
|
| 109 |
+
|
| 110 |
+
def predict(self, X: np.ndarray) -> np.ndarray:
|
| 111 |
+
"""Make predictions."""
|
| 112 |
+
if not self.is_fitted:
|
| 113 |
+
raise ValueError("Model must be trained before making predictions")
|
| 114 |
+
|
| 115 |
+
X_scaled = self.scaler.transform(X)
|
| 116 |
+
return self.model.predict(X_scaled)
|
| 117 |
+
|
| 118 |
+
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
| 119 |
+
"""Get prediction probabilities."""
|
| 120 |
+
if not self.is_fitted:
|
| 121 |
+
raise ValueError("Model must be trained before making predictions")
|
| 122 |
+
|
| 123 |
+
X_scaled = self.scaler.transform(X)
|
| 124 |
+
return self.model.predict_proba(X_scaled)
|
| 125 |
+
|
| 126 |
+
def evaluate(self, X: np.ndarray, y: np.ndarray,
|
| 127 |
+
class_names: Optional[list] = None) -> Dict:
|
| 128 |
+
"""
|
| 129 |
+
Evaluate model performance.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
X: Test features
|
| 133 |
+
y: Test labels
|
| 134 |
+
class_names: List of class names for reporting
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Dictionary containing evaluation metrics
|
| 138 |
+
"""
|
| 139 |
+
if not self.is_fitted:
|
| 140 |
+
raise ValueError("Model must be trained before evaluation")
|
| 141 |
+
|
| 142 |
+
# Make predictions
|
| 143 |
+
y_pred = self.predict(X)
|
| 144 |
+
y_proba = self.predict_proba(X)
|
| 145 |
+
|
| 146 |
+
# Calculate metrics
|
| 147 |
+
accuracy = accuracy_score(y, y_pred)
|
| 148 |
+
precision, recall, f1, support = precision_recall_fscore_support(
|
| 149 |
+
y, y_pred, average='weighted'
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Confusion matrix
|
| 153 |
+
cm = confusion_matrix(y, y_pred)
|
| 154 |
+
|
| 155 |
+
# Classification report
|
| 156 |
+
report = classification_report(
|
| 157 |
+
y, y_pred,
|
| 158 |
+
target_names=class_names,
|
| 159 |
+
output_dict=True
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# ROC AUC (for binary or multi-class)
|
| 163 |
+
try:
|
| 164 |
+
if len(np.unique(y)) == 2:
|
| 165 |
+
roc_auc = roc_auc_score(y, y_proba[:, 1])
|
| 166 |
+
else:
|
| 167 |
+
roc_auc = roc_auc_score(y, y_proba, multi_class='ovr', average='weighted')
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f"Could not compute ROC AUC: {e}")
|
| 170 |
+
roc_auc = None
|
| 171 |
+
|
| 172 |
+
results = {
|
| 173 |
+
'accuracy': accuracy,
|
| 174 |
+
'precision': precision,
|
| 175 |
+
'recall': recall,
|
| 176 |
+
'f1_score': f1,
|
| 177 |
+
'confusion_matrix': cm,
|
| 178 |
+
'classification_report': report,
|
| 179 |
+
'roc_auc': roc_auc,
|
| 180 |
+
'predictions': y_pred,
|
| 181 |
+
'probabilities': y_proba
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
print("\n" + "="*50)
|
| 185 |
+
print(f"{self.model_type.upper()} EVALUATION RESULTS")
|
| 186 |
+
print("="*50)
|
| 187 |
+
print(f"Accuracy: {accuracy:.4f}")
|
| 188 |
+
print(f"Precision: {precision:.4f}")
|
| 189 |
+
print(f"Recall: {recall:.4f}")
|
| 190 |
+
print(f"F1 Score: {f1:.4f}")
|
| 191 |
+
if roc_auc is not None:
|
| 192 |
+
print(f"ROC AUC: {roc_auc:.4f}")
|
| 193 |
+
print("\nConfusion Matrix:")
|
| 194 |
+
print(cm)
|
| 195 |
+
print("\nClassification Report:")
|
| 196 |
+
if class_names:
|
| 197 |
+
print(classification_report(y, y_pred, target_names=class_names))
|
| 198 |
+
else:
|
| 199 |
+
print(classification_report(y, y_pred))
|
| 200 |
+
print("="*50 + "\n")
|
| 201 |
+
|
| 202 |
+
return results
|
| 203 |
+
|
| 204 |
+
def get_feature_importance(self, feature_names: Optional[list] = None) -> np.ndarray:
|
| 205 |
+
"""
|
| 206 |
+
Get feature importances (only for tree-based models).
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
feature_names: Optional list of feature names
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
Array of feature importances
|
| 213 |
+
"""
|
| 214 |
+
if not self.is_fitted:
|
| 215 |
+
raise ValueError("Model must be trained first")
|
| 216 |
+
|
| 217 |
+
if hasattr(self.model, 'feature_importances_'):
|
| 218 |
+
importances = self.model.feature_importances_
|
| 219 |
+
|
| 220 |
+
if feature_names:
|
| 221 |
+
importance_dict = dict(zip(feature_names, importances))
|
| 222 |
+
# Sort by importance
|
| 223 |
+
importance_dict = dict(sorted(
|
| 224 |
+
importance_dict.items(),
|
| 225 |
+
key=lambda x: x[1],
|
| 226 |
+
reverse=True
|
| 227 |
+
))
|
| 228 |
+
|
| 229 |
+
print("\nTop 10 Feature Importances:")
|
| 230 |
+
for i, (name, imp) in enumerate(list(importance_dict.items())[:10]):
|
| 231 |
+
print(f"{i+1}. {name}: {imp:.4f}")
|
| 232 |
+
|
| 233 |
+
return importances
|
| 234 |
+
else:
|
| 235 |
+
print(f"{self.model_type} does not support feature importances")
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
def save(self, filepath: str):
|
| 239 |
+
"""Save model to disk."""
|
| 240 |
+
if not self.is_fitted:
|
| 241 |
+
raise ValueError("Cannot save untrained model")
|
| 242 |
+
|
| 243 |
+
model_data = {
|
| 244 |
+
'model': self.model,
|
| 245 |
+
'scaler': self.scaler,
|
| 246 |
+
'model_type': self.model_type,
|
| 247 |
+
'is_fitted': self.is_fitted
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
joblib.dump(model_data, filepath)
|
| 251 |
+
print(f"Model saved to {filepath}")
|
| 252 |
+
|
| 253 |
+
@classmethod
|
| 254 |
+
def load(cls, filepath: str):
|
| 255 |
+
"""Load model from disk."""
|
| 256 |
+
model_data = joblib.load(filepath)
|
| 257 |
+
|
| 258 |
+
instance = cls(model_type=model_data['model_type'])
|
| 259 |
+
instance.model = model_data['model']
|
| 260 |
+
instance.scaler = model_data['scaler']
|
| 261 |
+
instance.is_fitted = model_data['is_fitted']
|
| 262 |
+
|
| 263 |
+
print(f"Model loaded from {filepath}")
|
| 264 |
+
return instance
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class ModelComparison:
|
| 268 |
+
"""Compare multiple baseline models."""
|
| 269 |
+
|
| 270 |
+
def __init__(self):
|
| 271 |
+
self.models = {}
|
| 272 |
+
self.results = {}
|
| 273 |
+
|
| 274 |
+
def add_model(self, name: str, model: BaselineModel):
|
| 275 |
+
"""Add a model to the comparison."""
|
| 276 |
+
self.models[name] = model
|
| 277 |
+
|
| 278 |
+
def train_all(self, X_train: np.ndarray, y_train: np.ndarray,
|
| 279 |
+
X_val: Optional[np.ndarray] = None, y_val: Optional[np.ndarray] = None):
|
| 280 |
+
"""Train all models."""
|
| 281 |
+
for name, model in self.models.items():
|
| 282 |
+
print(f"\n{'='*60}")
|
| 283 |
+
print(f"Training {name}")
|
| 284 |
+
print('='*60)
|
| 285 |
+
model.train(X_train, y_train, X_val, y_val)
|
| 286 |
+
|
| 287 |
+
def evaluate_all(self, X_test: np.ndarray, y_test: np.ndarray,
|
| 288 |
+
class_names: Optional[list] = None):
|
| 289 |
+
"""Evaluate all models."""
|
| 290 |
+
for name, model in self.models.items():
|
| 291 |
+
print(f"\n{'='*60}")
|
| 292 |
+
print(f"Evaluating {name}")
|
| 293 |
+
print('='*60)
|
| 294 |
+
results = model.evaluate(X_test, y_test, class_names)
|
| 295 |
+
self.results[name] = results
|
| 296 |
+
|
| 297 |
+
def print_summary(self):
|
| 298 |
+
"""Print comparison summary."""
|
| 299 |
+
print("\n" + "="*80)
|
| 300 |
+
print("MODEL COMPARISON SUMMARY")
|
| 301 |
+
print("="*80)
|
| 302 |
+
print(f"{'Model':<25} {'Accuracy':<12} {'Precision':<12} {'Recall':<12} {'F1 Score':<12}")
|
| 303 |
+
print("-"*80)
|
| 304 |
+
|
| 305 |
+
for name, results in self.results.items():
|
| 306 |
+
print(f"{name:<25} {results['accuracy']:<12.4f} {results['precision']:<12.4f} "
|
| 307 |
+
f"{results['recall']:<12.4f} {results['f1_score']:<12.4f}")
|
| 308 |
+
|
| 309 |
+
print("="*80 + "\n")
|
| 310 |
+
|
| 311 |
+
# Find best model
|
| 312 |
+
best_model = max(self.results.items(), key=lambda x: x[1]['f1_score'])
|
| 313 |
+
print(f"Best model: {best_model[0]} (F1 Score: {best_model[1]['f1_score']:.4f})")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
# Example usage
|
| 318 |
+
print("Baseline models module loaded successfully")
|
| 319 |
+
print("Available models: random_forest, svm, logistic, gradient_boost")
|
src/dataset.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset loading and management for respiratory disease detection.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Tuple, List, Dict, Optional
|
| 9 |
+
from sklearn.model_selection import train_test_split
|
| 10 |
+
from sklearn.preprocessing import LabelEncoder
|
| 11 |
+
import pickle
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from preprocessing import AudioPreprocessor, AudioAugmenter
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AudioDataset:
|
| 18 |
+
"""Manages audio dataset loading and feature extraction."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, data_dir: str, preprocessor: AudioPreprocessor):
|
| 21 |
+
"""
|
| 22 |
+
Initialize dataset.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
data_dir: Directory containing audio files organized by class
|
| 26 |
+
preprocessor: AudioPreprocessor instance
|
| 27 |
+
"""
|
| 28 |
+
self.data_dir = Path(data_dir)
|
| 29 |
+
self.preprocessor = preprocessor
|
| 30 |
+
self.augmenter = AudioAugmenter()
|
| 31 |
+
self.label_encoder = LabelEncoder()
|
| 32 |
+
|
| 33 |
+
self.X = None
|
| 34 |
+
self.y = None
|
| 35 |
+
self.labels = None
|
| 36 |
+
self.file_paths = []
|
| 37 |
+
|
| 38 |
+
def load_from_directory_structure(self, use_cache: bool = True):
|
| 39 |
+
"""
|
| 40 |
+
Load dataset from directory structure where each subdirectory is a class.
|
| 41 |
+
Expected structure:
|
| 42 |
+
data_dir/
|
| 43 |
+
healthy/
|
| 44 |
+
file1.wav
|
| 45 |
+
file2.wav
|
| 46 |
+
covid/
|
| 47 |
+
file1.wav
|
| 48 |
+
file2.wav
|
| 49 |
+
"""
|
| 50 |
+
cache_file = self.data_dir / 'dataset_cache.pkl'
|
| 51 |
+
|
| 52 |
+
if use_cache and cache_file.exists():
|
| 53 |
+
print("Loading from cache...")
|
| 54 |
+
with open(cache_file, 'rb') as f:
|
| 55 |
+
cache_data = pickle.load(f)
|
| 56 |
+
self.X = cache_data['X']
|
| 57 |
+
self.y = cache_data['y']
|
| 58 |
+
self.labels = cache_data['labels']
|
| 59 |
+
self.file_paths = cache_data['file_paths']
|
| 60 |
+
self.label_encoder = cache_data['label_encoder']
|
| 61 |
+
print(f"Loaded {len(self.X)} samples from cache")
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
# Get all class directories
|
| 65 |
+
class_dirs = [d for d in self.data_dir.iterdir() if d.is_dir()]
|
| 66 |
+
|
| 67 |
+
if not class_dirs:
|
| 68 |
+
raise ValueError(f"No subdirectories found in {self.data_dir}")
|
| 69 |
+
|
| 70 |
+
print(f"Found {len(class_dirs)} classes: {[d.name for d in class_dirs]}")
|
| 71 |
+
|
| 72 |
+
X_list = []
|
| 73 |
+
y_list = []
|
| 74 |
+
|
| 75 |
+
for class_dir in class_dirs:
|
| 76 |
+
class_name = class_dir.name
|
| 77 |
+
audio_files = list(class_dir.glob('*.wav')) + list(class_dir.glob('*.mp3'))
|
| 78 |
+
|
| 79 |
+
print(f"Processing {len(audio_files)} files from class '{class_name}'...")
|
| 80 |
+
|
| 81 |
+
for audio_file in tqdm(audio_files, desc=class_name):
|
| 82 |
+
try:
|
| 83 |
+
# Load audio
|
| 84 |
+
audio = self.preprocessor.load_audio(str(audio_file))
|
| 85 |
+
|
| 86 |
+
# Extract MFCC features
|
| 87 |
+
mfcc = self.preprocessor.extract_mfcc(audio)
|
| 88 |
+
|
| 89 |
+
# Compute statistical features
|
| 90 |
+
features = self.preprocessor.compute_statistics(mfcc)
|
| 91 |
+
|
| 92 |
+
X_list.append(features)
|
| 93 |
+
y_list.append(class_name)
|
| 94 |
+
self.file_paths.append(str(audio_file))
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"Error processing {audio_file}: {e}")
|
| 98 |
+
|
| 99 |
+
# Convert to numpy arrays
|
| 100 |
+
self.X = np.array(X_list)
|
| 101 |
+
self.labels = np.array(y_list)
|
| 102 |
+
|
| 103 |
+
# Encode labels
|
| 104 |
+
self.y = self.label_encoder.fit_transform(self.labels)
|
| 105 |
+
|
| 106 |
+
print(f"\nDataset loaded: {len(self.X)} samples, {len(np.unique(self.y))} classes")
|
| 107 |
+
print(f"Feature shape: {self.X.shape}")
|
| 108 |
+
print(f"Class distribution: {dict(zip(*np.unique(self.labels, return_counts=True)))}")
|
| 109 |
+
|
| 110 |
+
# Save cache
|
| 111 |
+
with open(cache_file, 'wb') as f:
|
| 112 |
+
pickle.dump({
|
| 113 |
+
'X': self.X,
|
| 114 |
+
'y': self.y,
|
| 115 |
+
'labels': self.labels,
|
| 116 |
+
'file_paths': self.file_paths,
|
| 117 |
+
'label_encoder': self.label_encoder
|
| 118 |
+
}, f)
|
| 119 |
+
print(f"Cache saved to {cache_file}")
|
| 120 |
+
|
| 121 |
+
def load_from_csv(self, csv_path: str, audio_column: str = 'file_path',
|
| 122 |
+
label_column: str = 'label'):
|
| 123 |
+
"""
|
| 124 |
+
Load dataset from CSV file with file paths and labels.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
csv_path: Path to CSV file
|
| 128 |
+
audio_column: Column name containing audio file paths
|
| 129 |
+
label_column: Column name containing labels
|
| 130 |
+
"""
|
| 131 |
+
df = pd.read_csv(csv_path)
|
| 132 |
+
|
| 133 |
+
print(f"Loading {len(df)} samples from CSV...")
|
| 134 |
+
|
| 135 |
+
X_list = []
|
| 136 |
+
y_list = []
|
| 137 |
+
|
| 138 |
+
for idx, row in tqdm(df.iterrows(), total=len(df)):
|
| 139 |
+
try:
|
| 140 |
+
audio_path = row[audio_column]
|
| 141 |
+
label = row[label_column]
|
| 142 |
+
|
| 143 |
+
# Make path absolute if relative
|
| 144 |
+
if not Path(audio_path).is_absolute():
|
| 145 |
+
audio_path = self.data_dir / audio_path
|
| 146 |
+
|
| 147 |
+
# Load audio
|
| 148 |
+
audio = self.preprocessor.load_audio(str(audio_path))
|
| 149 |
+
|
| 150 |
+
# Extract MFCC features
|
| 151 |
+
mfcc = self.preprocessor.extract_mfcc(audio)
|
| 152 |
+
|
| 153 |
+
# Compute statistical features
|
| 154 |
+
features = self.preprocessor.compute_statistics(mfcc)
|
| 155 |
+
|
| 156 |
+
X_list.append(features)
|
| 157 |
+
y_list.append(label)
|
| 158 |
+
self.file_paths.append(str(audio_path))
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f"Error processing row {idx}: {e}")
|
| 162 |
+
|
| 163 |
+
# Convert to numpy arrays
|
| 164 |
+
self.X = np.array(X_list)
|
| 165 |
+
self.labels = np.array(y_list)
|
| 166 |
+
|
| 167 |
+
# Encode labels
|
| 168 |
+
self.y = self.label_encoder.fit_transform(self.labels)
|
| 169 |
+
|
| 170 |
+
print(f"\nDataset loaded: {len(self.X)} samples, {len(np.unique(self.y))} classes")
|
| 171 |
+
print(f"Feature shape: {self.X.shape}")
|
| 172 |
+
|
| 173 |
+
def split_data(self, test_size: float = 0.15, val_size: float = 0.15,
|
| 174 |
+
random_state: int = 42) -> Dict[str, np.ndarray]:
|
| 175 |
+
"""
|
| 176 |
+
Split dataset into train, validation, and test sets.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
test_size: Proportion of data for test set
|
| 180 |
+
val_size: Proportion of data for validation set
|
| 181 |
+
random_state: Random seed for reproducibility
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Dictionary containing train, val, test splits
|
| 185 |
+
"""
|
| 186 |
+
# First split: separate test set
|
| 187 |
+
X_temp, X_test, y_temp, y_test = train_test_split(
|
| 188 |
+
self.X, self.y, test_size=test_size, random_state=random_state, stratify=self.y
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Second split: separate train and validation
|
| 192 |
+
val_size_adjusted = val_size / (1 - test_size)
|
| 193 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 194 |
+
X_temp, y_temp, test_size=val_size_adjusted, random_state=random_state, stratify=y_temp
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
print(f"Data split: Train={len(X_train)}, Val={len(X_val)}, Test={len(X_test)}")
|
| 198 |
+
|
| 199 |
+
return {
|
| 200 |
+
'X_train': X_train,
|
| 201 |
+
'X_val': X_val,
|
| 202 |
+
'X_test': X_test,
|
| 203 |
+
'y_train': y_train,
|
| 204 |
+
'y_val': y_val,
|
| 205 |
+
'y_test': y_test
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
def get_deep_learning_data(self) -> Tuple[np.ndarray, np.ndarray]:
|
| 209 |
+
"""
|
| 210 |
+
Load data formatted for deep learning (2D features).
|
| 211 |
+
Returns MFCC spectrograms instead of statistical features.
|
| 212 |
+
"""
|
| 213 |
+
X_deep = []
|
| 214 |
+
y_deep = []
|
| 215 |
+
|
| 216 |
+
print("Preparing deep learning data...")
|
| 217 |
+
|
| 218 |
+
for file_path, label in tqdm(zip(self.file_paths, self.labels), total=len(self.file_paths)):
|
| 219 |
+
try:
|
| 220 |
+
audio = self.preprocessor.load_audio(file_path)
|
| 221 |
+
mfcc = self.preprocessor.extract_mfcc(audio)
|
| 222 |
+
|
| 223 |
+
X_deep.append(mfcc)
|
| 224 |
+
y_deep.append(label)
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print(f"Error processing {file_path}: {e}")
|
| 228 |
+
|
| 229 |
+
X_deep = np.array(X_deep)
|
| 230 |
+
y_deep = self.label_encoder.transform(np.array(y_deep))
|
| 231 |
+
|
| 232 |
+
# Add channel dimension for CNN: (samples, height, width, channels)
|
| 233 |
+
X_deep = np.expand_dims(X_deep, axis=-1)
|
| 234 |
+
|
| 235 |
+
print(f"Deep learning data shape: {X_deep.shape}")
|
| 236 |
+
return X_deep, y_deep
|
| 237 |
+
|
| 238 |
+
def get_class_names(self) -> List[str]:
|
| 239 |
+
"""Get list of class names."""
|
| 240 |
+
return list(self.label_encoder.classes_)
|
| 241 |
+
|
| 242 |
+
def save_splits(self, splits: Dict[str, np.ndarray], output_dir: str):
|
| 243 |
+
"""Save data splits to disk."""
|
| 244 |
+
output_path = Path(output_dir)
|
| 245 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 246 |
+
|
| 247 |
+
for split_name, data in splits.items():
|
| 248 |
+
np.save(output_path / f"{split_name}.npy", data)
|
| 249 |
+
|
| 250 |
+
# Save label encoder
|
| 251 |
+
with open(output_path / 'label_encoder.pkl', 'wb') as f:
|
| 252 |
+
pickle.dump(self.label_encoder, f)
|
| 253 |
+
|
| 254 |
+
print(f"Splits saved to {output_dir}")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def create_sample_dataset_structure(base_dir: str):
|
| 258 |
+
"""
|
| 259 |
+
Create a sample dataset directory structure for testing.
|
| 260 |
+
This helps users understand the expected format.
|
| 261 |
+
"""
|
| 262 |
+
base_path = Path(base_dir)
|
| 263 |
+
|
| 264 |
+
# Create class directories
|
| 265 |
+
classes = ['healthy', 'covid', 'asthma']
|
| 266 |
+
|
| 267 |
+
for class_name in classes:
|
| 268 |
+
class_dir = base_path / class_name
|
| 269 |
+
class_dir.mkdir(parents=True, exist_ok=True)
|
| 270 |
+
|
| 271 |
+
print(f"Sample dataset structure created at {base_dir}")
|
| 272 |
+
print("Please place your audio files in the respective class folders:")
|
| 273 |
+
for class_name in classes:
|
| 274 |
+
print(f" - {base_dir}/{class_name}/")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
if __name__ == "__main__":
|
| 278 |
+
# Example usage
|
| 279 |
+
preprocessor = AudioPreprocessor(sample_rate=16000, duration=5.0)
|
| 280 |
+
|
| 281 |
+
# Create sample structure
|
| 282 |
+
# create_sample_dataset_structure('/Users/tan135/Desktop/Sid AI/data/raw')
|
| 283 |
+
|
| 284 |
+
# Load dataset
|
| 285 |
+
# dataset = AudioDataset('/Users/tan135/Desktop/Sid AI/data/raw', preprocessor)
|
| 286 |
+
# dataset.load_from_directory_structure()
|
| 287 |
+
# splits = dataset.split_data()
|
src/deep_learning_models.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Deep learning models for respiratory disease detection.
|
| 3 |
+
Includes CNN and LSTM architectures.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tensorflow as tf
|
| 8 |
+
from tensorflow import keras
|
| 9 |
+
from tensorflow.keras import layers, models, callbacks
|
| 10 |
+
from tensorflow.keras.utils import to_categorical
|
| 11 |
+
from typing import Tuple, Optional, Dict
|
| 12 |
+
import pickle
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CNNModel:
|
| 17 |
+
"""Convolutional Neural Network for audio classification."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, input_shape: Tuple, num_classes: int, model_name: str = "cnn_model"):
|
| 20 |
+
"""
|
| 21 |
+
Initialize CNN model.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
input_shape: Shape of input (height, width, channels)
|
| 25 |
+
num_classes: Number of output classes
|
| 26 |
+
model_name: Name of the model
|
| 27 |
+
"""
|
| 28 |
+
self.input_shape = input_shape
|
| 29 |
+
self.num_classes = num_classes
|
| 30 |
+
self.model_name = model_name
|
| 31 |
+
self.model = None
|
| 32 |
+
self.history = None
|
| 33 |
+
|
| 34 |
+
def build_model(self, dropout_rate: float = 0.3):
|
| 35 |
+
"""
|
| 36 |
+
Build CNN architecture.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
dropout_rate: Dropout rate for regularization
|
| 40 |
+
"""
|
| 41 |
+
model = models.Sequential(name=self.model_name)
|
| 42 |
+
|
| 43 |
+
# First convolutional block
|
| 44 |
+
model.add(layers.Conv2D(32, (3, 3), activation='relu',
|
| 45 |
+
padding='same', input_shape=self.input_shape))
|
| 46 |
+
model.add(layers.BatchNormalization())
|
| 47 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 48 |
+
model.add(layers.Dropout(dropout_rate))
|
| 49 |
+
|
| 50 |
+
# Second convolutional block
|
| 51 |
+
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
|
| 52 |
+
model.add(layers.BatchNormalization())
|
| 53 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 54 |
+
model.add(layers.Dropout(dropout_rate))
|
| 55 |
+
|
| 56 |
+
# Third convolutional block
|
| 57 |
+
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
|
| 58 |
+
model.add(layers.BatchNormalization())
|
| 59 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 60 |
+
model.add(layers.Dropout(dropout_rate))
|
| 61 |
+
|
| 62 |
+
# Fourth convolutional block
|
| 63 |
+
model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
|
| 64 |
+
model.add(layers.BatchNormalization())
|
| 65 |
+
model.add(layers.GlobalAveragePooling2D())
|
| 66 |
+
|
| 67 |
+
# Dense layers
|
| 68 |
+
model.add(layers.Dense(256, activation='relu'))
|
| 69 |
+
model.add(layers.Dropout(dropout_rate))
|
| 70 |
+
model.add(layers.Dense(128, activation='relu'))
|
| 71 |
+
model.add(layers.Dropout(dropout_rate))
|
| 72 |
+
|
| 73 |
+
# Output layer
|
| 74 |
+
if self.num_classes == 2:
|
| 75 |
+
model.add(layers.Dense(1, activation='sigmoid'))
|
| 76 |
+
else:
|
| 77 |
+
model.add(layers.Dense(self.num_classes, activation='softmax'))
|
| 78 |
+
|
| 79 |
+
self.model = model
|
| 80 |
+
print(f"\n{self.model_name} architecture:")
|
| 81 |
+
self.model.summary()
|
| 82 |
+
|
| 83 |
+
return model
|
| 84 |
+
|
| 85 |
+
def compile_model(self, learning_rate: float = 0.001):
|
| 86 |
+
"""Compile the model."""
|
| 87 |
+
if self.model is None:
|
| 88 |
+
raise ValueError("Model must be built before compilation")
|
| 89 |
+
|
| 90 |
+
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
|
| 91 |
+
|
| 92 |
+
if self.num_classes == 2:
|
| 93 |
+
loss = 'binary_crossentropy'
|
| 94 |
+
metrics = ['accuracy', keras.metrics.AUC(name='auc')]
|
| 95 |
+
else:
|
| 96 |
+
loss = 'sparse_categorical_crossentropy'
|
| 97 |
+
metrics = ['accuracy']
|
| 98 |
+
|
| 99 |
+
self.model.compile(
|
| 100 |
+
optimizer=optimizer,
|
| 101 |
+
loss=loss,
|
| 102 |
+
metrics=metrics
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
print(f"Model compiled with optimizer={optimizer.__class__.__name__}, loss={loss}")
|
| 106 |
+
|
| 107 |
+
def train(self, X_train: np.ndarray, y_train: np.ndarray,
|
| 108 |
+
X_val: np.ndarray, y_val: np.ndarray,
|
| 109 |
+
epochs: int = 50, batch_size: int = 32,
|
| 110 |
+
model_dir: str = 'models'):
|
| 111 |
+
"""
|
| 112 |
+
Train the CNN model.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
X_train: Training features
|
| 116 |
+
y_train: Training labels
|
| 117 |
+
X_val: Validation features
|
| 118 |
+
y_val: Validation labels
|
| 119 |
+
epochs: Number of training epochs
|
| 120 |
+
batch_size: Batch size
|
| 121 |
+
model_dir: Directory to save model checkpoints
|
| 122 |
+
"""
|
| 123 |
+
if self.model is None:
|
| 124 |
+
raise ValueError("Model must be built and compiled before training")
|
| 125 |
+
|
| 126 |
+
# Create model directory
|
| 127 |
+
model_path = Path(model_dir)
|
| 128 |
+
model_path.mkdir(parents=True, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
# Define callbacks
|
| 131 |
+
checkpoint_path = model_path / f"{self.model_name}_best.keras"
|
| 132 |
+
callbacks_list = [
|
| 133 |
+
callbacks.ModelCheckpoint(
|
| 134 |
+
str(checkpoint_path),
|
| 135 |
+
monitor='val_loss',
|
| 136 |
+
save_best_only=True,
|
| 137 |
+
verbose=1
|
| 138 |
+
),
|
| 139 |
+
callbacks.EarlyStopping(
|
| 140 |
+
monitor='val_loss',
|
| 141 |
+
patience=10,
|
| 142 |
+
restore_best_weights=True,
|
| 143 |
+
verbose=1
|
| 144 |
+
),
|
| 145 |
+
callbacks.ReduceLROnPlateau(
|
| 146 |
+
monitor='val_loss',
|
| 147 |
+
factor=0.5,
|
| 148 |
+
patience=5,
|
| 149 |
+
min_lr=1e-7,
|
| 150 |
+
verbose=1
|
| 151 |
+
)
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
print(f"\nTraining {self.model_name}...")
|
| 155 |
+
print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")
|
| 156 |
+
print(f"Epochs: {epochs}, Batch size: {batch_size}")
|
| 157 |
+
|
| 158 |
+
# Train model
|
| 159 |
+
self.history = self.model.fit(
|
| 160 |
+
X_train, y_train,
|
| 161 |
+
validation_data=(X_val, y_val),
|
| 162 |
+
epochs=epochs,
|
| 163 |
+
batch_size=batch_size,
|
| 164 |
+
callbacks=callbacks_list,
|
| 165 |
+
verbose=1
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
print(f"\nTraining complete. Best model saved to {checkpoint_path}")
|
| 169 |
+
|
| 170 |
+
return self.history
|
| 171 |
+
|
| 172 |
+
def evaluate(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict:
|
| 173 |
+
"""Evaluate model on test set."""
|
| 174 |
+
if self.model is None:
|
| 175 |
+
raise ValueError("Model must be trained before evaluation")
|
| 176 |
+
|
| 177 |
+
print(f"\nEvaluating {self.model_name}...")
|
| 178 |
+
results = self.model.evaluate(X_test, y_test, verbose=1)
|
| 179 |
+
|
| 180 |
+
# Get predictions
|
| 181 |
+
y_pred_proba = self.model.predict(X_test)
|
| 182 |
+
|
| 183 |
+
if self.num_classes == 2:
|
| 184 |
+
y_pred = (y_pred_proba > 0.5).astype(int).flatten()
|
| 185 |
+
else:
|
| 186 |
+
y_pred = np.argmax(y_pred_proba, axis=1)
|
| 187 |
+
|
| 188 |
+
evaluation_results = {
|
| 189 |
+
'loss': results[0],
|
| 190 |
+
'accuracy': results[1],
|
| 191 |
+
'predictions': y_pred,
|
| 192 |
+
'probabilities': y_pred_proba
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
if len(results) > 2:
|
| 196 |
+
evaluation_results['auc'] = results[2]
|
| 197 |
+
|
| 198 |
+
print(f"Test Loss: {results[0]:.4f}")
|
| 199 |
+
print(f"Test Accuracy: {results[1]:.4f}")
|
| 200 |
+
|
| 201 |
+
return evaluation_results
|
| 202 |
+
|
| 203 |
+
def save(self, filepath: str):
|
| 204 |
+
"""Save model to disk."""
|
| 205 |
+
self.model.save(filepath)
|
| 206 |
+
print(f"Model saved to {filepath}")
|
| 207 |
+
|
| 208 |
+
@classmethod
|
| 209 |
+
def load(cls, filepath: str):
|
| 210 |
+
"""Load model from disk."""
|
| 211 |
+
model = keras.models.load_model(filepath)
|
| 212 |
+
print(f"Model loaded from {filepath}")
|
| 213 |
+
return model
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class LSTMModel:
|
| 217 |
+
"""LSTM model for sequential audio classification."""
|
| 218 |
+
|
| 219 |
+
def __init__(self, input_shape: Tuple, num_classes: int, model_name: str = "lstm_model"):
|
| 220 |
+
"""
|
| 221 |
+
Initialize LSTM model.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
input_shape: Shape of input (time_steps, features)
|
| 225 |
+
num_classes: Number of output classes
|
| 226 |
+
model_name: Name of the model
|
| 227 |
+
"""
|
| 228 |
+
self.input_shape = input_shape
|
| 229 |
+
self.num_classes = num_classes
|
| 230 |
+
self.model_name = model_name
|
| 231 |
+
self.model = None
|
| 232 |
+
self.history = None
|
| 233 |
+
|
| 234 |
+
def build_model(self, dropout_rate: float = 0.3):
|
| 235 |
+
"""
|
| 236 |
+
Build LSTM architecture.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
dropout_rate: Dropout rate for regularization
|
| 240 |
+
"""
|
| 241 |
+
model = models.Sequential(name=self.model_name)
|
| 242 |
+
|
| 243 |
+
# LSTM layers
|
| 244 |
+
model.add(layers.LSTM(128, return_sequences=True, input_shape=self.input_shape))
|
| 245 |
+
model.add(layers.Dropout(dropout_rate))
|
| 246 |
+
model.add(layers.BatchNormalization())
|
| 247 |
+
|
| 248 |
+
model.add(layers.LSTM(64, return_sequences=True))
|
| 249 |
+
model.add(layers.Dropout(dropout_rate))
|
| 250 |
+
model.add(layers.BatchNormalization())
|
| 251 |
+
|
| 252 |
+
model.add(layers.LSTM(32))
|
| 253 |
+
model.add(layers.Dropout(dropout_rate))
|
| 254 |
+
|
| 255 |
+
# Dense layers
|
| 256 |
+
model.add(layers.Dense(64, activation='relu'))
|
| 257 |
+
model.add(layers.Dropout(dropout_rate))
|
| 258 |
+
|
| 259 |
+
# Output layer
|
| 260 |
+
if self.num_classes == 2:
|
| 261 |
+
model.add(layers.Dense(1, activation='sigmoid'))
|
| 262 |
+
else:
|
| 263 |
+
model.add(layers.Dense(self.num_classes, activation='softmax'))
|
| 264 |
+
|
| 265 |
+
self.model = model
|
| 266 |
+
print(f"\n{self.model_name} architecture:")
|
| 267 |
+
self.model.summary()
|
| 268 |
+
|
| 269 |
+
return model
|
| 270 |
+
|
| 271 |
+
def compile_model(self, learning_rate: float = 0.001):
|
| 272 |
+
"""Compile the model."""
|
| 273 |
+
if self.model is None:
|
| 274 |
+
raise ValueError("Model must be built before compilation")
|
| 275 |
+
|
| 276 |
+
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
|
| 277 |
+
|
| 278 |
+
if self.num_classes == 2:
|
| 279 |
+
loss = 'binary_crossentropy'
|
| 280 |
+
metrics = ['accuracy', keras.metrics.AUC(name='auc')]
|
| 281 |
+
else:
|
| 282 |
+
loss = 'sparse_categorical_crossentropy'
|
| 283 |
+
metrics = ['accuracy']
|
| 284 |
+
|
| 285 |
+
self.model.compile(
|
| 286 |
+
optimizer=optimizer,
|
| 287 |
+
loss=loss,
|
| 288 |
+
metrics=metrics
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
print(f"Model compiled with optimizer={optimizer.__class__.__name__}, loss={loss}")
|
| 292 |
+
|
| 293 |
+
def train(self, X_train: np.ndarray, y_train: np.ndarray,
|
| 294 |
+
X_val: np.ndarray, y_val: np.ndarray,
|
| 295 |
+
epochs: int = 50, batch_size: int = 32,
|
| 296 |
+
model_dir: str = 'models'):
|
| 297 |
+
"""Train the LSTM model."""
|
| 298 |
+
if self.model is None:
|
| 299 |
+
raise ValueError("Model must be built and compiled before training")
|
| 300 |
+
|
| 301 |
+
# Create model directory
|
| 302 |
+
model_path = Path(model_dir)
|
| 303 |
+
model_path.mkdir(parents=True, exist_ok=True)
|
| 304 |
+
|
| 305 |
+
# Define callbacks
|
| 306 |
+
checkpoint_path = model_path / f"{self.model_name}_best.keras"
|
| 307 |
+
callbacks_list = [
|
| 308 |
+
callbacks.ModelCheckpoint(
|
| 309 |
+
str(checkpoint_path),
|
| 310 |
+
monitor='val_loss',
|
| 311 |
+
save_best_only=True,
|
| 312 |
+
verbose=1
|
| 313 |
+
),
|
| 314 |
+
callbacks.EarlyStopping(
|
| 315 |
+
monitor='val_loss',
|
| 316 |
+
patience=10,
|
| 317 |
+
restore_best_weights=True,
|
| 318 |
+
verbose=1
|
| 319 |
+
),
|
| 320 |
+
callbacks.ReduceLROnPlateau(
|
| 321 |
+
monitor='val_loss',
|
| 322 |
+
factor=0.5,
|
| 323 |
+
patience=5,
|
| 324 |
+
min_lr=1e-7,
|
| 325 |
+
verbose=1
|
| 326 |
+
)
|
| 327 |
+
]
|
| 328 |
+
|
| 329 |
+
print(f"\nTraining {self.model_name}...")
|
| 330 |
+
print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")
|
| 331 |
+
|
| 332 |
+
# Train model
|
| 333 |
+
self.history = self.model.fit(
|
| 334 |
+
X_train, y_train,
|
| 335 |
+
validation_data=(X_val, y_val),
|
| 336 |
+
epochs=epochs,
|
| 337 |
+
batch_size=batch_size,
|
| 338 |
+
callbacks=callbacks_list,
|
| 339 |
+
verbose=1
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
print(f"\nTraining complete. Best model saved to {checkpoint_path}")
|
| 343 |
+
|
| 344 |
+
return self.history
|
| 345 |
+
|
| 346 |
+
def evaluate(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict:
|
| 347 |
+
"""Evaluate model on test set."""
|
| 348 |
+
if self.model is None:
|
| 349 |
+
raise ValueError("Model must be trained before evaluation")
|
| 350 |
+
|
| 351 |
+
print(f"\nEvaluating {self.model_name}...")
|
| 352 |
+
results = self.model.evaluate(X_test, y_test, verbose=1)
|
| 353 |
+
|
| 354 |
+
# Get predictions
|
| 355 |
+
y_pred_proba = self.model.predict(X_test)
|
| 356 |
+
|
| 357 |
+
if self.num_classes == 2:
|
| 358 |
+
y_pred = (y_pred_proba > 0.5).astype(int).flatten()
|
| 359 |
+
else:
|
| 360 |
+
y_pred = np.argmax(y_pred_proba, axis=1)
|
| 361 |
+
|
| 362 |
+
evaluation_results = {
|
| 363 |
+
'loss': results[0],
|
| 364 |
+
'accuracy': results[1],
|
| 365 |
+
'predictions': y_pred,
|
| 366 |
+
'probabilities': y_pred_proba
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
if len(results) > 2:
|
| 370 |
+
evaluation_results['auc'] = results[2]
|
| 371 |
+
|
| 372 |
+
print(f"Test Loss: {results[0]:.4f}")
|
| 373 |
+
print(f"Test Accuracy: {results[1]:.4f}")
|
| 374 |
+
|
| 375 |
+
return evaluation_results
|
| 376 |
+
|
| 377 |
+
def save(self, filepath: str):
|
| 378 |
+
"""Save model to disk."""
|
| 379 |
+
self.model.save(filepath)
|
| 380 |
+
print(f"Model saved to {filepath}")
|
| 381 |
+
|
| 382 |
+
@classmethod
|
| 383 |
+
def load(cls, filepath: str):
|
| 384 |
+
"""Load model from disk."""
|
| 385 |
+
model = keras.models.load_model(filepath)
|
| 386 |
+
print(f"Model loaded from {filepath}")
|
| 387 |
+
return model
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
if __name__ == "__main__":
|
| 391 |
+
print("Deep learning models module loaded successfully")
|
| 392 |
+
print("Available models: CNNModel, LSTMModel")
|
src/evaluation.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation and visualization tools for model performance analysis.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
from sklearn.metrics import (
|
| 9 |
+
confusion_matrix, classification_report, roc_curve, auc,
|
| 10 |
+
precision_recall_curve, roc_auc_score
|
| 11 |
+
)
|
| 12 |
+
from sklearn.preprocessing import label_binarize
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import List, Optional, Dict
|
| 15 |
+
import json
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ModelEvaluator:
|
| 19 |
+
"""Comprehensive model evaluation and visualization."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, class_names: List[str], output_dir: str = 'results'):
|
| 22 |
+
"""
|
| 23 |
+
Initialize evaluator.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
class_names: List of class names
|
| 27 |
+
output_dir: Directory to save visualizations
|
| 28 |
+
"""
|
| 29 |
+
self.class_names = class_names
|
| 30 |
+
self.output_dir = Path(output_dir)
|
| 31 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
# Set style
|
| 34 |
+
sns.set_style("whitegrid")
|
| 35 |
+
plt.rcParams['figure.figsize'] = (10, 8)
|
| 36 |
+
|
| 37 |
+
def plot_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray,
|
| 38 |
+
title: str = 'Confusion Matrix', save_name: str = 'confusion_matrix.png'):
|
| 39 |
+
"""
|
| 40 |
+
Plot confusion matrix.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
y_true: True labels
|
| 44 |
+
y_pred: Predicted labels
|
| 45 |
+
title: Plot title
|
| 46 |
+
save_name: Filename to save the plot
|
| 47 |
+
"""
|
| 48 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 49 |
+
|
| 50 |
+
plt.figure(figsize=(10, 8))
|
| 51 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 52 |
+
xticklabels=self.class_names,
|
| 53 |
+
yticklabels=self.class_names,
|
| 54 |
+
cbar_kws={'label': 'Count'})
|
| 55 |
+
plt.title(title, fontsize=16, fontweight='bold')
|
| 56 |
+
plt.ylabel('True Label', fontsize=12)
|
| 57 |
+
plt.xlabel('Predicted Label', fontsize=12)
|
| 58 |
+
plt.tight_layout()
|
| 59 |
+
|
| 60 |
+
save_path = self.output_dir / save_name
|
| 61 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 62 |
+
print(f"Confusion matrix saved to {save_path}")
|
| 63 |
+
plt.close()
|
| 64 |
+
|
| 65 |
+
return cm
|
| 66 |
+
|
| 67 |
+
def plot_roc_curve(self, y_true: np.ndarray, y_proba: np.ndarray,
|
| 68 |
+
title: str = 'ROC Curve', save_name: str = 'roc_curve.png'):
|
| 69 |
+
"""
|
| 70 |
+
Plot ROC curve (handles binary and multi-class).
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
y_true: True labels
|
| 74 |
+
y_proba: Prediction probabilities
|
| 75 |
+
title: Plot title
|
| 76 |
+
save_name: Filename to save the plot
|
| 77 |
+
"""
|
| 78 |
+
n_classes = len(self.class_names)
|
| 79 |
+
|
| 80 |
+
plt.figure(figsize=(10, 8))
|
| 81 |
+
|
| 82 |
+
if n_classes == 2:
|
| 83 |
+
# Binary classification
|
| 84 |
+
# Handle both (n,1) and (n,2) probability shapes
|
| 85 |
+
if y_proba.shape[1] == 1:
|
| 86 |
+
fpr, tpr, _ = roc_curve(y_true, y_proba[:, 0])
|
| 87 |
+
else:
|
| 88 |
+
fpr, tpr, _ = roc_curve(y_true, y_proba[:, 1])
|
| 89 |
+
roc_auc = auc(fpr, tpr)
|
| 90 |
+
|
| 91 |
+
plt.plot(fpr, tpr, color='darkorange', lw=2,
|
| 92 |
+
label=f'ROC curve (AUC = {roc_auc:.3f})')
|
| 93 |
+
else:
|
| 94 |
+
# Multi-class classification
|
| 95 |
+
y_true_bin = label_binarize(y_true, classes=range(n_classes))
|
| 96 |
+
|
| 97 |
+
for i in range(n_classes):
|
| 98 |
+
fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_proba[:, i])
|
| 99 |
+
roc_auc = auc(fpr, tpr)
|
| 100 |
+
plt.plot(fpr, tpr, lw=2,
|
| 101 |
+
label=f'{self.class_names[i]} (AUC = {roc_auc:.3f})')
|
| 102 |
+
|
| 103 |
+
plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random (AUC = 0.5)')
|
| 104 |
+
plt.xlim([0.0, 1.0])
|
| 105 |
+
plt.ylim([0.0, 1.05])
|
| 106 |
+
plt.xlabel('False Positive Rate', fontsize=12)
|
| 107 |
+
plt.ylabel('True Positive Rate', fontsize=12)
|
| 108 |
+
plt.title(title, fontsize=16, fontweight='bold')
|
| 109 |
+
plt.legend(loc="lower right", fontsize=10)
|
| 110 |
+
plt.grid(True, alpha=0.3)
|
| 111 |
+
plt.tight_layout()
|
| 112 |
+
|
| 113 |
+
save_path = self.output_dir / save_name
|
| 114 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 115 |
+
print(f"ROC curve saved to {save_path}")
|
| 116 |
+
plt.close()
|
| 117 |
+
|
| 118 |
+
def plot_precision_recall_curve(self, y_true: np.ndarray, y_proba: np.ndarray,
|
| 119 |
+
title: str = 'Precision-Recall Curve',
|
| 120 |
+
save_name: str = 'precision_recall_curve.png'):
|
| 121 |
+
"""Plot precision-recall curve."""
|
| 122 |
+
n_classes = len(self.class_names)
|
| 123 |
+
|
| 124 |
+
plt.figure(figsize=(10, 8))
|
| 125 |
+
|
| 126 |
+
if n_classes == 2:
|
| 127 |
+
precision, recall, _ = precision_recall_curve(y_true, y_proba[:, 1])
|
| 128 |
+
plt.plot(recall, precision, color='darkorange', lw=2)
|
| 129 |
+
else:
|
| 130 |
+
y_true_bin = label_binarize(y_true, classes=range(n_classes))
|
| 131 |
+
|
| 132 |
+
for i in range(n_classes):
|
| 133 |
+
precision, recall, _ = precision_recall_curve(y_true_bin[:, i], y_proba[:, i])
|
| 134 |
+
plt.plot(recall, precision, lw=2, label=self.class_names[i])
|
| 135 |
+
|
| 136 |
+
plt.xlabel('Recall', fontsize=12)
|
| 137 |
+
plt.ylabel('Precision', fontsize=12)
|
| 138 |
+
plt.title(title, fontsize=16, fontweight='bold')
|
| 139 |
+
plt.legend(loc="best", fontsize=10)
|
| 140 |
+
plt.grid(True, alpha=0.3)
|
| 141 |
+
plt.tight_layout()
|
| 142 |
+
|
| 143 |
+
save_path = self.output_dir / save_name
|
| 144 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 145 |
+
print(f"Precision-recall curve saved to {save_path}")
|
| 146 |
+
plt.close()
|
| 147 |
+
|
| 148 |
+
def plot_training_history(self, history, title: str = 'Training History',
|
| 149 |
+
save_name: str = 'training_history.png'):
|
| 150 |
+
"""
|
| 151 |
+
Plot training history for deep learning models.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
history: Keras history object
|
| 155 |
+
title: Plot title
|
| 156 |
+
save_name: Filename to save the plot
|
| 157 |
+
"""
|
| 158 |
+
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
|
| 159 |
+
|
| 160 |
+
# Plot accuracy
|
| 161 |
+
axes[0].plot(history.history['accuracy'], label='Train Accuracy', linewidth=2)
|
| 162 |
+
axes[0].plot(history.history['val_accuracy'], label='Val Accuracy', linewidth=2)
|
| 163 |
+
axes[0].set_xlabel('Epoch', fontsize=12)
|
| 164 |
+
axes[0].set_ylabel('Accuracy', fontsize=12)
|
| 165 |
+
axes[0].set_title('Model Accuracy', fontsize=14, fontweight='bold')
|
| 166 |
+
axes[0].legend(loc='best', fontsize=10)
|
| 167 |
+
axes[0].grid(True, alpha=0.3)
|
| 168 |
+
|
| 169 |
+
# Plot loss
|
| 170 |
+
axes[1].plot(history.history['loss'], label='Train Loss', linewidth=2)
|
| 171 |
+
axes[1].plot(history.history['val_loss'], label='Val Loss', linewidth=2)
|
| 172 |
+
axes[1].set_xlabel('Epoch', fontsize=12)
|
| 173 |
+
axes[1].set_ylabel('Loss', fontsize=12)
|
| 174 |
+
axes[1].set_title('Model Loss', fontsize=14, fontweight='bold')
|
| 175 |
+
axes[1].legend(loc='best', fontsize=10)
|
| 176 |
+
axes[1].grid(True, alpha=0.3)
|
| 177 |
+
|
| 178 |
+
plt.suptitle(title, fontsize=16, fontweight='bold', y=1.02)
|
| 179 |
+
plt.tight_layout()
|
| 180 |
+
|
| 181 |
+
save_path = self.output_dir / save_name
|
| 182 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 183 |
+
print(f"Training history saved to {save_path}")
|
| 184 |
+
plt.close()
|
| 185 |
+
|
| 186 |
+
def plot_feature_importance(self, importances: np.ndarray, feature_names: List[str] = None,
|
| 187 |
+
top_n: int = 20, title: str = 'Feature Importance',
|
| 188 |
+
save_name: str = 'feature_importance.png'):
|
| 189 |
+
"""Plot feature importance for tree-based models."""
|
| 190 |
+
if feature_names is None:
|
| 191 |
+
feature_names = [f'Feature {i}' for i in range(len(importances))]
|
| 192 |
+
|
| 193 |
+
# Sort features by importance
|
| 194 |
+
indices = np.argsort(importances)[-top_n:]
|
| 195 |
+
|
| 196 |
+
plt.figure(figsize=(10, max(8, top_n * 0.3)))
|
| 197 |
+
plt.barh(range(len(indices)), importances[indices], color='steelblue')
|
| 198 |
+
plt.yticks(range(len(indices)), [feature_names[i] for i in indices])
|
| 199 |
+
plt.xlabel('Importance', fontsize=12)
|
| 200 |
+
plt.title(title, fontsize=16, fontweight='bold')
|
| 201 |
+
plt.tight_layout()
|
| 202 |
+
|
| 203 |
+
save_path = self.output_dir / save_name
|
| 204 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 205 |
+
print(f"Feature importance saved to {save_path}")
|
| 206 |
+
plt.close()
|
| 207 |
+
|
| 208 |
+
def generate_classification_report(self, y_true: np.ndarray, y_pred: np.ndarray,
|
| 209 |
+
save_name: str = 'classification_report.txt'):
|
| 210 |
+
"""Generate and save classification report."""
|
| 211 |
+
report = classification_report(y_true, y_pred, target_names=self.class_names)
|
| 212 |
+
|
| 213 |
+
save_path = self.output_dir / save_name
|
| 214 |
+
with open(save_path, 'w') as f:
|
| 215 |
+
f.write(report)
|
| 216 |
+
|
| 217 |
+
print(f"\nClassification Report:\n{report}")
|
| 218 |
+
print(f"Report saved to {save_path}")
|
| 219 |
+
|
| 220 |
+
return report
|
| 221 |
+
|
| 222 |
+
def save_metrics(self, metrics: Dict, save_name: str = 'metrics.json'):
|
| 223 |
+
"""Save metrics to JSON file."""
|
| 224 |
+
# Convert numpy types to Python types for JSON serialization
|
| 225 |
+
def convert_to_serializable(obj):
|
| 226 |
+
if isinstance(obj, np.ndarray):
|
| 227 |
+
return obj.tolist()
|
| 228 |
+
elif isinstance(obj, (np.int64, np.int32)):
|
| 229 |
+
return int(obj)
|
| 230 |
+
elif isinstance(obj, (np.float64, np.float32)):
|
| 231 |
+
return float(obj)
|
| 232 |
+
elif isinstance(obj, dict):
|
| 233 |
+
return {k: convert_to_serializable(v) for k, v in obj.items()}
|
| 234 |
+
elif isinstance(obj, list):
|
| 235 |
+
return [convert_to_serializable(item) for item in obj]
|
| 236 |
+
return obj
|
| 237 |
+
|
| 238 |
+
serializable_metrics = convert_to_serializable(metrics)
|
| 239 |
+
|
| 240 |
+
save_path = self.output_dir / save_name
|
| 241 |
+
with open(save_path, 'w') as f:
|
| 242 |
+
json.dump(serializable_metrics, f, indent=4)
|
| 243 |
+
|
| 244 |
+
print(f"Metrics saved to {save_path}")
|
| 245 |
+
|
| 246 |
+
def create_comparison_plot(self, results_dict: Dict[str, Dict],
|
| 247 |
+
metric: str = 'accuracy',
|
| 248 |
+
title: str = 'Model Comparison',
|
| 249 |
+
save_name: str = 'model_comparison.png'):
|
| 250 |
+
"""
|
| 251 |
+
Create comparison plot for multiple models.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
results_dict: Dictionary of model results {model_name: results_dict}
|
| 255 |
+
metric: Metric to compare
|
| 256 |
+
title: Plot title
|
| 257 |
+
save_name: Filename to save the plot
|
| 258 |
+
"""
|
| 259 |
+
models = list(results_dict.keys())
|
| 260 |
+
values = [results_dict[model][metric] for model in models]
|
| 261 |
+
|
| 262 |
+
plt.figure(figsize=(12, 6))
|
| 263 |
+
bars = plt.bar(models, values, color='steelblue', alpha=0.8, edgecolor='black')
|
| 264 |
+
|
| 265 |
+
# Add value labels on top of bars
|
| 266 |
+
for bar in bars:
|
| 267 |
+
height = bar.get_height()
|
| 268 |
+
plt.text(bar.get_x() + bar.get_width()/2., height,
|
| 269 |
+
f'{height:.4f}',
|
| 270 |
+
ha='center', va='bottom', fontsize=10, fontweight='bold')
|
| 271 |
+
|
| 272 |
+
plt.ylabel(metric.capitalize(), fontsize=12)
|
| 273 |
+
plt.title(title, fontsize=16, fontweight='bold')
|
| 274 |
+
plt.xticks(rotation=45, ha='right')
|
| 275 |
+
plt.ylim([0, 1.0])
|
| 276 |
+
plt.grid(True, alpha=0.3, axis='y')
|
| 277 |
+
plt.tight_layout()
|
| 278 |
+
|
| 279 |
+
save_path = self.output_dir / save_name
|
| 280 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 281 |
+
print(f"Comparison plot saved to {save_path}")
|
| 282 |
+
plt.close()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class AudioVisualizer:
|
| 286 |
+
"""Visualize audio features and spectrograms."""
|
| 287 |
+
|
| 288 |
+
@staticmethod
|
| 289 |
+
def plot_waveform(audio: np.ndarray, sr: int = 16000,
|
| 290 |
+
title: str = 'Audio Waveform', save_path: Optional[str] = None):
|
| 291 |
+
"""Plot audio waveform."""
|
| 292 |
+
plt.figure(figsize=(12, 4))
|
| 293 |
+
time = np.arange(0, len(audio)) / sr
|
| 294 |
+
plt.plot(time, audio, linewidth=0.5)
|
| 295 |
+
plt.xlabel('Time (s)', fontsize=12)
|
| 296 |
+
plt.ylabel('Amplitude', fontsize=12)
|
| 297 |
+
plt.title(title, fontsize=14, fontweight='bold')
|
| 298 |
+
plt.grid(True, alpha=0.3)
|
| 299 |
+
plt.tight_layout()
|
| 300 |
+
|
| 301 |
+
if save_path:
|
| 302 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 303 |
+
plt.close()
|
| 304 |
+
else:
|
| 305 |
+
plt.show()
|
| 306 |
+
|
| 307 |
+
@staticmethod
|
| 308 |
+
def plot_spectrogram(spectrogram: np.ndarray, sr: int = 16000,
|
| 309 |
+
title: str = 'Spectrogram', save_path: Optional[str] = None):
|
| 310 |
+
"""Plot spectrogram."""
|
| 311 |
+
plt.figure(figsize=(12, 6))
|
| 312 |
+
plt.imshow(spectrogram, aspect='auto', origin='lower', cmap='viridis')
|
| 313 |
+
plt.colorbar(format='%+2.0f dB', label='Power (dB)')
|
| 314 |
+
plt.xlabel('Time', fontsize=12)
|
| 315 |
+
plt.ylabel('Frequency', fontsize=12)
|
| 316 |
+
plt.title(title, fontsize=14, fontweight='bold')
|
| 317 |
+
plt.tight_layout()
|
| 318 |
+
|
| 319 |
+
if save_path:
|
| 320 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 321 |
+
plt.close()
|
| 322 |
+
else:
|
| 323 |
+
plt.show()
|
| 324 |
+
|
| 325 |
+
@staticmethod
|
| 326 |
+
def plot_mfcc(mfcc: np.ndarray, title: str = 'MFCC Features',
|
| 327 |
+
save_path: Optional[str] = None):
|
| 328 |
+
"""Plot MFCC features."""
|
| 329 |
+
plt.figure(figsize=(12, 6))
|
| 330 |
+
plt.imshow(mfcc, aspect='auto', origin='lower', cmap='coolwarm')
|
| 331 |
+
plt.colorbar(label='MFCC Coefficient Value')
|
| 332 |
+
plt.xlabel('Time Frame', fontsize=12)
|
| 333 |
+
plt.ylabel('MFCC Coefficient', fontsize=12)
|
| 334 |
+
plt.title(title, fontsize=14, fontweight='bold')
|
| 335 |
+
plt.tight_layout()
|
| 336 |
+
|
| 337 |
+
if save_path:
|
| 338 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 339 |
+
plt.close()
|
| 340 |
+
else:
|
| 341 |
+
plt.show()
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
if __name__ == "__main__":
|
| 345 |
+
print("Evaluation and visualization module loaded successfully")
|
src/preprocessing.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio preprocessing and feature extraction module for respiratory disease detection.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import librosa
|
| 6 |
+
import numpy as np
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Tuple, Dict, Optional
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings('ignore')
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AudioPreprocessor:
|
| 15 |
+
"""Handles audio loading, normalization, and feature extraction."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, sample_rate: int = 16000, duration: float = 5.0):
|
| 18 |
+
"""
|
| 19 |
+
Initialize the audio preprocessor.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
sample_rate: Target sample rate for all audio files
|
| 23 |
+
duration: Target duration in seconds (will pad/trim)
|
| 24 |
+
"""
|
| 25 |
+
self.sample_rate = sample_rate
|
| 26 |
+
self.duration = duration
|
| 27 |
+
self.target_length = int(sample_rate * duration)
|
| 28 |
+
|
| 29 |
+
def load_audio(self, file_path: str) -> np.ndarray:
|
| 30 |
+
"""
|
| 31 |
+
Load and normalize audio file.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
file_path: Path to audio file
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Normalized audio array
|
| 38 |
+
"""
|
| 39 |
+
try:
|
| 40 |
+
# Load audio file
|
| 41 |
+
audio, sr = librosa.load(file_path, sr=self.sample_rate, mono=True)
|
| 42 |
+
|
| 43 |
+
# Normalize audio to fixed length
|
| 44 |
+
audio = self._normalize_length(audio)
|
| 45 |
+
|
| 46 |
+
# Normalize amplitude
|
| 47 |
+
audio = librosa.util.normalize(audio)
|
| 48 |
+
|
| 49 |
+
return audio
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Error loading {file_path}: {e}")
|
| 52 |
+
return np.zeros(self.target_length)
|
| 53 |
+
|
| 54 |
+
def _normalize_length(self, audio: np.ndarray) -> np.ndarray:
|
| 55 |
+
"""Pad or trim audio to target length."""
|
| 56 |
+
if len(audio) < self.target_length:
|
| 57 |
+
# Pad with zeros
|
| 58 |
+
audio = np.pad(audio, (0, self.target_length - len(audio)))
|
| 59 |
+
else:
|
| 60 |
+
# Trim to target length
|
| 61 |
+
audio = audio[:self.target_length]
|
| 62 |
+
return audio
|
| 63 |
+
|
| 64 |
+
def extract_mfcc(self, audio: np.ndarray, n_mfcc: int = 40) -> np.ndarray:
|
| 65 |
+
"""
|
| 66 |
+
Extract MFCC features from audio.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
audio: Audio signal
|
| 70 |
+
n_mfcc: Number of MFCCs to extract
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
MFCC features (n_mfcc, time_steps)
|
| 74 |
+
"""
|
| 75 |
+
mfcc = librosa.feature.mfcc(
|
| 76 |
+
y=audio,
|
| 77 |
+
sr=self.sample_rate,
|
| 78 |
+
n_mfcc=n_mfcc,
|
| 79 |
+
n_fft=2048,
|
| 80 |
+
hop_length=512
|
| 81 |
+
)
|
| 82 |
+
return mfcc
|
| 83 |
+
|
| 84 |
+
def extract_mel_spectrogram(self, audio: np.ndarray, n_mels: int = 128) -> np.ndarray:
|
| 85 |
+
"""
|
| 86 |
+
Extract mel spectrogram from audio.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
audio: Audio signal
|
| 90 |
+
n_mels: Number of mel bands
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Mel spectrogram
|
| 94 |
+
"""
|
| 95 |
+
mel_spec = librosa.feature.melspectrogram(
|
| 96 |
+
y=audio,
|
| 97 |
+
sr=self.sample_rate,
|
| 98 |
+
n_mels=n_mels,
|
| 99 |
+
n_fft=2048,
|
| 100 |
+
hop_length=512
|
| 101 |
+
)
|
| 102 |
+
# Convert to log scale
|
| 103 |
+
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
|
| 104 |
+
return mel_spec_db
|
| 105 |
+
|
| 106 |
+
def extract_spectral_features(self, audio: np.ndarray) -> Dict[str, np.ndarray]:
|
| 107 |
+
"""
|
| 108 |
+
Extract various spectral features.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
audio: Audio signal
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Dictionary of spectral features
|
| 115 |
+
"""
|
| 116 |
+
features = {}
|
| 117 |
+
|
| 118 |
+
# Spectral centroid
|
| 119 |
+
features['spectral_centroid'] = librosa.feature.spectral_centroid(
|
| 120 |
+
y=audio, sr=self.sample_rate
|
| 121 |
+
)[0]
|
| 122 |
+
|
| 123 |
+
# Spectral rolloff
|
| 124 |
+
features['spectral_rolloff'] = librosa.feature.spectral_rolloff(
|
| 125 |
+
y=audio, sr=self.sample_rate
|
| 126 |
+
)[0]
|
| 127 |
+
|
| 128 |
+
# Zero crossing rate
|
| 129 |
+
features['zero_crossing_rate'] = librosa.feature.zero_crossing_rate(audio)[0]
|
| 130 |
+
|
| 131 |
+
# Chroma features
|
| 132 |
+
features['chroma'] = librosa.feature.chroma_stft(
|
| 133 |
+
y=audio, sr=self.sample_rate
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Spectral contrast
|
| 137 |
+
features['spectral_contrast'] = librosa.feature.spectral_contrast(
|
| 138 |
+
y=audio, sr=self.sample_rate
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return features
|
| 142 |
+
|
| 143 |
+
def extract_all_features(self, audio: np.ndarray) -> Dict[str, np.ndarray]:
|
| 144 |
+
"""
|
| 145 |
+
Extract all audio features.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
audio: Audio signal
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Dictionary containing all features
|
| 152 |
+
"""
|
| 153 |
+
features = {
|
| 154 |
+
'mfcc': self.extract_mfcc(audio),
|
| 155 |
+
'mel_spectrogram': self.extract_mel_spectrogram(audio),
|
| 156 |
+
}
|
| 157 |
+
features.update(self.extract_spectral_features(audio))
|
| 158 |
+
return features
|
| 159 |
+
|
| 160 |
+
def compute_statistics(self, feature_array: np.ndarray) -> np.ndarray:
|
| 161 |
+
"""
|
| 162 |
+
Compute statistical features (mean, std, min, max) from feature array.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
feature_array: 2D feature array (features, time)
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Flattened statistical features
|
| 169 |
+
"""
|
| 170 |
+
stats = []
|
| 171 |
+
stats.extend(np.mean(feature_array, axis=1))
|
| 172 |
+
stats.extend(np.std(feature_array, axis=1))
|
| 173 |
+
stats.extend(np.min(feature_array, axis=1))
|
| 174 |
+
stats.extend(np.max(feature_array, axis=1))
|
| 175 |
+
return np.array(stats)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class AudioAugmenter:
|
| 179 |
+
"""Augments audio data for better model generalization."""
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def add_noise(audio: np.ndarray, noise_level: float = 0.005) -> np.ndarray:
|
| 183 |
+
"""Add random noise to audio."""
|
| 184 |
+
noise = np.random.randn(len(audio))
|
| 185 |
+
return audio + noise_level * noise
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def time_stretch(audio: np.ndarray, rate: float = 1.2) -> np.ndarray:
|
| 189 |
+
"""Time stretch audio."""
|
| 190 |
+
return librosa.effects.time_stretch(audio, rate=rate)
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def pitch_shift(audio: np.ndarray, sr: int, n_steps: int = 2) -> np.ndarray:
|
| 194 |
+
"""Shift pitch of audio."""
|
| 195 |
+
return librosa.effects.pitch_shift(audio, sr=sr, n_steps=n_steps)
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def random_gain(audio: np.ndarray, min_gain: float = 0.8, max_gain: float = 1.2) -> np.ndarray:
|
| 199 |
+
"""Apply random gain to audio."""
|
| 200 |
+
gain = np.random.uniform(min_gain, max_gain)
|
| 201 |
+
return audio * gain
|
| 202 |
+
|
| 203 |
+
def augment(self, audio: np.ndarray, sr: int, techniques: list = None) -> np.ndarray:
|
| 204 |
+
"""
|
| 205 |
+
Apply random augmentation techniques.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
audio: Audio signal
|
| 209 |
+
sr: Sample rate
|
| 210 |
+
techniques: List of augmentation techniques to apply
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Augmented audio
|
| 214 |
+
"""
|
| 215 |
+
if techniques is None:
|
| 216 |
+
techniques = ['noise', 'gain']
|
| 217 |
+
|
| 218 |
+
augmented = audio.copy()
|
| 219 |
+
|
| 220 |
+
for technique in techniques:
|
| 221 |
+
if technique == 'noise' and np.random.rand() > 0.5:
|
| 222 |
+
augmented = self.add_noise(augmented)
|
| 223 |
+
elif technique == 'pitch' and np.random.rand() > 0.5:
|
| 224 |
+
n_steps = np.random.randint(-2, 3)
|
| 225 |
+
augmented = self.pitch_shift(augmented, sr, n_steps)
|
| 226 |
+
elif technique == 'stretch' and np.random.rand() > 0.5:
|
| 227 |
+
rate = np.random.uniform(0.9, 1.1)
|
| 228 |
+
augmented = self.time_stretch(augmented, rate)
|
| 229 |
+
elif technique == 'gain' and np.random.rand() > 0.5:
|
| 230 |
+
augmented = self.random_gain(augmented)
|
| 231 |
+
|
| 232 |
+
return augmented
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def process_dataset(data_dir: str, output_dir: str, preprocessor: AudioPreprocessor):
|
| 236 |
+
"""
|
| 237 |
+
Process all audio files in a dataset directory.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
data_dir: Directory containing raw audio files
|
| 241 |
+
output_dir: Directory to save processed features
|
| 242 |
+
preprocessor: AudioPreprocessor instance
|
| 243 |
+
"""
|
| 244 |
+
data_path = Path(data_dir)
|
| 245 |
+
output_path = Path(output_dir)
|
| 246 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 247 |
+
|
| 248 |
+
audio_files = list(data_path.rglob('*.wav')) + list(data_path.rglob('*.mp3'))
|
| 249 |
+
|
| 250 |
+
print(f"Found {len(audio_files)} audio files")
|
| 251 |
+
|
| 252 |
+
for audio_file in audio_files:
|
| 253 |
+
try:
|
| 254 |
+
# Load and preprocess audio
|
| 255 |
+
audio = preprocessor.load_audio(str(audio_file))
|
| 256 |
+
|
| 257 |
+
# Extract features
|
| 258 |
+
features = preprocessor.extract_all_features(audio)
|
| 259 |
+
|
| 260 |
+
# Save features
|
| 261 |
+
relative_path = audio_file.relative_to(data_path)
|
| 262 |
+
output_file = output_path / relative_path.with_suffix('.npz')
|
| 263 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 264 |
+
|
| 265 |
+
np.savez_compressed(output_file, **features)
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print(f"Error processing {audio_file}: {e}")
|
| 269 |
+
|
| 270 |
+
print(f"Processing complete. Features saved to {output_dir}")
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
# Example usage
|
| 275 |
+
preprocessor = AudioPreprocessor(sample_rate=16000, duration=5.0)
|
| 276 |
+
|
| 277 |
+
# Process a single file (example)
|
| 278 |
+
# audio = preprocessor.load_audio("path/to/audio.wav")
|
| 279 |
+
# features = preprocessor.extract_all_features(audio)
|
| 280 |
+
# print("MFCC shape:", features['mfcc'].shape)
|
| 281 |
+
# print("Mel spectrogram shape:", features['mel_spectrogram'].shape)
|