Spaces:
Build error
Build error
| """ | |
| Respiratory Disease Classification Web Application | |
| Combines CatBoost (tabular) and LSTM (audio) models with weighted ensemble | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from catboost import CatBoostClassifier | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| from utils import TabularPreprocessor, AudioPreprocessor, get_disease_info | |
| from model import load_lstm_model, predict_audio, Config # Import Config for model loading compatibility | |
| # Initialize device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Load models | |
| print("Loading models...") | |
| try: | |
| # Load CatBoost model | |
| catboost_model = CatBoostClassifier() | |
| catboost_model.load_model('catboost_tabular_model.cbm') | |
| print("✓ CatBoost model loaded successfully") | |
| # Load LSTM model | |
| lstm_model = load_lstm_model('best_respiratory_model.pth', device=device) | |
| print("✓ LSTM model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| raise | |
| # Initialize preprocessors | |
| tabular_preprocessor = TabularPreprocessor() | |
| audio_preprocessor = AudioPreprocessor( | |
| sample_rate=16000, | |
| duration=1.0, # 1 second audio | |
| n_mfcc=20, | |
| n_fft=2048, | |
| hop_length=512, | |
| n_mels=64 | |
| ) | |
| print("Preprocessors initialized") | |
| # Model weights for ensemble | |
| TABULAR_WEIGHT = 0.8 | |
| AUDIO_WEIGHT = 0.2 | |
| def predict(age, gender, tb_contact, wheezing, phlegm_cough, | |
| family_asthma, fever, cold_present, pack_years, | |
| cough_audio, vowel_audio): | |
| """ | |
| Main prediction function combining both models | |
| Args: | |
| Tabular features (9 features) | |
| cough_audio: Path to cough audio file | |
| vowel_audio: Path to vowel audio file | |
| Returns: | |
| Prediction results and probabilities | |
| """ | |
| try: | |
| # Validate inputs | |
| if cough_audio is None or vowel_audio is None: | |
| return { | |
| "error": "⚠️ Please upload both cough and vowel audio files (1 second each)" | |
| }, None | |
| # ========== TABULAR MODEL PREDICTION ========== | |
| # Prepare tabular data | |
| tabular_data = { | |
| 'age': age if age is not None else np.nan, | |
| 'gender': gender, | |
| 'tbContactHistory': tb_contact, | |
| 'wheezingHistory': wheezing, | |
| 'phlegmCough': phlegm_cough, | |
| 'familyAsthmaHistory': family_asthma, | |
| 'feverHistory': fever if fever is not None else np.nan, | |
| 'coldPresent': cold_present if cold_present is not None else np.nan, | |
| 'packYears': pack_years if pack_years is not None else 0 | |
| } | |
| # Preprocess tabular data | |
| X_tabular = tabular_preprocessor.transform(tabular_data) | |
| # Get CatBoost prediction probabilities | |
| tabular_proba = catboost_model.predict_proba(X_tabular)[0] | |
| tabular_pred = np.argmax(tabular_proba) | |
| print(f"Tabular prediction: {tabular_pred}, proba: {tabular_proba}") | |
| # ========== AUDIO MODEL PREDICTION ========== | |
| # Extract features from both audio files | |
| # Model was trained with CONCAT mode (120 features = 60 cough + 60 vowel) | |
| combined_audio_features = audio_preprocessor.extract_from_both_audios( | |
| cough_audio, vowel_audio, combine_mode="concat" | |
| ) | |
| # Get LSTM prediction | |
| audio_pred, audio_proba = predict_audio( | |
| lstm_model, combined_audio_features, device=device | |
| ) | |
| print(f"Audio prediction: {audio_pred}, proba: {audio_proba}") | |
| # ========== ENSEMBLE PREDICTION ========== | |
| # Weighted average of probabilities | |
| ensemble_proba = (TABULAR_WEIGHT * tabular_proba + | |
| AUDIO_WEIGHT * audio_proba) | |
| # Final prediction | |
| final_pred = np.argmax(ensemble_proba) | |
| # Get disease information | |
| disease_info = get_disease_info(final_pred) | |
| # ========== FORMAT RESULTS ========== | |
| # Create detailed results | |
| result_text = f""" | |
| # 🏥 Respiratory Disease Classification Results | |
| ## 📊 Final Diagnosis: **{disease_info['name']}** | |
| ### Confidence Scores: | |
| - **Healthy**: {ensemble_proba[0]:.1%} | |
| - **COPD**: {ensemble_proba[1]:.1%} | |
| - **Asthma**: {ensemble_proba[2]:.1%} | |
| --- | |
| ## 📝 Description: | |
| {disease_info['description']} | |
| ## 💡 Recommendations: | |
| """ | |
| for i, rec in enumerate(disease_info['recommendations'], 1): | |
| result_text += f"\n{i}. {rec}" | |
| result_text += f""" | |
| --- | |
| ## 🔬 Model Details: | |
| - **Tabular Model (CatBoost)**: Weight = {TABULAR_WEIGHT:.0%} | |
| - Prediction: {['Healthy', 'COPD', 'Asthma'][tabular_pred]} | |
| - Confidence: {tabular_proba[tabular_pred]:.1%} | |
| - **Audio Model (LSTM)**: Weight = {AUDIO_WEIGHT:.0%} | |
| - Prediction: {['Healthy', 'COPD', 'Asthma'][audio_pred]} | |
| - Confidence: {audio_proba[audio_pred]:.1%} | |
| --- | |
| **⚠️ Disclaimer**: This is a predictive model and should not replace professional medical diagnosis. | |
| Please consult with a healthcare professional for proper medical advice. | |
| """ | |
| # Create probability chart data | |
| prob_chart = { | |
| "Healthy": float(ensemble_proba[0]), | |
| "COPD": float(ensemble_proba[1]), | |
| "Asthma": float(ensemble_proba[2]) | |
| } | |
| return result_text, prob_chart | |
| except Exception as e: | |
| error_msg = f"❌ Error during prediction: {str(e)}\n\nPlease check your inputs and try again." | |
| print(f"Prediction error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return error_msg, None | |
| # ========== GRADIO INTERFACE ========== | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Arial', sans-serif; | |
| } | |
| .output-markdown h1 { | |
| color: #2c3e50; | |
| border-bottom: 3px solid #3498db; | |
| padding-bottom: 10px; | |
| } | |
| .output-markdown h2 { | |
| color: #34495e; | |
| margin-top: 20px; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=custom_css, title="Respiratory Disease Classifier") as demo: | |
| gr.Markdown(""" | |
| # 🫁 Respiratory Disease Classification System | |
| This application uses a combination of **tabular medical data** (80%) and **audio recordings** (20%) | |
| to predict respiratory diseases: **Healthy**, **COPD**, or **Asthma**. | |
| ### 📋 Instructions: | |
| 1. Fill in the medical information below | |
| 2. Upload **1-second** cough audio (WAV format) | |
| 3. Upload **1-second** vowel audio (WAV format) | |
| 4. Click **"Predict Disease"** to get results | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📝 Medical Information") | |
| age = gr.Number( | |
| label="Age (years)", | |
| value=None | |
| ) | |
| gender = gr.Radio( | |
| choices=[0, 1], | |
| label="Gender (0=Female, 1=Male)", | |
| value=1 | |
| ) | |
| tb_contact = gr.Radio( | |
| choices=[0, 1], | |
| label="TB Contact History (0=No, 1=Yes)", | |
| value=0 | |
| ) | |
| wheezing = gr.Radio( | |
| choices=[0, 1], | |
| label="Wheezing History (0=No, 1=Yes)", | |
| value=0 | |
| ) | |
| phlegm_cough = gr.Radio( | |
| choices=[0, 1], | |
| label="Phlegm Cough (0=No, 1=Yes)", | |
| value=0 | |
| ) | |
| family_asthma = gr.Radio( | |
| choices=[0, 1], | |
| label="Family Asthma History (0=No, 1=Yes)", | |
| value=0 | |
| ) | |
| fever = gr.Radio( | |
| choices=[0, 1], | |
| label="Fever History (0=No, 1=Yes, optional)", | |
| value=None | |
| ) | |
| cold_present = gr.Radio( | |
| choices=[0, 1], | |
| label="Cold Present (0=No, 1=Yes, optional)", | |
| value=None | |
| ) | |
| pack_years = gr.Number( | |
| label="Pack Years (smoking, optional)", | |
| value=0 | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎤 Audio Recordings (1 second each)") | |
| cough_audio = gr.Audio( | |
| label="Cough Audio (WAV, 1 second)", | |
| type="filepath", | |
| sources=["upload"] | |
| ) | |
| vowel_audio = gr.Audio( | |
| label="Vowel Audio (WAV, 1 second)", | |
| type="filepath", | |
| sources=["upload"] | |
| ) | |
| gr.Markdown(""" | |
| #### 🎵 Audio Recording Tips: | |
| - Record in a quiet environment | |
| - Keep recordings to exactly **1 second** | |
| - Use WAV format for best quality | |
| - For cough: record a clear single cough | |
| - For vowel: sustain an "ahhh" sound | |
| """) | |
| predict_btn = gr.Button( | |
| "🔍 Predict Disease", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| output_text = gr.Markdown(label="Prediction Results") | |
| with gr.Column(scale=1): | |
| output_plot = gr.Label( | |
| label="Probability Distribution", | |
| num_top_classes=3 | |
| ) | |
| # Example inputs | |
| gr.Markdown("### 📌 Example Input") | |
| gr.Examples( | |
| examples=[ | |
| [43, 1, 0, 1, 0, 0, 0, None, 0, None, None], # Example 1 | |
| [24, 0, 0, 1, 0, 0, 0, None, 0, None, None], # Example 2 | |
| ], | |
| inputs=[age, gender, tb_contact, wheezing, phlegm_cough, | |
| family_asthma, fever, cold_present, pack_years, | |
| cough_audio, vowel_audio], | |
| label="Click to load example data (Note: Audio files need to be uploaded separately)" | |
| ) | |
| # Connect prediction function | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[age, gender, tb_contact, wheezing, phlegm_cough, | |
| family_asthma, fever, cold_present, pack_years, | |
| cough_audio, vowel_audio], | |
| outputs=[output_text, output_plot] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### ℹ️ About This System | |
| **Models Used:** | |
| - **CatBoost Classifier** (80% weight): Analyzes tabular medical data | |
| - **LSTM Neural Network** (20% weight): Analyzes audio features (MFCC) | |
| **Classes:** | |
| - **Class 0**: Healthy (No respiratory disease) | |
| - **Class 1**: COPD (Chronic Obstructive Pulmonary Disease) | |
| - **Class 2**: Asthma | |
| **⚠️ Medical Disclaimer:** | |
| This application is for educational and research purposes only. It should NOT be used as a | |
| substitute for professional medical advice, diagnosis, or treatment. Always seek the advice | |
| of qualified health providers with any questions regarding a medical condition. | |
| --- | |
| *Developed with ❤️ using Gradio, PyTorch, and CatBoost* | |
| """) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |