Spaces:
Runtime error
Runtime error
| """ | |
| Pneumonia Detection System using Deep Learning Ensemble | |
| Capabilities: | |
| - Supports multiple pre-trained model architectures (VGG19, ResNet50, etc.) | |
| - Handles JPEG, PNG, and DICOM chest X-ray images | |
| - Provides both individual model and weighted ensemble predictions | |
| - Includes confidence scoring and detailed model-wise analysis | |
| """ | |
| from secrets import choice | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| from pathlib import Path | |
| # ---------------------------------------------------------------------------- | |
| # Debug Configuration for Development and Troubleshooting | |
| # Enable via CLINICAL_DEBUG environment variable | |
| # ---------------------------------------------------------------------------- | |
| DEBUG = os.getenv("CLINICAL_DEBUG", "0") in ("1", "true", "True") | |
| def _dbg(msg): | |
| if DEBUG: | |
| print(f"[DEBUG] {msg}") | |
| # ---------------------------------------------------------------------------- | |
| # DICOM Medical Image Support | |
| # Enables support for medical-grade DICOM format chest X-rays | |
| # Requires pydicom package - gracefully degrades if not available | |
| # ---------------------------------------------------------------------------- | |
| try: | |
| import pydicom | |
| DICOM_AVAILABLE = True | |
| except ImportError: | |
| DICOM_AVAILABLE = False | |
| # ---------------------------------------------------------------------------- | |
| # Neural Network Model Architectures | |
| # Each class implements a specific deep learning architecture | |
| # Modified for binary classification (Normal vs Pneumonia) | |
| # ---------------------------------------------------------------------------- | |
| class MobileNetV2Model(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| self.model = models.mobilenet_v2(weights=None) | |
| self.model.classifier[1] = nn.Linear(1280, num_classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class ResNet50Model(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| self.model = models.resnet50(weights=None) | |
| self.model.fc = nn.Linear(2048, num_classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class EfficientNetB0Model(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| from torchvision.models import efficientnet_b0 | |
| self.model = efficientnet_b0(weights=None) | |
| num_features = self.model.classifier[1].in_features | |
| self.model.classifier[1] = nn.Linear(num_features, num_classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class VGG19Model(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| self.model = models.vgg19(weights=None) | |
| self.model.classifier[6] = nn.Linear(4096, num_classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class DenseNet121Model(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| self.model = models.densenet121(weights=None) | |
| self.model.classifier = nn.Linear(1024, num_classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| # ---------------------------------------------------------------------------- | |
| # DICOM & Image Processing | |
| # ---------------------------------------------------------------------------- | |
| def process_dicom_file(file_obj): | |
| """Convert DICOM file to a PIL Image.""" | |
| if not DICOM_AVAILABLE: | |
| raise ValueError("DICOM support not available. Please install pydicom.") | |
| import pydicom | |
| ds = pydicom.dcmread(file_obj.name if hasattr(file_obj, "name") else file_obj) | |
| pixel_array = ds.pixel_array.astype(np.float32) | |
| # Rescale if present | |
| slope = float(getattr(ds, "RescaleSlope", 1.0)) | |
| intercept = float(getattr(ds, "RescaleIntercept", 0.0)) | |
| pixel_array = pixel_array * slope + intercept | |
| # Invert MONOCHROME1 | |
| if getattr(ds, "PhotometricInterpretation", "").upper() == "MONOCHROME1": | |
| pixel_array = pixel_array.max() - pixel_array | |
| # Normalize | |
| pmin, pmax = pixel_array.min(), pixel_array.max() | |
| pixel_array = (pixel_array - pmin) / max(pmax - pmin, 1e-6) | |
| pixel_array = (pixel_array * 255).clip(0, 255).astype(np.uint8) | |
| image = Image.fromarray(pixel_array, mode="L").convert("RGB") | |
| return image | |
| def process_uploaded_image(file_obj): | |
| """Handle JPEG, PNG, or DICOM uploads and return PIL Image.""" | |
| if file_obj is None: | |
| return None | |
| name = getattr(file_obj, "name", "").lower() | |
| if name.endswith((".dcm", ".dicom")): | |
| return process_dicom_file(file_obj) | |
| image = Image.open(file_obj.name if hasattr(file_obj, "name") else file_obj) | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| return image | |
| # ---------------------------------------------------------------------------- | |
| # Pneumonia Model System | |
| # ---------------------------------------------------------------------------- | |
| class PneumoniaModelSystem: | |
| """ | |
| Main orchestrator for pneumonia detection models. | |
| Manages model loading, inference, and ensemble predictions. | |
| Attributes: | |
| device (str): Computation device ('cpu' or 'cuda') | |
| models (dict): Loaded model instances and their metadata | |
| transform (transforms.Compose): Image preprocessing pipeline | |
| """ | |
| def __init__(self, device="cpu"): | |
| self.device = device | |
| self.models = {} | |
| # Standard ImageNet normalization and preprocessing | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), # Resize to model input size | |
| transforms.ToTensor(), # Convert to tensor (0-1 range) | |
| transforms.Normalize( # ImageNet mean and std | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ] | |
| ) | |
| self.model_definitions = { | |
| "Model_A1_7CB_Appr_D": { | |
| "architecture": "VGG19", | |
| "file": "Model_A1_7CB_Appr_D.pt", | |
| "description": "VGG19 - Model A1 with 7CB approach D", | |
| }, | |
| "Model_C_Appr_B": { | |
| "architecture": "MobileNetV2", | |
| "file": "Model_C_Appr_B.pt", | |
| "description": "MobileNetV2 - Model C with approach B", | |
| }, | |
| "Model_F_Appr_B": { | |
| "architecture": "ResNet50", | |
| "file": "Model_F_Appr_B.pt", | |
| "description": "ResNet50 - Model F with approach B", | |
| }, | |
| "Model_G_Appr_B": { | |
| "architecture": "EfficientNet-B0", | |
| "file": "Model_G_Appr_B.pt", | |
| "description": "EfficientNet-B0 - Model G with approach B", | |
| }, | |
| "Model_H_Appr_B": { | |
| "architecture": "DenseNet121", | |
| "file": "Model_H_Appr_B.pt", | |
| "description": "DenseNet121 - Model H with approach B", | |
| }, | |
| } | |
| self.ensemble_weights = { | |
| "Model_A1_7CB_Appr_D": 0.30, | |
| "Model_C_Appr_B": 0.25, | |
| "Model_F_Appr_B": 0.15, | |
| "Model_G_Appr_B": 0.15, | |
| "Model_H_Appr_B": 0.15, | |
| } | |
| def _create_model(self, arch): | |
| return { | |
| "MobileNetV2": MobileNetV2Model, | |
| #"ResNet50": ResNet50Model, | |
| "EfficientNet-B0": EfficientNetB0Model, | |
| "VGG19": VGG19Model, | |
| "DenseNet121": DenseNet121Model, | |
| }[arch](num_classes=2).to(self.device) | |
| def load_models(self, model_dir="models"): | |
| model_dir = Path(model_dir) | |
| for name, info in self.model_definitions.items(): | |
| path = model_dir / info["file"] | |
| if not path.exists(): | |
| print(f" Model missing: {path}") | |
| continue | |
| try: | |
| model = self._create_model(info["architecture"]) | |
| state_dict = torch.load(path, map_location=self.device) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| self.models[name] = {"model": model, "info": info} | |
| _dbg(f"Loaded {name}") | |
| except Exception as e: | |
| print(f" Failed to load {name}: {e}") | |
| return self | |
| def predict_single_model(self, image, model_name): | |
| model = self.models[model_name]["model"] | |
| img_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| probs = torch.softmax(model(img_tensor), dim=1)[0].cpu().numpy() | |
| idx = int(probs.argmax()) | |
| label = "PNEUMONIA" if idx == 1 else "NORMAL" | |
| return { | |
| "prediction": label, | |
| "confidence": float(probs[idx]), | |
| "pneumonia_probability": float(probs[1]), | |
| "normal_probability": float(probs[0]), | |
| "model_used": model_name, | |
| } | |
| def predict_ensemble(self, image, selected_models=None): | |
| """ | |
| Perform ensemble prediction using selected models | |
| Args: | |
| image: Input image to process | |
| selected_models: List of model names to include in ensemble. If None, use all models. | |
| """ | |
| img_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| ensemble_probs = torch.zeros(1, 2).to(self.device) | |
| total_weight = 0 | |
| individual_predictions = [] | |
| with torch.no_grad(): | |
| for name, info in self.models.items(): | |
| # Skip if model not selected | |
| if selected_models is not None and name not in selected_models: | |
| continue | |
| model = info["model"] | |
| w = self.ensemble_weights.get(name, 1) | |
| probs = torch.softmax(model(img_tensor), dim=1)[0] | |
| ensemble_probs += w * probs | |
| total_weight += w | |
| # Store individual model predictions | |
| prob_np = probs.cpu().numpy() | |
| idx = int(prob_np.argmax()) | |
| individual_predictions.append({ | |
| "model_name": name, | |
| "architecture": info["info"]["architecture"], | |
| "prediction": "PNEUMONIA" if idx == 1 else "NORMAL", | |
| "confidence": float(prob_np[idx]), | |
| "pneumonia_probability": float(prob_np[1]), | |
| "normal_probability": float(prob_np[0]), | |
| "weight": w | |
| }) | |
| if total_weight == 0: | |
| raise ValueError("No models selected for ensemble prediction") | |
| ensemble_probs /= total_weight | |
| probs = ensemble_probs[0].cpu().numpy() | |
| idx = int(probs.argmax()) | |
| label = "PNEUMONIA" if idx == 1 else "NORMAL" | |
| return { | |
| "prediction": label, | |
| "confidence": float(probs[idx]), | |
| "pneumonia_probability": float(probs[1]), | |
| "normal_probability": float(probs[0]), | |
| "models_used": [p["model_name"] for p in individual_predictions], | |
| "individual_predictions": individual_predictions | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # Initialize model system | |
| # ---------------------------------------------------------------------------- | |
| model_system = PneumoniaModelSystem(device="cpu") | |
| try: | |
| model_system.load_models("models") | |
| available_models = model_system.models.keys() | |
| print(f"✅ Loaded models: {list(available_models)}") | |
| except Exception as e: | |
| print(f" Model loading error: {e}") | |
| available_models = [] | |
| # ---------------------------------------------------------------------------- | |
| # Gradio Inference Function | |
| # ---------------------------------------------------------------------------- | |
| def predict_pneumonia(file_path, selected_model_option, selected_ensemble_models=None): | |
| """ | |
| Process uploaded X-ray image and perform pneumonia detection | |
| Args: | |
| file_path: Path to the uploaded image file | |
| selected_model_option: Selected model or ensemble option | |
| selected_ensemble_models: List of model names to include in ensemble prediction | |
| Returns: | |
| tuple: (processed_image, text_result, probability_dict, confidence_html, individual_results_html) | |
| """ | |
| if file_path is None: | |
| return None, "Please upload an X-ray image", {}, "", "" | |
| class FileObj: | |
| def __init__(self, path): | |
| self.name = path | |
| processed_image = process_uploaded_image(FileObj(file_path)) | |
| if selected_model_option == "Ensemble (All Models)": | |
| try: | |
| # Use selected models if provided, otherwise use all | |
| result = model_system.predict_ensemble(processed_image, selected_ensemble_models) | |
| model_info = f"Ensemble of {len(result['models_used'])} selected models" | |
| except ValueError as e: | |
| return None, "Please select at least one model for ensemble prediction", {}, "", "" | |
| # Create table for individual model predictions | |
| individual_results_html = """ | |
| <div style='margin-top: 20px;'> | |
| <h3>Individual Model Predictions</h3> | |
| <table style='width: 100%; border-collapse: collapse; margin-top: 10px;'> | |
| <tr style='background-color: #f2f2f2;'> | |
| <th style='padding: 8px; border: 1px solid #ddd; text-align: left;'>Model</th> | |
| <th style='padding: 8px; border: 1px solid #ddd; text-align: left;'>Architecture</th> | |
| <th style='padding: 8px; border: 1px solid #ddd; text-align: left;'>Prediction</th> | |
| <th style='padding: 8px; border: 1px solid #ddd; text-align: right;'>Confidence</th> | |
| <th style='padding: 8px; border: 1px solid #ddd; text-align: right;'>Weight</th> | |
| </tr> | |
| """ | |
| for pred in result["individual_predictions"]: | |
| color = "red" if pred["prediction"] == "PNEUMONIA" else "green" | |
| individual_results_html += f""" | |
| <tr> | |
| <td style='padding: 8px; border: 1px solid #ddd;'>{pred["model_name"]}</td> | |
| <td style='padding: 8px; border: 1px solid #ddd;'>{pred["architecture"]}</td> | |
| <td style='padding: 8px; border: 1px solid #ddd; color: {color};'>{pred["prediction"]}</td> | |
| <td style='padding: 8px; border: 1px solid #ddd; text-align: right;'>{pred["confidence"]*100:.1f}%</td> | |
| <td style='padding: 8px; border: 1px solid #ddd; text-align: right;'>{pred["weight"]:.3f}</td> | |
| </tr> | |
| """ | |
| individual_results_html += "</table></div>" | |
| else: | |
| result = model_system.predict_single_model(processed_image, selected_model_option) | |
| model_info = result["model_used"] | |
| individual_results_html = "" | |
| text = f"## Prediction: {result['prediction']}\n\n" | |
| #text += f"**Confidence:** {result['confidence']*100:.2f}%\n\n" | |
| #text += f"**Model Used:** {model_info}\n\n" | |
| if result["prediction"] == "PNEUMONIA": | |
| text += " **Pneumonia detected** – please consult a radiologist." | |
| else: | |
| text += "✓ **No pneumonia detected** – appears normal." | |
| prob_dict = { | |
| "Normal": result["normal_probability"], | |
| "Pneumonia": result["pneumonia_probability"], | |
| } | |
| color = ( | |
| "green" | |
| if result["confidence"] >= 0.9 | |
| else "blue" | |
| if result["confidence"] >= 0.75 | |
| else "orange" | |
| if result["confidence"] >= 0.6 | |
| else "red" | |
| ) | |
| html = f""" | |
| <div style='padding: 20px; background: #f8f8f8; border-radius: 10px;'> | |
| <h3 style='color:{color};'>Confidence: {result['confidence']*100:.1f}%</h3> | |
| <p>Model used: {model_info}</p> | |
| </div> | |
| """ | |
| return processed_image, text, prob_dict, html, individual_results_html | |
| # ---------------------------------------------------------------------------- | |
| # Gradio Web Interface Configuration | |
| # Defines the UI layout and components for the pneumonia detection system | |
| # Includes custom styling and responsive layout design | |
| # ---------------------------------------------------------------------------- | |
| custom_css = """ | |
| .gradio-container { font-family: 'Arial', sans-serif; } | |
| .output-markdown h2 { color: #2c3e50; } | |
| """ | |
| with gr.Blocks(css=custom_css, title="Pneumonia Detection AI") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🫁 Pneumonia Detection from Chest X-rays | |
| Upload an X-ray image and select a model or the ensemble to analyze. | |
| **Disclaimer:** Research and educational use only. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_file = gr.File( | |
| label="Upload Chest X-Ray (JPEG, PNG, DICOM)", | |
| file_types=[".jpg", ".jpeg", ".png", ".dcm", ".dicom"], | |
| type="filepath", | |
| ) | |
| model_selector = gr.Radio( | |
| choices=["Ensemble (All Models)"] + list(available_models), | |
| value="Ensemble (All Models)", | |
| label="Select Model", | |
| ) | |
| # Create containers for ensemble components | |
| with gr.Group(visible=True) as ensemble_options: | |
| gr.Markdown("### Select Models for Ensemble") | |
| ensemble_model_selector = gr.CheckboxGroup( | |
| choices=list(available_models), | |
| value=list(available_models), # All selected by default | |
| label="Models to include in ensemble", | |
| ) | |
| predict_btn = gr.Button("Analyze X-Ray", variant="primary") | |
| individual_results = gr.HTML(label="Individual Model Results") | |
| with gr.Column(): | |
| preview = gr.Image(label="Processed Image", height=400) | |
| output_text = gr.Markdown(label="Diagnosis Result") | |
| prob_chart = gr.JSON(label="Probability Distribution") | |
| confidence_html = gr.HTML(label="Confidence Level") | |
| # Function to control visibility of ensemble results based on model selection | |
| def update_ensemble_components(choice): | |
| """Toggle visibility of ensemble-related components""" | |
| is_ensemble = choice == "Ensemble (All Models)" | |
| return [ | |
| gr.Box.update(visible=is_ensemble), | |
| gr.Row.update(visible=is_ensemble) | |
| ] | |
| # Event handler for model selection changes | |
| model_selector.change( | |
| fn=update_ensemble_components, | |
| inputs=[model_selector], | |
| outputs=[ensemble_options, individual_results], | |
| ) | |
| predict_btn.click( | |
| fn=predict_pneumonia, | |
| inputs=[input_file, model_selector, ensemble_model_selector], | |
| outputs=[preview, output_text, prob_chart, confidence_html, individual_results], | |
| ) | |
| # ---------------------------------------------------------------------------- | |
| # Launch (Hugging Face Spaces compatible) | |
| # ---------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch() | |