papsofts's picture
Initial commit from local
3612742 verified
"""
Pneumonia Detection System using Deep Learning Ensemble
Capabilities:
- Supports multiple pre-trained model architectures (VGG19, ResNet50, etc.)
- Handles JPEG, PNG, and DICOM chest X-ray images
- Provides both individual model and weighted ensemble predictions
- Includes confidence scoring and detailed model-wise analysis
"""
from secrets import choice
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os
from pathlib import Path
# ----------------------------------------------------------------------------
# Debug Configuration for Development and Troubleshooting
# Enable via CLINICAL_DEBUG environment variable
# ----------------------------------------------------------------------------
DEBUG = os.getenv("CLINICAL_DEBUG", "0") in ("1", "true", "True")
def _dbg(msg):
if DEBUG:
print(f"[DEBUG] {msg}")
# ----------------------------------------------------------------------------
# DICOM Medical Image Support
# Enables support for medical-grade DICOM format chest X-rays
# Requires pydicom package - gracefully degrades if not available
# ----------------------------------------------------------------------------
try:
import pydicom
DICOM_AVAILABLE = True
except ImportError:
DICOM_AVAILABLE = False
# ----------------------------------------------------------------------------
# Neural Network Model Architectures
# Each class implements a specific deep learning architecture
# Modified for binary classification (Normal vs Pneumonia)
# ----------------------------------------------------------------------------
class MobileNetV2Model(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
self.model = models.mobilenet_v2(weights=None)
self.model.classifier[1] = nn.Linear(1280, num_classes)
def forward(self, x):
return self.model(x)
class ResNet50Model(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
self.model = models.resnet50(weights=None)
self.model.fc = nn.Linear(2048, num_classes)
def forward(self, x):
return self.model(x)
class EfficientNetB0Model(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
from torchvision.models import efficientnet_b0
self.model = efficientnet_b0(weights=None)
num_features = self.model.classifier[1].in_features
self.model.classifier[1] = nn.Linear(num_features, num_classes)
def forward(self, x):
return self.model(x)
class VGG19Model(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
self.model = models.vgg19(weights=None)
self.model.classifier[6] = nn.Linear(4096, num_classes)
def forward(self, x):
return self.model(x)
class DenseNet121Model(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
self.model = models.densenet121(weights=None)
self.model.classifier = nn.Linear(1024, num_classes)
def forward(self, x):
return self.model(x)
# ----------------------------------------------------------------------------
# DICOM & Image Processing
# ----------------------------------------------------------------------------
def process_dicom_file(file_obj):
"""Convert DICOM file to a PIL Image."""
if not DICOM_AVAILABLE:
raise ValueError("DICOM support not available. Please install pydicom.")
import pydicom
ds = pydicom.dcmread(file_obj.name if hasattr(file_obj, "name") else file_obj)
pixel_array = ds.pixel_array.astype(np.float32)
# Rescale if present
slope = float(getattr(ds, "RescaleSlope", 1.0))
intercept = float(getattr(ds, "RescaleIntercept", 0.0))
pixel_array = pixel_array * slope + intercept
# Invert MONOCHROME1
if getattr(ds, "PhotometricInterpretation", "").upper() == "MONOCHROME1":
pixel_array = pixel_array.max() - pixel_array
# Normalize
pmin, pmax = pixel_array.min(), pixel_array.max()
pixel_array = (pixel_array - pmin) / max(pmax - pmin, 1e-6)
pixel_array = (pixel_array * 255).clip(0, 255).astype(np.uint8)
image = Image.fromarray(pixel_array, mode="L").convert("RGB")
return image
def process_uploaded_image(file_obj):
"""Handle JPEG, PNG, or DICOM uploads and return PIL Image."""
if file_obj is None:
return None
name = getattr(file_obj, "name", "").lower()
if name.endswith((".dcm", ".dicom")):
return process_dicom_file(file_obj)
image = Image.open(file_obj.name if hasattr(file_obj, "name") else file_obj)
if image.mode != "RGB":
image = image.convert("RGB")
return image
# ----------------------------------------------------------------------------
# Pneumonia Model System
# ----------------------------------------------------------------------------
class PneumoniaModelSystem:
"""
Main orchestrator for pneumonia detection models.
Manages model loading, inference, and ensemble predictions.
Attributes:
device (str): Computation device ('cpu' or 'cuda')
models (dict): Loaded model instances and their metadata
transform (transforms.Compose): Image preprocessing pipeline
"""
def __init__(self, device="cpu"):
self.device = device
self.models = {}
# Standard ImageNet normalization and preprocessing
self.transform = transforms.Compose(
[
transforms.Resize((224, 224)), # Resize to model input size
transforms.ToTensor(), # Convert to tensor (0-1 range)
transforms.Normalize( # ImageNet mean and std
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
self.model_definitions = {
"Model_A1_7CB_Appr_D": {
"architecture": "VGG19",
"file": "Model_A1_7CB_Appr_D.pt",
"description": "VGG19 - Model A1 with 7CB approach D",
},
"Model_C_Appr_B": {
"architecture": "MobileNetV2",
"file": "Model_C_Appr_B.pt",
"description": "MobileNetV2 - Model C with approach B",
},
"Model_F_Appr_B": {
"architecture": "ResNet50",
"file": "Model_F_Appr_B.pt",
"description": "ResNet50 - Model F with approach B",
},
"Model_G_Appr_B": {
"architecture": "EfficientNet-B0",
"file": "Model_G_Appr_B.pt",
"description": "EfficientNet-B0 - Model G with approach B",
},
"Model_H_Appr_B": {
"architecture": "DenseNet121",
"file": "Model_H_Appr_B.pt",
"description": "DenseNet121 - Model H with approach B",
},
}
self.ensemble_weights = {
"Model_A1_7CB_Appr_D": 0.30,
"Model_C_Appr_B": 0.25,
"Model_F_Appr_B": 0.15,
"Model_G_Appr_B": 0.15,
"Model_H_Appr_B": 0.15,
}
def _create_model(self, arch):
return {
"MobileNetV2": MobileNetV2Model,
#"ResNet50": ResNet50Model,
"EfficientNet-B0": EfficientNetB0Model,
"VGG19": VGG19Model,
"DenseNet121": DenseNet121Model,
}[arch](num_classes=2).to(self.device)
def load_models(self, model_dir="models"):
model_dir = Path(model_dir)
for name, info in self.model_definitions.items():
path = model_dir / info["file"]
if not path.exists():
print(f" Model missing: {path}")
continue
try:
model = self._create_model(info["architecture"])
state_dict = torch.load(path, map_location=self.device)
model.load_state_dict(state_dict, strict=False)
model.eval()
self.models[name] = {"model": model, "info": info}
_dbg(f"Loaded {name}")
except Exception as e:
print(f" Failed to load {name}: {e}")
return self
def predict_single_model(self, image, model_name):
model = self.models[model_name]["model"]
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
probs = torch.softmax(model(img_tensor), dim=1)[0].cpu().numpy()
idx = int(probs.argmax())
label = "PNEUMONIA" if idx == 1 else "NORMAL"
return {
"prediction": label,
"confidence": float(probs[idx]),
"pneumonia_probability": float(probs[1]),
"normal_probability": float(probs[0]),
"model_used": model_name,
}
def predict_ensemble(self, image, selected_models=None):
"""
Perform ensemble prediction using selected models
Args:
image: Input image to process
selected_models: List of model names to include in ensemble. If None, use all models.
"""
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
ensemble_probs = torch.zeros(1, 2).to(self.device)
total_weight = 0
individual_predictions = []
with torch.no_grad():
for name, info in self.models.items():
# Skip if model not selected
if selected_models is not None and name not in selected_models:
continue
model = info["model"]
w = self.ensemble_weights.get(name, 1)
probs = torch.softmax(model(img_tensor), dim=1)[0]
ensemble_probs += w * probs
total_weight += w
# Store individual model predictions
prob_np = probs.cpu().numpy()
idx = int(prob_np.argmax())
individual_predictions.append({
"model_name": name,
"architecture": info["info"]["architecture"],
"prediction": "PNEUMONIA" if idx == 1 else "NORMAL",
"confidence": float(prob_np[idx]),
"pneumonia_probability": float(prob_np[1]),
"normal_probability": float(prob_np[0]),
"weight": w
})
if total_weight == 0:
raise ValueError("No models selected for ensemble prediction")
ensemble_probs /= total_weight
probs = ensemble_probs[0].cpu().numpy()
idx = int(probs.argmax())
label = "PNEUMONIA" if idx == 1 else "NORMAL"
return {
"prediction": label,
"confidence": float(probs[idx]),
"pneumonia_probability": float(probs[1]),
"normal_probability": float(probs[0]),
"models_used": [p["model_name"] for p in individual_predictions],
"individual_predictions": individual_predictions
}
# ----------------------------------------------------------------------------
# Initialize model system
# ----------------------------------------------------------------------------
model_system = PneumoniaModelSystem(device="cpu")
try:
model_system.load_models("models")
available_models = model_system.models.keys()
print(f"✅ Loaded models: {list(available_models)}")
except Exception as e:
print(f" Model loading error: {e}")
available_models = []
# ----------------------------------------------------------------------------
# Gradio Inference Function
# ----------------------------------------------------------------------------
def predict_pneumonia(file_path, selected_model_option, selected_ensemble_models=None):
"""
Process uploaded X-ray image and perform pneumonia detection
Args:
file_path: Path to the uploaded image file
selected_model_option: Selected model or ensemble option
selected_ensemble_models: List of model names to include in ensemble prediction
Returns:
tuple: (processed_image, text_result, probability_dict, confidence_html, individual_results_html)
"""
if file_path is None:
return None, "Please upload an X-ray image", {}, "", ""
class FileObj:
def __init__(self, path):
self.name = path
processed_image = process_uploaded_image(FileObj(file_path))
if selected_model_option == "Ensemble (All Models)":
try:
# Use selected models if provided, otherwise use all
result = model_system.predict_ensemble(processed_image, selected_ensemble_models)
model_info = f"Ensemble of {len(result['models_used'])} selected models"
except ValueError as e:
return None, "Please select at least one model for ensemble prediction", {}, "", ""
# Create table for individual model predictions
individual_results_html = """
<div style='margin-top: 20px;'>
<h3>Individual Model Predictions</h3>
<table style='width: 100%; border-collapse: collapse; margin-top: 10px;'>
<tr style='background-color: #f2f2f2;'>
<th style='padding: 8px; border: 1px solid #ddd; text-align: left;'>Model</th>
<th style='padding: 8px; border: 1px solid #ddd; text-align: left;'>Architecture</th>
<th style='padding: 8px; border: 1px solid #ddd; text-align: left;'>Prediction</th>
<th style='padding: 8px; border: 1px solid #ddd; text-align: right;'>Confidence</th>
<th style='padding: 8px; border: 1px solid #ddd; text-align: right;'>Weight</th>
</tr>
"""
for pred in result["individual_predictions"]:
color = "red" if pred["prediction"] == "PNEUMONIA" else "green"
individual_results_html += f"""
<tr>
<td style='padding: 8px; border: 1px solid #ddd;'>{pred["model_name"]}</td>
<td style='padding: 8px; border: 1px solid #ddd;'>{pred["architecture"]}</td>
<td style='padding: 8px; border: 1px solid #ddd; color: {color};'>{pred["prediction"]}</td>
<td style='padding: 8px; border: 1px solid #ddd; text-align: right;'>{pred["confidence"]*100:.1f}%</td>
<td style='padding: 8px; border: 1px solid #ddd; text-align: right;'>{pred["weight"]:.3f}</td>
</tr>
"""
individual_results_html += "</table></div>"
else:
result = model_system.predict_single_model(processed_image, selected_model_option)
model_info = result["model_used"]
individual_results_html = ""
text = f"## Prediction: {result['prediction']}\n\n"
#text += f"**Confidence:** {result['confidence']*100:.2f}%\n\n"
#text += f"**Model Used:** {model_info}\n\n"
if result["prediction"] == "PNEUMONIA":
text += " **Pneumonia detected** – please consult a radiologist."
else:
text += "✓ **No pneumonia detected** – appears normal."
prob_dict = {
"Normal": result["normal_probability"],
"Pneumonia": result["pneumonia_probability"],
}
color = (
"green"
if result["confidence"] >= 0.9
else "blue"
if result["confidence"] >= 0.75
else "orange"
if result["confidence"] >= 0.6
else "red"
)
html = f"""
<div style='padding: 20px; background: #f8f8f8; border-radius: 10px;'>
<h3 style='color:{color};'>Confidence: {result['confidence']*100:.1f}%</h3>
<p>Model used: {model_info}</p>
</div>
"""
return processed_image, text, prob_dict, html, individual_results_html
# ----------------------------------------------------------------------------
# Gradio Web Interface Configuration
# Defines the UI layout and components for the pneumonia detection system
# Includes custom styling and responsive layout design
# ----------------------------------------------------------------------------
custom_css = """
.gradio-container { font-family: 'Arial', sans-serif; }
.output-markdown h2 { color: #2c3e50; }
"""
with gr.Blocks(css=custom_css, title="Pneumonia Detection AI") as demo:
gr.Markdown(
"""
# 🫁 Pneumonia Detection from Chest X-rays
Upload an X-ray image and select a model or the ensemble to analyze.
**Disclaimer:** Research and educational use only.
"""
)
with gr.Row():
with gr.Column():
input_file = gr.File(
label="Upload Chest X-Ray (JPEG, PNG, DICOM)",
file_types=[".jpg", ".jpeg", ".png", ".dcm", ".dicom"],
type="filepath",
)
model_selector = gr.Radio(
choices=["Ensemble (All Models)"] + list(available_models),
value="Ensemble (All Models)",
label="Select Model",
)
# Create containers for ensemble components
with gr.Group(visible=True) as ensemble_options:
gr.Markdown("### Select Models for Ensemble")
ensemble_model_selector = gr.CheckboxGroup(
choices=list(available_models),
value=list(available_models), # All selected by default
label="Models to include in ensemble",
)
predict_btn = gr.Button("Analyze X-Ray", variant="primary")
individual_results = gr.HTML(label="Individual Model Results")
with gr.Column():
preview = gr.Image(label="Processed Image", height=400)
output_text = gr.Markdown(label="Diagnosis Result")
prob_chart = gr.JSON(label="Probability Distribution")
confidence_html = gr.HTML(label="Confidence Level")
# Function to control visibility of ensemble results based on model selection
def update_ensemble_components(choice):
"""Toggle visibility of ensemble-related components"""
is_ensemble = choice == "Ensemble (All Models)"
return [
gr.Box.update(visible=is_ensemble),
gr.Row.update(visible=is_ensemble)
]
# Event handler for model selection changes
model_selector.change(
fn=update_ensemble_components,
inputs=[model_selector],
outputs=[ensemble_options, individual_results],
)
predict_btn.click(
fn=predict_pneumonia,
inputs=[input_file, model_selector, ensemble_model_selector],
outputs=[preview, output_text, prob_chart, confidence_html, individual_results],
)
# ----------------------------------------------------------------------------
# Launch (Hugging Face Spaces compatible)
# ----------------------------------------------------------------------------
if __name__ == "__main__":
demo.launch()