Jeremy Mattathias Mboe
Fix: Remove info parameter from Gradio components for compatibility
7e903cb
"""
Respiratory Disease Classification Web Application
Combines CatBoost (tabular) and LSTM (audio) models with weighted ensemble
"""
import gradio as gr
import numpy as np
import pandas as pd
import torch
from catboost import CatBoostClassifier
import warnings
warnings.filterwarnings('ignore')
from utils import TabularPreprocessor, AudioPreprocessor, get_disease_info
from model import load_lstm_model, predict_audio, Config # Import Config for model loading compatibility
# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load models
print("Loading models...")
try:
# Load CatBoost model
catboost_model = CatBoostClassifier()
catboost_model.load_model('catboost_tabular_model.cbm')
print("✓ CatBoost model loaded successfully")
# Load LSTM model
lstm_model = load_lstm_model('best_respiratory_model.pth', device=device)
print("✓ LSTM model loaded successfully")
except Exception as e:
print(f"Error loading models: {e}")
raise
# Initialize preprocessors
tabular_preprocessor = TabularPreprocessor()
audio_preprocessor = AudioPreprocessor(
sample_rate=16000,
duration=1.0, # 1 second audio
n_mfcc=20,
n_fft=2048,
hop_length=512,
n_mels=64
)
print("Preprocessors initialized")
# Model weights for ensemble
TABULAR_WEIGHT = 0.8
AUDIO_WEIGHT = 0.2
def predict(age, gender, tb_contact, wheezing, phlegm_cough,
family_asthma, fever, cold_present, pack_years,
cough_audio, vowel_audio):
"""
Main prediction function combining both models
Args:
Tabular features (9 features)
cough_audio: Path to cough audio file
vowel_audio: Path to vowel audio file
Returns:
Prediction results and probabilities
"""
try:
# Validate inputs
if cough_audio is None or vowel_audio is None:
return {
"error": "⚠️ Please upload both cough and vowel audio files (1 second each)"
}, None
# ========== TABULAR MODEL PREDICTION ==========
# Prepare tabular data
tabular_data = {
'age': age if age is not None else np.nan,
'gender': gender,
'tbContactHistory': tb_contact,
'wheezingHistory': wheezing,
'phlegmCough': phlegm_cough,
'familyAsthmaHistory': family_asthma,
'feverHistory': fever if fever is not None else np.nan,
'coldPresent': cold_present if cold_present is not None else np.nan,
'packYears': pack_years if pack_years is not None else 0
}
# Preprocess tabular data
X_tabular = tabular_preprocessor.transform(tabular_data)
# Get CatBoost prediction probabilities
tabular_proba = catboost_model.predict_proba(X_tabular)[0]
tabular_pred = np.argmax(tabular_proba)
print(f"Tabular prediction: {tabular_pred}, proba: {tabular_proba}")
# ========== AUDIO MODEL PREDICTION ==========
# Extract features from both audio files
# Model was trained with CONCAT mode (120 features = 60 cough + 60 vowel)
combined_audio_features = audio_preprocessor.extract_from_both_audios(
cough_audio, vowel_audio, combine_mode="concat"
)
# Get LSTM prediction
audio_pred, audio_proba = predict_audio(
lstm_model, combined_audio_features, device=device
)
print(f"Audio prediction: {audio_pred}, proba: {audio_proba}")
# ========== ENSEMBLE PREDICTION ==========
# Weighted average of probabilities
ensemble_proba = (TABULAR_WEIGHT * tabular_proba +
AUDIO_WEIGHT * audio_proba)
# Final prediction
final_pred = np.argmax(ensemble_proba)
# Get disease information
disease_info = get_disease_info(final_pred)
# ========== FORMAT RESULTS ==========
# Create detailed results
result_text = f"""
# 🏥 Respiratory Disease Classification Results
## 📊 Final Diagnosis: **{disease_info['name']}**
### Confidence Scores:
- **Healthy**: {ensemble_proba[0]:.1%}
- **COPD**: {ensemble_proba[1]:.1%}
- **Asthma**: {ensemble_proba[2]:.1%}
---
## 📝 Description:
{disease_info['description']}
## 💡 Recommendations:
"""
for i, rec in enumerate(disease_info['recommendations'], 1):
result_text += f"\n{i}. {rec}"
result_text += f"""
---
## 🔬 Model Details:
- **Tabular Model (CatBoost)**: Weight = {TABULAR_WEIGHT:.0%}
- Prediction: {['Healthy', 'COPD', 'Asthma'][tabular_pred]}
- Confidence: {tabular_proba[tabular_pred]:.1%}
- **Audio Model (LSTM)**: Weight = {AUDIO_WEIGHT:.0%}
- Prediction: {['Healthy', 'COPD', 'Asthma'][audio_pred]}
- Confidence: {audio_proba[audio_pred]:.1%}
---
**⚠️ Disclaimer**: This is a predictive model and should not replace professional medical diagnosis.
Please consult with a healthcare professional for proper medical advice.
"""
# Create probability chart data
prob_chart = {
"Healthy": float(ensemble_proba[0]),
"COPD": float(ensemble_proba[1]),
"Asthma": float(ensemble_proba[2])
}
return result_text, prob_chart
except Exception as e:
error_msg = f"❌ Error during prediction: {str(e)}\n\nPlease check your inputs and try again."
print(f"Prediction error: {e}")
import traceback
traceback.print_exc()
return error_msg, None
# ========== GRADIO INTERFACE ==========
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.output-markdown h1 {
color: #2c3e50;
border-bottom: 3px solid #3498db;
padding-bottom: 10px;
}
.output-markdown h2 {
color: #34495e;
margin-top: 20px;
}
"""
# Create Gradio interface
with gr.Blocks(css=custom_css, title="Respiratory Disease Classifier") as demo:
gr.Markdown("""
# 🫁 Respiratory Disease Classification System
This application uses a combination of **tabular medical data** (80%) and **audio recordings** (20%)
to predict respiratory diseases: **Healthy**, **COPD**, or **Asthma**.
### 📋 Instructions:
1. Fill in the medical information below
2. Upload **1-second** cough audio (WAV format)
3. Upload **1-second** vowel audio (WAV format)
4. Click **"Predict Disease"** to get results
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📝 Medical Information")
age = gr.Number(
label="Age (years)",
value=None
)
gender = gr.Radio(
choices=[0, 1],
label="Gender (0=Female, 1=Male)",
value=1
)
tb_contact = gr.Radio(
choices=[0, 1],
label="TB Contact History (0=No, 1=Yes)",
value=0
)
wheezing = gr.Radio(
choices=[0, 1],
label="Wheezing History (0=No, 1=Yes)",
value=0
)
phlegm_cough = gr.Radio(
choices=[0, 1],
label="Phlegm Cough (0=No, 1=Yes)",
value=0
)
family_asthma = gr.Radio(
choices=[0, 1],
label="Family Asthma History (0=No, 1=Yes)",
value=0
)
fever = gr.Radio(
choices=[0, 1],
label="Fever History (0=No, 1=Yes, optional)",
value=None
)
cold_present = gr.Radio(
choices=[0, 1],
label="Cold Present (0=No, 1=Yes, optional)",
value=None
)
pack_years = gr.Number(
label="Pack Years (smoking, optional)",
value=0
)
with gr.Column(scale=1):
gr.Markdown("### 🎤 Audio Recordings (1 second each)")
cough_audio = gr.Audio(
label="Cough Audio (WAV, 1 second)",
type="filepath",
sources=["upload"]
)
vowel_audio = gr.Audio(
label="Vowel Audio (WAV, 1 second)",
type="filepath",
sources=["upload"]
)
gr.Markdown("""
#### 🎵 Audio Recording Tips:
- Record in a quiet environment
- Keep recordings to exactly **1 second**
- Use WAV format for best quality
- For cough: record a clear single cough
- For vowel: sustain an "ahhh" sound
""")
predict_btn = gr.Button(
"🔍 Predict Disease",
variant="primary",
size="lg"
)
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=2):
output_text = gr.Markdown(label="Prediction Results")
with gr.Column(scale=1):
output_plot = gr.Label(
label="Probability Distribution",
num_top_classes=3
)
# Example inputs
gr.Markdown("### 📌 Example Input")
gr.Examples(
examples=[
[43, 1, 0, 1, 0, 0, 0, None, 0, None, None], # Example 1
[24, 0, 0, 1, 0, 0, 0, None, 0, None, None], # Example 2
],
inputs=[age, gender, tb_contact, wheezing, phlegm_cough,
family_asthma, fever, cold_present, pack_years,
cough_audio, vowel_audio],
label="Click to load example data (Note: Audio files need to be uploaded separately)"
)
# Connect prediction function
predict_btn.click(
fn=predict,
inputs=[age, gender, tb_contact, wheezing, phlegm_cough,
family_asthma, fever, cold_present, pack_years,
cough_audio, vowel_audio],
outputs=[output_text, output_plot]
)
gr.Markdown("""
---
### ℹ️ About This System
**Models Used:**
- **CatBoost Classifier** (80% weight): Analyzes tabular medical data
- **LSTM Neural Network** (20% weight): Analyzes audio features (MFCC)
**Classes:**
- **Class 0**: Healthy (No respiratory disease)
- **Class 1**: COPD (Chronic Obstructive Pulmonary Disease)
- **Class 2**: Asthma
**⚠️ Medical Disclaimer:**
This application is for educational and research purposes only. It should NOT be used as a
substitute for professional medical advice, diagnosis, or treatment. Always seek the advice
of qualified health providers with any questions regarding a medical condition.
---
*Developed with ❤️ using Gradio, PyTorch, and CatBoost*
""")
# Launch the application
if __name__ == "__main__":
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860
)