import gradio as gr import numpy as np from PIL import Image import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.applications import EfficientNetB0 import cv2 import pickle import os # Import Grad-CAM utilities from gradcam_utils import ( make_gradcam_heatmap, overlay_heatmap_on_image, get_last_conv_layer_name ) # Configuration IMG_SIZE = 224 NUM_CLASSES = 14 # Disease labels all_diseases = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia' ] def build_model(img_size, num_classes): """Build the EfficientNetB0 model""" inputs = layers.Input(shape=(img_size, img_size, 3)) base_model = EfficientNetB0( include_top=False, weights='imagenet', input_tensor=inputs, pooling='avg' ) base_model.trainable = True x = base_model.output x = layers.Dense(512, activation='relu')(x) x = layers.Dropout(0.3)(x) x = layers.Dense(256, activation='relu')(x) x = layers.Dropout(0.2)(x) outputs = layers.Dense(num_classes, activation='sigmoid', dtype='float32')(x) model = keras.Model(inputs=inputs, outputs=outputs) return model # Load model print("Loading model...") model = build_model(IMG_SIZE, NUM_CLASSES) try: model.load_weights('best_model.h5') print("✅ Model loaded successfully!") except Exception as e: print(f"⚠️ Warning: Could not load model weights - {e}") # Load label encoder try: with open('label_encoder.pkl', 'rb') as f: label_encoder = pickle.load(f) print("✅ Label encoder loaded!") except Exception as e: print(f"⚠️ Creating default label encoder - {e}") label_encoder = {disease: idx for idx, disease in enumerate(all_diseases)} # Load optimal thresholds try: with open('optimal_thresholds.pkl', 'rb') as f: optimal_thresholds = pickle.load(f) print("✅ Optimal thresholds loaded!") use_optimal_thresholds = True except Exception as e: print(f"⚠️ Using default threshold 0.5 - {e}") optimal_thresholds = {disease: 0.5 for disease in all_diseases} use_optimal_thresholds = False # Get last conv layer for Grad-CAM try: last_conv_layer = get_last_conv_layer_name(model) print(f"✅ Grad-CAM layer: {last_conv_layer}") except: last_conv_layer = 'top_conv' print(f"⚠️ Using default Grad-CAM layer: {last_conv_layer}") def preprocess_image(image): """Preprocess image for prediction""" if image is None: return None if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8')) image = image.convert('RGB') image = image.resize((IMG_SIZE, IMG_SIZE)) img_array = np.array(image) / 255.0 img_array = np.expand_dims(img_array, axis=0) return img_array def predict_with_tta(image, n_augmentations=3): """Perform Test-Time Augmentation for more robust predictions""" if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8')) image = image.convert('RGB') image = image.resize((IMG_SIZE, IMG_SIZE)) predictions = [] # Original image img_array = np.array(image) / 255.0 img_array = np.expand_dims(img_array, axis=0) pred = model.predict(img_array, verbose=0) predictions.append(pred[0]) # Augmented versions for _ in range(n_augmentations): aug_img = image.transpose(Image.FLIP_LEFT_RIGHT) if np.random.random() > 0.5 else image angle = np.random.uniform(-10, 10) aug_img = aug_img.rotate(angle, fillcolor=(0, 0, 0)) from PIL import ImageEnhance enhancer = ImageEnhance.Brightness(aug_img) aug_img = enhancer.enhance(np.random.uniform(0.9, 1.1)) img_array = np.array(aug_img) / 255.0 img_array = np.expand_dims(img_array, axis=0) pred = model.predict(img_array, verbose=0) predictions.append(pred[0]) mean_pred = np.mean(predictions, axis=0) std_pred = np.std(predictions, axis=0) return mean_pred, std_pred def generate_gradcam(image, disease_idx): """Generate improved Grad-CAM visualization for specific disease""" if isinstance(image, np.ndarray): img_pil = Image.fromarray(image.astype('uint8')) else: img_pil = image img_resized = img_pil.convert('RGB').resize((IMG_SIZE, IMG_SIZE)) img_array = np.array(img_resized) / 255.0 img_array = np.expand_dims(img_array, axis=0).astype(np.float32) # Generate improved heatmap with noise reduction heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer, disease_idx) # Overlay with better alpha for medical images overlaid_image = overlay_heatmap_on_image(img_resized, heatmap, alpha=0.5) return overlaid_image def predict(image, use_tta, use_thresholds, show_gradcam, top_k=5): """Main prediction function with Grad-CAM support""" if image is None: return "⚠️ Please upload an image first.", {}, None, "" try: # Get predictions if use_tta: predictions, std = predict_with_tta(image, n_augmentations=3) tta_text = "\n\n*✅ Using Test-Time Augmentation (TTA)*" else: img_array = preprocess_image(image) predictions = model.predict(img_array, verbose=0)[0] std = np.zeros_like(predictions) tta_text = "" # Apply optimal thresholds if requested if use_thresholds and use_optimal_thresholds: threshold_text = "\n*✅ Using optimal thresholds for classification*" else: threshold_text = "\n*Using default threshold: 0.5*" # Get top K predictions top_indices = np.argsort(predictions)[::-1][:top_k] # Create results results = {} result_text = f"## 🏥 Prediction Results (Top {top_k})\n\n" gradcam_images = [] for i, idx in enumerate(top_indices, 1): disease = all_diseases[idx] prob = float(predictions[idx]) percentage = prob * 100 results[disease] = prob # Determine if positive using optimal threshold if use_thresholds and use_optimal_thresholds: threshold = optimal_thresholds.get(disease, 0.5) is_positive = prob >= threshold status = "✅ POSITIVE" if is_positive else "❌ NEGATIVE" else: threshold = 0.5 is_positive = prob >= 0.5 status = "✅ POSITIVE" if is_positive else "❌ NEGATIVE" # Confidence indicator if percentage > 70: confidence = "🔴 High" elif percentage > 40: confidence = "🟡 Medium" else: confidence = "🟢 Low" result_text += f"**{i}. {disease}** {status}\n" result_text += f" - Probability: **{percentage:.2f}%**\n" result_text += f" - Threshold: {threshold:.3f}\n" result_text += f" - Confidence: {confidence}\n" if use_tta: result_text += f" - Uncertainty (±std): {std[idx]*100:.2f}%\n" result_text += "\n" # Generate Grad-CAM for top 3 if requested if show_gradcam and i <= 3: gradcam_img = generate_gradcam(image, idx) gradcam_images.append(gradcam_img) result_text += threshold_text result_text += tta_text result_text += "\n\n---\n\n*⚠️ **Medical Disclaimer:** This is an AI tool for educational purposes only. NOT for clinical diagnosis.*" # Prepare Grad-CAM output if show_gradcam and gradcam_images: gradcam_gallery = gradcam_images gradcam_text = f"## 🔥 Grad-CAM Visualizations\n\nShowing attention maps for top {len(gradcam_images)} predictions.\n\n**Red areas** = High attention (model focuses here)\n**Blue areas** = Low attention" else: gradcam_gallery = None gradcam_text = "" return result_text, results, gradcam_gallery, gradcam_text except Exception as e: error_msg = f"❌ Error: {str(e)}" return error_msg, {}, None, "" # Custom CSS custom_css = """ .container { max-width: 1400px; margin: auto; } .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .gr-button-primary { background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); border: none; } """ # Create Gradio interface with gr.Blocks(css=custom_css, title="Medical Image Classifier") as demo: gr.Markdown( """ # 🏥 Medical Image Classification System ### AI-Powered Disease Detection with Grad-CAM Visualization Upload a chest X-ray to detect 14 thoracic diseases using EfficientNetB0 + Grad-CAM explainability. """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( label="📁 Upload Chest X-ray Image", type="numpy" ) gr.Markdown("### ⚙️ Options") use_tta = gr.Checkbox( label="🔄 Test-Time Augmentation (TTA)", value=False, info="More accurate but 3-4x slower" ) use_thresholds = gr.Checkbox( label="🎯 Use Optimal Thresholds", value=True, info="Disease-specific thresholds for better classification" ) show_gradcam = gr.Checkbox( label="🔥 Show Grad-CAM Visualization", value=True, info="Visual explanation of model predictions" ) top_k = gr.Slider( minimum=1, maximum=14, value=5, step=1, label="📊 Number of predictions to show" ) predict_btn = gr.Button("🔍 Analyze Image", variant="primary", size="lg") gr.Markdown( """ --- ### 📋 Detectable Conditions **Lung Conditions:** - Atelectasis, Pneumonia, Pneumothorax - Consolidation, Infiltration, Emphysema **Cardiac:** - Cardiomegaly, Edema **Abnormal Growths:** - Mass, Nodule, Fibrosis **Others:** - Effusion, Pleural Thickening, Hernia """ ) with gr.Column(scale=1): result_text = gr.Markdown(label="📊 Analysis Results") result_plot = gr.Label(label="Probability Distribution", num_top_classes=10) gradcam_text = gr.Markdown(label="Grad-CAM Info") gradcam_gallery = gr.Gallery( label="🔥 Grad-CAM Heatmaps", columns=3, height="auto" ) # Examples section gr.Markdown("### 📸 Example Images") if os.path.exists('example_1.jpg') and os.path.exists('example_2.jpg'): gr.Examples( examples=[ ["example_1.jpg", False, True, True, 5], ["example_2.jpg", True, True, True, 5], ], inputs=[image_input, use_tta, use_thresholds, show_gradcam, top_k], outputs=[result_text, result_plot, gradcam_gallery, gradcam_text], fn=predict, cache_examples=False ) # Button click event predict_btn.click( fn=predict, inputs=[image_input, use_tta, use_thresholds, show_gradcam, top_k], outputs=[result_text, result_plot, gradcam_gallery, gradcam_text] ) gr.Markdown( """ --- ## ℹ️ About This System ### 🧠 Model Architecture - **Base:** EfficientNetB0 (ImageNet pre-trained) - **Custom Layers:** Dense(512) → Dropout → Dense(256) → Dropout → Output(14) - **Loss:** Binary Focal Cross-Entropy (α=0.25, γ=2.0) - **Training:** Full fine-tuning @ lr=1e-5 ### 🎯 Optimal Thresholds Disease-specific thresholds optimized for F1-score, providing better balance between precision and recall compared to the default 0.5 threshold. ### 🔥 Grad-CAM Visualization **Gradient-weighted Class Activation Mapping** shows which regions of the X-ray the model focuses on when making predictions. Red areas indicate high attention. **Reference:** Selvaraju et al. (2017) - Grad-CAM: Visual Explanations from Deep Networks ### 🔄 Test-Time Augmentation - Processes original + 3 augmented versions - Augmentations: horizontal flip, rotation (±10°), brightness (0.9-1.1x) - Final prediction = average of all versions - Provides uncertainty estimates via standard deviation --- ## ⚠️ Medical Disclaimer **FOR EDUCATIONAL USE ONLY - NOT FOR CLINICAL DIAGNOSIS** - ❌ Not FDA approved or clinically validated - ❌ Not a substitute for professional medical diagnosis - ✅ For research and educational purposes only - ✅ Always consult qualified healthcare professionals --- ### 🔧 Technical Stack - TensorFlow 2.10.0 | Gradio 3.48.0 - Model: ~30M parameters | Inference: ~0.5-2s **Built with ❤️ for AI in Healthcare** """ ) # Launch if __name__ == "__main__": demo.launch( share=False, server_name="0.0.0.0", server_port=7860, show_error=True )