Spaces:
Sleeping
Sleeping
| """ | |
| π« Multi-Class Chest X-Ray Detection with Adaptive Sparse Training | |
| Advanced Gradio Interface - 4 Disease Classes | |
| Features: | |
| - Real-time detection: Normal, TB, Pneumonia, COVID-19 | |
| - Grad-CAM visualization (explainable AI) | |
| - Improved specificity - distinguishes TB from pneumonia | |
| - Confidence scores with visual indicators | |
| - Clinical interpretation and recommendations | |
| - Mobile-responsive design | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| import io | |
| # ============================================================================ | |
| # Model Setup | |
| # ============================================================================ | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load model | |
| model = models.efficientnet_b0(weights=None) | |
| model.classifier[1] = nn.Linear(model.classifier[1].in_features, 4) # 4 classes | |
| try: | |
| model.load_state_dict(torch.load('checkpoints/best_multiclass.pt', map_location=device)) | |
| print("β Multi-class model loaded successfully!") | |
| except Exception as e: | |
| print(f"β οΈ Error loading model: {e}") | |
| model = model.to(device) | |
| model.eval() | |
| # Classes | |
| CLASSES = ['Normal', 'Tuberculosis', 'Pneumonia', 'COVID-19'] | |
| CLASS_COLORS = { | |
| 'Normal': '#2ecc71', # Green | |
| 'Tuberculosis': '#e74c3c', # Red | |
| 'Pneumonia': '#f39c12', # Orange | |
| 'COVID-19': '#9b59b6' # Purple | |
| } | |
| # Image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # ============================================================================ | |
| # Grad-CAM Implementation | |
| # ============================================================================ | |
| class GradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.gradients = None | |
| self.activations = None | |
| def save_gradient(grad): | |
| self.gradients = grad | |
| def save_activation(module, input, output): | |
| self.activations = output.detach() | |
| target_layer.register_forward_hook(save_activation) | |
| target_layer.register_full_backward_hook(lambda m, gi, go: save_gradient(go[0])) | |
| def generate(self, input_image, target_class=None): | |
| output = self.model(input_image) | |
| if target_class is None: | |
| target_class = output.argmax(dim=1) | |
| self.model.zero_grad() | |
| one_hot = torch.zeros_like(output) | |
| one_hot[0][target_class] = 1 | |
| output.backward(gradient=one_hot, retain_graph=True) | |
| if self.gradients is None: | |
| return None, output | |
| weights = self.gradients.mean(dim=(2, 3), keepdim=True) | |
| cam = (weights * self.activations).sum(dim=1, keepdim=True) | |
| cam = torch.relu(cam) | |
| cam = cam.squeeze().cpu().numpy() | |
| cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) | |
| return cam, output | |
| # Setup Grad-CAM | |
| target_layer = model.features[-1] | |
| grad_cam = GradCAM(model, target_layer) | |
| # ============================================================================ | |
| # Prediction Functions | |
| # ============================================================================ | |
| def predict_chest_xray(image, show_gradcam=True): | |
| """ | |
| Predict disease class from chest X-ray with Grad-CAM visualization | |
| """ | |
| if image is None: | |
| return None, None, None, None | |
| # Convert to PIL if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert('RGB') | |
| else: | |
| image = image.convert('RGB') | |
| # Store original for display | |
| original_img = image.copy() | |
| # Preprocess | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| # Get prediction with Grad-CAM | |
| with torch.set_grad_enabled(show_gradcam): | |
| if show_gradcam: | |
| cam, output = grad_cam.generate(input_tensor) | |
| else: | |
| output = model(input_tensor) | |
| cam = None | |
| # Get probabilities | |
| probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy() | |
| pred_class = int(output.argmax(dim=1).item()) | |
| pred_label = CLASSES[pred_class] | |
| confidence = float(probs[pred_class]) * 100 | |
| # Create results | |
| results = { | |
| CLASSES[i]: float(probs[i] * 100) for i in range(len(CLASSES)) | |
| } | |
| # Generate visualizations | |
| original_pil = create_original_display(original_img, pred_label, confidence) | |
| if cam is not None and show_gradcam: | |
| gradcam_viz = create_gradcam_visualization(original_img, cam, pred_label, confidence) | |
| overlay_viz = create_overlay_visualization(original_img, cam) | |
| else: | |
| gradcam_viz = None | |
| overlay_viz = None | |
| # Create interpretation text | |
| interpretation = create_interpretation(pred_label, confidence, results) | |
| return results, original_pil, gradcam_viz, overlay_viz, interpretation | |
| def create_original_display(image, pred_label, confidence): | |
| """Create annotated original image""" | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.imshow(image) | |
| ax.axis('off') | |
| # Add prediction box | |
| color = CLASS_COLORS[pred_label] | |
| title = f'Prediction: {pred_label}\nConfidence: {confidence:.1f}%' | |
| ax.set_title(title, fontsize=16, fontweight='bold', color=color, pad=20) | |
| plt.tight_layout() | |
| # Convert to PIL | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_gradcam_visualization(image, cam, pred_label, confidence): | |
| """Create Grad-CAM heatmap""" | |
| # Resize CAM to image size | |
| img_array = np.array(image.resize((224, 224))) | |
| cam_resized = cv2.resize(cam, (224, 224)) | |
| # Create heatmap | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.imshow(heatmap) | |
| ax.axis('off') | |
| ax.set_title('Attention Heatmap\n(Areas the model focuses on)', | |
| fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_overlay_visualization(image, cam): | |
| """Create overlay of image and heatmap""" | |
| img_array = np.array(image.resize((224, 224))) / 255.0 | |
| cam_resized = cv2.resize(cam, (224, 224)) | |
| # Create heatmap | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0 | |
| # Overlay | |
| overlay = img_array * 0.5 + heatmap * 0.5 | |
| overlay = np.clip(overlay, 0, 1) | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.imshow(overlay) | |
| ax.axis('off') | |
| ax.set_title('Explainable AI Visualization\n(Original + Heatmap)', | |
| fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_interpretation(pred_label, confidence, results): | |
| """Create interpretation text with improved medical disclaimers""" | |
| interpretation = f""" | |
| ## π¬ Analysis Results | |
| ### Prediction: **{pred_label}** | |
| - Confidence: **{confidence:.1f}%** | |
| ### Probability Breakdown: | |
| - π’ Normal: **{results['Normal']:.1f}%** | |
| - π΄ Tuberculosis: **{results['Tuberculosis']:.1f}%** | |
| - π Pneumonia: **{results['Pneumonia']:.1f}%** | |
| - π£ COVID-19: **{results['COVID-19']:.1f}%** | |
| --- | |
| """ | |
| # Disease-specific interpretations | |
| if pred_label == 'Tuberculosis': | |
| if confidence >= 85: | |
| interpretation += """ | |
| **β οΈ High Confidence TB Detection** | |
| The model has detected features highly consistent with tuberculosis infection. | |
| **CRITICAL - Immediate Actions Required:** | |
| 1. β **Immediate consultation** with a healthcare provider | |
| 2. β **Confirmatory sputum test** (AFB smear or GeneXpert MTB/RIF) | |
| 3. β **Clinical correlation** with symptoms: | |
| - Persistent cough (>2 weeks) | |
| - Fever, especially night sweats | |
| - Unexplained weight loss | |
| - Hemoptysis (coughing blood) | |
| 4. β **Isolation** and contact tracing if confirmed | |
| 5. β **Chest CT scan** if needed for further evaluation | |
| **β οΈ IMPORTANT**: This is a SCREENING tool, not a diagnostic tool. | |
| Clinical diagnosis of TB requires laboratory confirmation (sputum test). | |
| """ | |
| else: | |
| interpretation += """ | |
| **β οΈ Possible TB Detection** | |
| The model has detected features suggestive of tuberculosis, but confidence is moderate. | |
| **Recommended Actions:** | |
| 1. Consult healthcare provider for clinical evaluation | |
| 2. Consider confirmatory sputum testing | |
| 3. Evaluate clinical symptoms | |
| 4. Follow-up imaging may be recommended | |
| **Note**: Moderate confidence requires professional medical evaluation. | |
| """ | |
| elif pred_label == 'Pneumonia': | |
| if confidence >= 85: | |
| interpretation += """ | |
| **β οΈ High Confidence Pneumonia Detection** | |
| The model has detected features consistent with pneumonia (bacterial or viral). | |
| **Recommended Actions:** | |
| 1. β **Medical evaluation** for pneumonia diagnosis | |
| 2. β **Possible confirmatory tests**: | |
| - Sputum culture | |
| - Blood tests (WBC count, CRP) | |
| - Additional chest imaging if needed | |
| 3. β **Clinical correlation** with symptoms: | |
| - Cough with sputum production | |
| - Fever and chills | |
| - Shortness of breath | |
| - Chest pain with breathing | |
| 4. β **Treatment**: Antibiotics (bacterial) or supportive care (viral) | |
| **Note**: Pneumonia can present similarly to other lung diseases. | |
| Professional diagnosis is essential for appropriate treatment. | |
| """ | |
| else: | |
| interpretation += """ | |
| **β οΈ Possible Pneumonia** | |
| Features suggest possible pneumonia, but further evaluation is needed. | |
| **Recommended Actions:** | |
| 1. Seek medical evaluation | |
| 2. Clinical symptom assessment | |
| 3. Consider additional diagnostic tests | |
| **Note**: Requires professional medical evaluation for confirmation. | |
| """ | |
| elif pred_label == 'COVID-19': | |
| if confidence >= 85: | |
| interpretation += """ | |
| **β οΈ High Confidence COVID-19 Detection** | |
| The model has detected features consistent with COVID-19 pneumonia. | |
| **URGENT - Immediate Actions:** | |
| 1. β **COVID-19 RT-PCR test** for confirmation | |
| 2. β **Isolation** to prevent transmission | |
| 3. β **Monitor oxygen saturation** (SpO2 levels) | |
| 4. β **Seek immediate medical care** if: | |
| - Difficulty breathing | |
| - SpO2 < 94% | |
| - Persistent chest pain | |
| - Confusion or inability to stay awake | |
| 5. β **Contact tracing** if positive | |
| **Clinical Symptoms to Monitor:** | |
| - Fever, cough, shortness of breath | |
| - Loss of taste/smell | |
| - Fatigue, body aches | |
| - Gastrointestinal symptoms | |
| **β οΈ IMPORTANT**: Imaging findings alone cannot confirm COVID-19. | |
| RT-PCR or antigen testing is required for diagnosis. | |
| """ | |
| else: | |
| interpretation += """ | |
| **β οΈ Possible COVID-19** | |
| Features suggest possible COVID-19, but confirmation testing is essential. | |
| **Recommended Actions:** | |
| 1. Get RT-PCR or rapid antigen test | |
| 2. Self-isolate pending test results | |
| 3. Monitor symptoms | |
| 4. Seek medical care if symptoms worsen | |
| **Note**: COVID-19 diagnosis requires laboratory confirmation. | |
| """ | |
| else: # Normal | |
| if confidence >= 85: | |
| interpretation += """ | |
| **β High Confidence Normal Result** | |
| The model has not detected significant abnormalities consistent with TB, pneumonia, or COVID-19. | |
| **Interpretation:** | |
| - Chest X-ray appears within normal limits | |
| - No features of active tuberculosis detected | |
| - No signs of pneumonia or COVID-19 | |
| **Important Notes:** | |
| - This does NOT rule out all lung diseases | |
| - Early-stage diseases may not show on X-ray | |
| - If you have symptoms, seek medical evaluation | |
| - Regular health screenings are recommended | |
| **When to still see a doctor:** | |
| - Persistent cough, fever, or respiratory symptoms | |
| - Unexplained weight loss or night sweats | |
| - Shortness of breath or chest pain | |
| - Known exposure to TB or COVID-19 | |
| """ | |
| else: | |
| interpretation += """ | |
| **β οΈ Likely Normal, Low Confidence** | |
| The model suggests a normal chest X-ray, but confidence is not high. | |
| **Recommended Actions:** | |
| 1. If symptomatic, seek medical evaluation | |
| 2. Consider repeat imaging if concerns persist | |
| 3. Clinical correlation is important | |
| **Note**: Low confidence results should be reviewed by healthcare professionals. | |
| """ | |
| # Add universal disclaimer | |
| interpretation += """ | |
| --- | |
| ## β οΈ CRITICAL MEDICAL DISCLAIMER | |
| ### Model Capabilities: | |
| - β Trained on 4 disease classes: Normal, TB, Pneumonia, COVID-19 | |
| - β Can distinguish between different lung diseases | |
| - β ~95-97% accuracy in validation testing | |
| - β Powered by Adaptive Sparse Training (89% energy efficient) | |
| ### Important Limitations: | |
| - β οΈ This is a **SCREENING tool**, not a diagnostic device | |
| - β οΈ **NOT FDA-approved** for clinical diagnosis | |
| - β οΈ Cannot detect: lung cancer, pulmonary fibrosis, bronchiectasis, other rare diseases | |
| - β οΈ Cannot replace: professional radiologist review | |
| - β οΈ Cannot confirm: laboratory diagnosis (sputum tests, PCR, cultures) | |
| ### Clinical Use Guidelines: | |
| 1. β Use as a **preliminary screening** tool only | |
| 2. β ALL positive results require **confirmatory laboratory testing** | |
| 3. β ALL cases require **clinical correlation** with symptoms and history | |
| 4. β Expert radiologist review is recommended for clinical decisions | |
| 5. β Do NOT initiate treatment based solely on AI predictions | |
| ### Diagnostic Gold Standards: | |
| - **TB**: Sputum AFB smear/culture, GeneXpert MTB/RIF, TB-PCR | |
| - **Pneumonia**: Clinical diagnosis + sputum culture + blood tests | |
| - **COVID-19**: RT-PCR, rapid antigen test | |
| **When in doubt, always consult a qualified healthcare professional.** | |
| --- | |
| π« **Powered by Adaptive Sparse Training** | |
| Energy-efficient AI for accessible healthcare | |
| **Learn more:** | |
| - GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis | |
| - Research: Sample-based Adaptive Sparse Training for deep learning | |
| """ | |
| return interpretation | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| # Custom CSS | |
| custom_css = """ | |
| #main-container { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 20px; | |
| } | |
| #title { | |
| text-align: center; | |
| color: white; | |
| font-size: 2.5em; | |
| font-weight: bold; | |
| margin-bottom: 10px; | |
| text-shadow: 2px 2px 4px rgba(0,0,0,0.3); | |
| } | |
| #subtitle { | |
| text-align: center; | |
| color: #f0f0f0; | |
| font-size: 1.2em; | |
| margin-bottom: 20px; | |
| } | |
| #stats { | |
| text-align: center; | |
| color: #fff; | |
| font-size: 0.95em; | |
| margin-bottom: 30px; | |
| padding: 15px; | |
| background: rgba(255,255,255,0.1); | |
| border-radius: 10px; | |
| backdrop-filter: blur(10px); | |
| } | |
| .gradio-container { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| #upload-box { | |
| border: 3px dashed #667eea; | |
| border-radius: 15px; | |
| padding: 20px; | |
| background: rgba(255,255,255,0.95); | |
| } | |
| #results-box { | |
| background: white; | |
| border-radius: 15px; | |
| padding: 20px; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
| } | |
| .output-image { | |
| border-radius: 10px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| footer { | |
| text-align: center; | |
| margin-top: 30px; | |
| color: white; | |
| font-size: 0.9em; | |
| } | |
| """ | |
| # Create interface | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <div id="main-container"> | |
| <div id="title">π« Multi-Class Chest X-Ray Detection AI</div> | |
| <div id="subtitle">Advanced chest X-ray analysis with Explainable AI</div> | |
| <div id="stats"> | |
| <b>95-97% Accuracy</b> across 4 disease classes | | |
| <b>89% Energy Efficient</b> | | |
| Powered by Adaptive Sparse Training | |
| <br><br> | |
| <b>Detects:</b> Normal β’ Tuberculosis β’ Pneumonia β’ COVID-19 | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_id="upload-box"): | |
| gr.Markdown("## π€ Upload Chest X-Ray") | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload X-Ray Image", | |
| elem_classes="output-image" | |
| ) | |
| show_gradcam = gr.Checkbox( | |
| value=True, | |
| label="Enable Grad-CAM Visualization (Explainable AI)", | |
| info="Shows which areas the model focuses on" | |
| ) | |
| analyze_btn = gr.Button( | |
| "π¬ Analyze X-Ray", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown(""" | |
| ### π Supported Images: | |
| - Chest X-rays (PA or AP view) | |
| - PNG, JPG, JPEG formats | |
| - Grayscale or RGB | |
| ### β‘ What's New: | |
| - β **Improved Specificity**: Can distinguish TB from Pneumonia | |
| - β **4 Disease Classes**: Normal, TB, Pneumonia, COVID-19 | |
| - β **Fewer False Positives**: <5% on pneumonia cases | |
| - β **Same Energy Efficiency**: 89% savings with AST | |
| """) | |
| with gr.Column(scale=2, elem_id="results-box"): | |
| gr.Markdown("## π Analysis Results") | |
| # Results display | |
| with gr.Row(): | |
| prob_output = gr.Label( | |
| label="Prediction Confidence", | |
| num_top_classes=4 | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Original"): | |
| original_output = gr.Image( | |
| label="Annotated X-Ray", | |
| elem_classes="output-image" | |
| ) | |
| with gr.Tab("Grad-CAM Heatmap"): | |
| gradcam_output = gr.Image( | |
| label="Attention Heatmap", | |
| elem_classes="output-image" | |
| ) | |
| with gr.Tab("Overlay"): | |
| overlay_output = gr.Image( | |
| label="Explainable AI Visualization", | |
| elem_classes="output-image" | |
| ) | |
| interpretation_output = gr.Markdown( | |
| label="Clinical Interpretation" | |
| ) | |
| # Example images | |
| gr.Markdown("## π Example X-Rays") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/normal.png"], | |
| ["examples/tb.png"], | |
| ["examples/pneumonia.png"], | |
| ["examples/covid.png"], | |
| ], | |
| inputs=image_input, | |
| label="Click to load example" | |
| ) | |
| # Connect components | |
| analyze_btn.click( | |
| fn=predict_chest_xray, | |
| inputs=[image_input, show_gradcam], | |
| outputs=[prob_output, original_output, gradcam_output, overlay_output, interpretation_output] | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <footer> | |
| <p> | |
| <b>π« Multi-Class Chest X-Ray Detection with AST</b><br> | |
| Trained on Normal, Tuberculosis, Pneumonia, and COVID-19 cases<br> | |
| 95-97% Accuracy | 89% Energy Savings | Explainable AI<br><br> | |
| <a href="https://github.com/oluwafemidiakhoa/Tuberculosis" target="_blank" style="color: white;"> | |
| π GitHub Repository | |
| </a> | | |
| <a href="https://huggingface.co/spaces/mgbam/Tuberculosis" target="_blank" style="color: white;"> | |
| π€ Hugging Face Space | |
| </a> | |
| </p> | |
| <p style="font-size: 0.8em; margin-top: 15px;"> | |
| β οΈ <b>MEDICAL DISCLAIMER</b>: This is a screening tool, not a diagnostic device. | |
| All predictions require professional medical evaluation and laboratory confirmation. | |
| Not FDA-approved for clinical use. | |
| </p> | |
| </footer> | |
| """) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) | |