Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from tensorflow.keras.applications import EfficientNetB0 | |
| import cv2 | |
| import pickle | |
| import os | |
| # Import Grad-CAM utilities | |
| from gradcam_utils import ( | |
| make_gradcam_heatmap, | |
| overlay_heatmap_on_image, | |
| get_last_conv_layer_name | |
| ) | |
| # Configuration | |
| IMG_SIZE = 224 | |
| NUM_CLASSES = 14 | |
| # Disease labels | |
| all_diseases = [ | |
| 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', | |
| 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', | |
| 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', | |
| 'Pleural_Thickening', 'Hernia' | |
| ] | |
| def build_model(img_size, num_classes): | |
| """Build the EfficientNetB0 model""" | |
| inputs = layers.Input(shape=(img_size, img_size, 3)) | |
| base_model = EfficientNetB0( | |
| include_top=False, | |
| weights='imagenet', | |
| input_tensor=inputs, | |
| pooling='avg' | |
| ) | |
| base_model.trainable = True | |
| x = base_model.output | |
| x = layers.Dense(512, activation='relu')(x) | |
| x = layers.Dropout(0.3)(x) | |
| x = layers.Dense(256, activation='relu')(x) | |
| x = layers.Dropout(0.2)(x) | |
| outputs = layers.Dense(num_classes, activation='sigmoid', dtype='float32')(x) | |
| model = keras.Model(inputs=inputs, outputs=outputs) | |
| return model | |
| # Load model | |
| print("Loading model...") | |
| model = build_model(IMG_SIZE, NUM_CLASSES) | |
| try: | |
| model.load_weights('best_model.h5') | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"⚠️ Warning: Could not load model weights - {e}") | |
| # Load label encoder | |
| try: | |
| with open('label_encoder.pkl', 'rb') as f: | |
| label_encoder = pickle.load(f) | |
| print("✅ Label encoder loaded!") | |
| except Exception as e: | |
| print(f"⚠️ Creating default label encoder - {e}") | |
| label_encoder = {disease: idx for idx, disease in enumerate(all_diseases)} | |
| # Load optimal thresholds | |
| try: | |
| with open('optimal_thresholds.pkl', 'rb') as f: | |
| optimal_thresholds = pickle.load(f) | |
| print("✅ Optimal thresholds loaded!") | |
| use_optimal_thresholds = True | |
| except Exception as e: | |
| print(f"⚠️ Using default threshold 0.5 - {e}") | |
| optimal_thresholds = {disease: 0.5 for disease in all_diseases} | |
| use_optimal_thresholds = False | |
| # Get last conv layer for Grad-CAM | |
| try: | |
| last_conv_layer = get_last_conv_layer_name(model) | |
| print(f"✅ Grad-CAM layer: {last_conv_layer}") | |
| except: | |
| last_conv_layer = 'top_conv' | |
| print(f"⚠️ Using default Grad-CAM layer: {last_conv_layer}") | |
| def preprocess_image(image): | |
| """Preprocess image for prediction""" | |
| if image is None: | |
| return None | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image.astype('uint8')) | |
| image = image.convert('RGB') | |
| image = image.resize((IMG_SIZE, IMG_SIZE)) | |
| img_array = np.array(image) / 255.0 | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| def predict_with_tta(image, n_augmentations=3): | |
| """Perform Test-Time Augmentation for more robust predictions""" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image.astype('uint8')) | |
| image = image.convert('RGB') | |
| image = image.resize((IMG_SIZE, IMG_SIZE)) | |
| predictions = [] | |
| # Original image | |
| img_array = np.array(image) / 255.0 | |
| img_array = np.expand_dims(img_array, axis=0) | |
| pred = model.predict(img_array, verbose=0) | |
| predictions.append(pred[0]) | |
| # Augmented versions | |
| for _ in range(n_augmentations): | |
| aug_img = image.transpose(Image.FLIP_LEFT_RIGHT) if np.random.random() > 0.5 else image | |
| angle = np.random.uniform(-10, 10) | |
| aug_img = aug_img.rotate(angle, fillcolor=(0, 0, 0)) | |
| from PIL import ImageEnhance | |
| enhancer = ImageEnhance.Brightness(aug_img) | |
| aug_img = enhancer.enhance(np.random.uniform(0.9, 1.1)) | |
| img_array = np.array(aug_img) / 255.0 | |
| img_array = np.expand_dims(img_array, axis=0) | |
| pred = model.predict(img_array, verbose=0) | |
| predictions.append(pred[0]) | |
| mean_pred = np.mean(predictions, axis=0) | |
| std_pred = np.std(predictions, axis=0) | |
| return mean_pred, std_pred | |
| def generate_gradcam(image, disease_idx): | |
| """Generate improved Grad-CAM visualization for specific disease""" | |
| if isinstance(image, np.ndarray): | |
| img_pil = Image.fromarray(image.astype('uint8')) | |
| else: | |
| img_pil = image | |
| img_resized = img_pil.convert('RGB').resize((IMG_SIZE, IMG_SIZE)) | |
| img_array = np.array(img_resized) / 255.0 | |
| img_array = np.expand_dims(img_array, axis=0).astype(np.float32) | |
| # Generate improved heatmap with noise reduction | |
| heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer, disease_idx) | |
| # Overlay with better alpha for medical images | |
| overlaid_image = overlay_heatmap_on_image(img_resized, heatmap, alpha=0.5) | |
| return overlaid_image | |
| def predict(image, use_tta, use_thresholds, show_gradcam, top_k=5): | |
| """Main prediction function with Grad-CAM support""" | |
| if image is None: | |
| return "⚠️ Please upload an image first.", {}, None, "" | |
| try: | |
| # Get predictions | |
| if use_tta: | |
| predictions, std = predict_with_tta(image, n_augmentations=3) | |
| tta_text = "\n\n*✅ Using Test-Time Augmentation (TTA)*" | |
| else: | |
| img_array = preprocess_image(image) | |
| predictions = model.predict(img_array, verbose=0)[0] | |
| std = np.zeros_like(predictions) | |
| tta_text = "" | |
| # Apply optimal thresholds if requested | |
| if use_thresholds and use_optimal_thresholds: | |
| threshold_text = "\n*✅ Using optimal thresholds for classification*" | |
| else: | |
| threshold_text = "\n*Using default threshold: 0.5*" | |
| # Get top K predictions | |
| top_indices = np.argsort(predictions)[::-1][:top_k] | |
| # Create results | |
| results = {} | |
| result_text = f"## 🏥 Prediction Results (Top {top_k})\n\n" | |
| gradcam_images = [] | |
| for i, idx in enumerate(top_indices, 1): | |
| disease = all_diseases[idx] | |
| prob = float(predictions[idx]) | |
| percentage = prob * 100 | |
| results[disease] = prob | |
| # Determine if positive using optimal threshold | |
| if use_thresholds and use_optimal_thresholds: | |
| threshold = optimal_thresholds.get(disease, 0.5) | |
| is_positive = prob >= threshold | |
| status = "✅ POSITIVE" if is_positive else "❌ NEGATIVE" | |
| else: | |
| threshold = 0.5 | |
| is_positive = prob >= 0.5 | |
| status = "✅ POSITIVE" if is_positive else "❌ NEGATIVE" | |
| # Confidence indicator | |
| if percentage > 70: | |
| confidence = "🔴 High" | |
| elif percentage > 40: | |
| confidence = "🟡 Medium" | |
| else: | |
| confidence = "🟢 Low" | |
| result_text += f"**{i}. {disease}** {status}\n" | |
| result_text += f" - Probability: **{percentage:.2f}%**\n" | |
| result_text += f" - Threshold: {threshold:.3f}\n" | |
| result_text += f" - Confidence: {confidence}\n" | |
| if use_tta: | |
| result_text += f" - Uncertainty (±std): {std[idx]*100:.2f}%\n" | |
| result_text += "\n" | |
| # Generate Grad-CAM for top 3 if requested | |
| if show_gradcam and i <= 3: | |
| gradcam_img = generate_gradcam(image, idx) | |
| gradcam_images.append(gradcam_img) | |
| result_text += threshold_text | |
| result_text += tta_text | |
| result_text += "\n\n---\n\n*⚠️ **Medical Disclaimer:** This is an AI tool for educational purposes only. NOT for clinical diagnosis.*" | |
| # Prepare Grad-CAM output | |
| if show_gradcam and gradcam_images: | |
| gradcam_gallery = gradcam_images | |
| gradcam_text = f"## 🔥 Grad-CAM Visualizations\n\nShowing attention maps for top {len(gradcam_images)} predictions.\n\n**Red areas** = High attention (model focuses here)\n**Blue areas** = Low attention" | |
| else: | |
| gradcam_gallery = None | |
| gradcam_text = "" | |
| return result_text, results, gradcam_gallery, gradcam_text | |
| except Exception as e: | |
| error_msg = f"❌ Error: {str(e)}" | |
| return error_msg, {}, None, "" | |
| # Custom CSS | |
| custom_css = """ | |
| .container { | |
| max-width: 1400px; | |
| margin: auto; | |
| } | |
| .gradio-container { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| border: none; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=custom_css, title="Medical Image Classifier") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🏥 Medical Image Classification System | |
| ### AI-Powered Disease Detection with Grad-CAM Visualization | |
| Upload a chest X-ray to detect 14 thoracic diseases using EfficientNetB0 + Grad-CAM explainability. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="📁 Upload Chest X-ray Image", | |
| type="numpy" | |
| ) | |
| gr.Markdown("### ⚙️ Options") | |
| use_tta = gr.Checkbox( | |
| label="🔄 Test-Time Augmentation (TTA)", | |
| value=False, | |
| info="More accurate but 3-4x slower" | |
| ) | |
| use_thresholds = gr.Checkbox( | |
| label="🎯 Use Optimal Thresholds", | |
| value=True, | |
| info="Disease-specific thresholds for better classification" | |
| ) | |
| show_gradcam = gr.Checkbox( | |
| label="🔥 Show Grad-CAM Visualization", | |
| value=True, | |
| info="Visual explanation of model predictions" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=14, | |
| value=5, | |
| step=1, | |
| label="📊 Number of predictions to show" | |
| ) | |
| predict_btn = gr.Button("🔍 Analyze Image", variant="primary", size="lg") | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### 📋 Detectable Conditions | |
| **Lung Conditions:** | |
| - Atelectasis, Pneumonia, Pneumothorax | |
| - Consolidation, Infiltration, Emphysema | |
| **Cardiac:** | |
| - Cardiomegaly, Edema | |
| **Abnormal Growths:** | |
| - Mass, Nodule, Fibrosis | |
| **Others:** | |
| - Effusion, Pleural Thickening, Hernia | |
| """ | |
| ) | |
| with gr.Column(scale=1): | |
| result_text = gr.Markdown(label="📊 Analysis Results") | |
| result_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| gradcam_text = gr.Markdown(label="Grad-CAM Info") | |
| gradcam_gallery = gr.Gallery( | |
| label="🔥 Grad-CAM Heatmaps", | |
| columns=3, | |
| height="auto" | |
| ) | |
| # Examples section | |
| gr.Markdown("### 📸 Example Images") | |
| if os.path.exists('example_1.jpg') and os.path.exists('example_2.jpg'): | |
| gr.Examples( | |
| examples=[ | |
| ["example_1.jpg", False, True, True, 5], | |
| ["example_2.jpg", True, True, True, 5], | |
| ], | |
| inputs=[image_input, use_tta, use_thresholds, show_gradcam, top_k], | |
| outputs=[result_text, result_plot, gradcam_gallery, gradcam_text], | |
| fn=predict, | |
| cache_examples=False | |
| ) | |
| # Button click event | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[image_input, use_tta, use_thresholds, show_gradcam, top_k], | |
| outputs=[result_text, result_plot, gradcam_gallery, gradcam_text] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ## ℹ️ About This System | |
| ### 🧠 Model Architecture | |
| - **Base:** EfficientNetB0 (ImageNet pre-trained) | |
| - **Custom Layers:** Dense(512) → Dropout → Dense(256) → Dropout → Output(14) | |
| - **Loss:** Binary Focal Cross-Entropy (α=0.25, γ=2.0) | |
| - **Training:** Full fine-tuning @ lr=1e-5 | |
| ### 🎯 Optimal Thresholds | |
| Disease-specific thresholds optimized for F1-score, providing better balance between | |
| precision and recall compared to the default 0.5 threshold. | |
| ### 🔥 Grad-CAM Visualization | |
| **Gradient-weighted Class Activation Mapping** shows which regions of the X-ray | |
| the model focuses on when making predictions. Red areas indicate high attention. | |
| **Reference:** Selvaraju et al. (2017) - Grad-CAM: Visual Explanations from Deep Networks | |
| ### 🔄 Test-Time Augmentation | |
| - Processes original + 3 augmented versions | |
| - Augmentations: horizontal flip, rotation (±10°), brightness (0.9-1.1x) | |
| - Final prediction = average of all versions | |
| - Provides uncertainty estimates via standard deviation | |
| --- | |
| ## ⚠️ Medical Disclaimer | |
| **FOR EDUCATIONAL USE ONLY - NOT FOR CLINICAL DIAGNOSIS** | |
| - ❌ Not FDA approved or clinically validated | |
| - ❌ Not a substitute for professional medical diagnosis | |
| - ✅ For research and educational purposes only | |
| - ✅ Always consult qualified healthcare professionals | |
| --- | |
| ### 🔧 Technical Stack | |
| - TensorFlow 2.10.0 | Gradio 3.48.0 | |
| - Model: ~30M parameters | Inference: ~0.5-2s | |
| **Built with ❤️ for AI in Healthcare** | |
| """ | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |