emiraran's picture
Update app.py
d426f60 verified
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import EfficientNetB0
import cv2
import pickle
import os
# Import Grad-CAM utilities
from gradcam_utils import (
make_gradcam_heatmap,
overlay_heatmap_on_image,
get_last_conv_layer_name
)
# Configuration
IMG_SIZE = 224
NUM_CLASSES = 14
# Disease labels
all_diseases = [
'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax',
'Consolidation', 'Edema', 'Emphysema', 'Fibrosis',
'Pleural_Thickening', 'Hernia'
]
def build_model(img_size, num_classes):
"""Build the EfficientNetB0 model"""
inputs = layers.Input(shape=(img_size, img_size, 3))
base_model = EfficientNetB0(
include_top=False,
weights='imagenet',
input_tensor=inputs,
pooling='avg'
)
base_model.trainable = True
x = base_model.output
x = layers.Dense(512, activation='relu')(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(num_classes, activation='sigmoid', dtype='float32')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
# Load model
print("Loading model...")
model = build_model(IMG_SIZE, NUM_CLASSES)
try:
model.load_weights('best_model.h5')
print("✅ Model loaded successfully!")
except Exception as e:
print(f"⚠️ Warning: Could not load model weights - {e}")
# Load label encoder
try:
with open('label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
print("✅ Label encoder loaded!")
except Exception as e:
print(f"⚠️ Creating default label encoder - {e}")
label_encoder = {disease: idx for idx, disease in enumerate(all_diseases)}
# Load optimal thresholds
try:
with open('optimal_thresholds.pkl', 'rb') as f:
optimal_thresholds = pickle.load(f)
print("✅ Optimal thresholds loaded!")
use_optimal_thresholds = True
except Exception as e:
print(f"⚠️ Using default threshold 0.5 - {e}")
optimal_thresholds = {disease: 0.5 for disease in all_diseases}
use_optimal_thresholds = False
# Get last conv layer for Grad-CAM
try:
last_conv_layer = get_last_conv_layer_name(model)
print(f"✅ Grad-CAM layer: {last_conv_layer}")
except:
last_conv_layer = 'top_conv'
print(f"⚠️ Using default Grad-CAM layer: {last_conv_layer}")
def preprocess_image(image):
"""Preprocess image for prediction"""
if image is None:
return None
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
image = image.convert('RGB')
image = image.resize((IMG_SIZE, IMG_SIZE))
img_array = np.array(image) / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
def predict_with_tta(image, n_augmentations=3):
"""Perform Test-Time Augmentation for more robust predictions"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
image = image.convert('RGB')
image = image.resize((IMG_SIZE, IMG_SIZE))
predictions = []
# Original image
img_array = np.array(image) / 255.0
img_array = np.expand_dims(img_array, axis=0)
pred = model.predict(img_array, verbose=0)
predictions.append(pred[0])
# Augmented versions
for _ in range(n_augmentations):
aug_img = image.transpose(Image.FLIP_LEFT_RIGHT) if np.random.random() > 0.5 else image
angle = np.random.uniform(-10, 10)
aug_img = aug_img.rotate(angle, fillcolor=(0, 0, 0))
from PIL import ImageEnhance
enhancer = ImageEnhance.Brightness(aug_img)
aug_img = enhancer.enhance(np.random.uniform(0.9, 1.1))
img_array = np.array(aug_img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
pred = model.predict(img_array, verbose=0)
predictions.append(pred[0])
mean_pred = np.mean(predictions, axis=0)
std_pred = np.std(predictions, axis=0)
return mean_pred, std_pred
def generate_gradcam(image, disease_idx):
"""Generate improved Grad-CAM visualization for specific disease"""
if isinstance(image, np.ndarray):
img_pil = Image.fromarray(image.astype('uint8'))
else:
img_pil = image
img_resized = img_pil.convert('RGB').resize((IMG_SIZE, IMG_SIZE))
img_array = np.array(img_resized) / 255.0
img_array = np.expand_dims(img_array, axis=0).astype(np.float32)
# Generate improved heatmap with noise reduction
heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer, disease_idx)
# Overlay with better alpha for medical images
overlaid_image = overlay_heatmap_on_image(img_resized, heatmap, alpha=0.5)
return overlaid_image
def predict(image, use_tta, use_thresholds, show_gradcam, top_k=5):
"""Main prediction function with Grad-CAM support"""
if image is None:
return "⚠️ Please upload an image first.", {}, None, ""
try:
# Get predictions
if use_tta:
predictions, std = predict_with_tta(image, n_augmentations=3)
tta_text = "\n\n*✅ Using Test-Time Augmentation (TTA)*"
else:
img_array = preprocess_image(image)
predictions = model.predict(img_array, verbose=0)[0]
std = np.zeros_like(predictions)
tta_text = ""
# Apply optimal thresholds if requested
if use_thresholds and use_optimal_thresholds:
threshold_text = "\n*✅ Using optimal thresholds for classification*"
else:
threshold_text = "\n*Using default threshold: 0.5*"
# Get top K predictions
top_indices = np.argsort(predictions)[::-1][:top_k]
# Create results
results = {}
result_text = f"## 🏥 Prediction Results (Top {top_k})\n\n"
gradcam_images = []
for i, idx in enumerate(top_indices, 1):
disease = all_diseases[idx]
prob = float(predictions[idx])
percentage = prob * 100
results[disease] = prob
# Determine if positive using optimal threshold
if use_thresholds and use_optimal_thresholds:
threshold = optimal_thresholds.get(disease, 0.5)
is_positive = prob >= threshold
status = "✅ POSITIVE" if is_positive else "❌ NEGATIVE"
else:
threshold = 0.5
is_positive = prob >= 0.5
status = "✅ POSITIVE" if is_positive else "❌ NEGATIVE"
# Confidence indicator
if percentage > 70:
confidence = "🔴 High"
elif percentage > 40:
confidence = "🟡 Medium"
else:
confidence = "🟢 Low"
result_text += f"**{i}. {disease}** {status}\n"
result_text += f" - Probability: **{percentage:.2f}%**\n"
result_text += f" - Threshold: {threshold:.3f}\n"
result_text += f" - Confidence: {confidence}\n"
if use_tta:
result_text += f" - Uncertainty (±std): {std[idx]*100:.2f}%\n"
result_text += "\n"
# Generate Grad-CAM for top 3 if requested
if show_gradcam and i <= 3:
gradcam_img = generate_gradcam(image, idx)
gradcam_images.append(gradcam_img)
result_text += threshold_text
result_text += tta_text
result_text += "\n\n---\n\n*⚠️ **Medical Disclaimer:** This is an AI tool for educational purposes only. NOT for clinical diagnosis.*"
# Prepare Grad-CAM output
if show_gradcam and gradcam_images:
gradcam_gallery = gradcam_images
gradcam_text = f"## 🔥 Grad-CAM Visualizations\n\nShowing attention maps for top {len(gradcam_images)} predictions.\n\n**Red areas** = High attention (model focuses here)\n**Blue areas** = Low attention"
else:
gradcam_gallery = None
gradcam_text = ""
return result_text, results, gradcam_gallery, gradcam_text
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
return error_msg, {}, None, ""
# Custom CSS
custom_css = """
.container {
max-width: 1400px;
margin: auto;
}
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button-primary {
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
border: none;
}
"""
# Create Gradio interface
with gr.Blocks(css=custom_css, title="Medical Image Classifier") as demo:
gr.Markdown(
"""
# 🏥 Medical Image Classification System
### AI-Powered Disease Detection with Grad-CAM Visualization
Upload a chest X-ray to detect 14 thoracic diseases using EfficientNetB0 + Grad-CAM explainability.
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="📁 Upload Chest X-ray Image",
type="numpy"
)
gr.Markdown("### ⚙️ Options")
use_tta = gr.Checkbox(
label="🔄 Test-Time Augmentation (TTA)",
value=False,
info="More accurate but 3-4x slower"
)
use_thresholds = gr.Checkbox(
label="🎯 Use Optimal Thresholds",
value=True,
info="Disease-specific thresholds for better classification"
)
show_gradcam = gr.Checkbox(
label="🔥 Show Grad-CAM Visualization",
value=True,
info="Visual explanation of model predictions"
)
top_k = gr.Slider(
minimum=1,
maximum=14,
value=5,
step=1,
label="📊 Number of predictions to show"
)
predict_btn = gr.Button("🔍 Analyze Image", variant="primary", size="lg")
gr.Markdown(
"""
---
### 📋 Detectable Conditions
**Lung Conditions:**
- Atelectasis, Pneumonia, Pneumothorax
- Consolidation, Infiltration, Emphysema
**Cardiac:**
- Cardiomegaly, Edema
**Abnormal Growths:**
- Mass, Nodule, Fibrosis
**Others:**
- Effusion, Pleural Thickening, Hernia
"""
)
with gr.Column(scale=1):
result_text = gr.Markdown(label="📊 Analysis Results")
result_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
gradcam_text = gr.Markdown(label="Grad-CAM Info")
gradcam_gallery = gr.Gallery(
label="🔥 Grad-CAM Heatmaps",
columns=3,
height="auto"
)
# Examples section
gr.Markdown("### 📸 Example Images")
if os.path.exists('example_1.jpg') and os.path.exists('example_2.jpg'):
gr.Examples(
examples=[
["example_1.jpg", False, True, True, 5],
["example_2.jpg", True, True, True, 5],
],
inputs=[image_input, use_tta, use_thresholds, show_gradcam, top_k],
outputs=[result_text, result_plot, gradcam_gallery, gradcam_text],
fn=predict,
cache_examples=False
)
# Button click event
predict_btn.click(
fn=predict,
inputs=[image_input, use_tta, use_thresholds, show_gradcam, top_k],
outputs=[result_text, result_plot, gradcam_gallery, gradcam_text]
)
gr.Markdown(
"""
---
## ℹ️ About This System
### 🧠 Model Architecture
- **Base:** EfficientNetB0 (ImageNet pre-trained)
- **Custom Layers:** Dense(512) → Dropout → Dense(256) → Dropout → Output(14)
- **Loss:** Binary Focal Cross-Entropy (α=0.25, γ=2.0)
- **Training:** Full fine-tuning @ lr=1e-5
### 🎯 Optimal Thresholds
Disease-specific thresholds optimized for F1-score, providing better balance between
precision and recall compared to the default 0.5 threshold.
### 🔥 Grad-CAM Visualization
**Gradient-weighted Class Activation Mapping** shows which regions of the X-ray
the model focuses on when making predictions. Red areas indicate high attention.
**Reference:** Selvaraju et al. (2017) - Grad-CAM: Visual Explanations from Deep Networks
### 🔄 Test-Time Augmentation
- Processes original + 3 augmented versions
- Augmentations: horizontal flip, rotation (±10°), brightness (0.9-1.1x)
- Final prediction = average of all versions
- Provides uncertainty estimates via standard deviation
---
## ⚠️ Medical Disclaimer
**FOR EDUCATIONAL USE ONLY - NOT FOR CLINICAL DIAGNOSIS**
- ❌ Not FDA approved or clinically validated
- ❌ Not a substitute for professional medical diagnosis
- ✅ For research and educational purposes only
- ✅ Always consult qualified healthcare professionals
---
### 🔧 Technical Stack
- TensorFlow 2.10.0 | Gradio 3.48.0
- Model: ~30M parameters | Inference: ~0.5-2s
**Built with ❤️ for AI in Healthcare**
"""
)
# Launch
if __name__ == "__main__":
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)