Spaces:
Sleeping
Sleeping
| """ | |
| π« TB Detection with Adaptive Sparse Training | |
| Advanced Gradio Interface with Modern UI/UX | |
| Features: | |
| - Real-time TB detection from chest X-rays | |
| - Grad-CAM visualization (explainable AI) | |
| - Confidence scores with visual indicators | |
| - Multi-image batch processing | |
| - Interactive dashboard with metrics | |
| - Mobile-responsive design | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| import io | |
| import json | |
| # ============================================================================ | |
| # Model Setup | |
| # ============================================================================ | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load model | |
| model = models.efficientnet_b0(weights=None) | |
| model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2) | |
| try: | |
| model.load_state_dict(torch.load('checkpoints/best.pt', map_location=device)) | |
| print("β Model loaded successfully!") | |
| except Exception as e: | |
| print(f"β οΈ Error loading model: {e}") | |
| model = model.to(device) | |
| model.eval() | |
| # Classes | |
| CLASSES = ['Normal', 'Tuberculosis'] | |
| CLASS_COLORS = { | |
| 'Normal': '#2ecc71', # Green | |
| 'Tuberculosis': '#e74c3c' # Red | |
| } | |
| # Image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # ============================================================================ | |
| # Grad-CAM Implementation | |
| # ============================================================================ | |
| class GradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.gradients = None | |
| self.activations = None | |
| def save_gradient(grad): | |
| self.gradients = grad | |
| def save_activation(module, input, output): | |
| self.activations = output.detach() | |
| target_layer.register_forward_hook(save_activation) | |
| target_layer.register_full_backward_hook(lambda m, gi, go: save_gradient(go[0])) | |
| def generate(self, input_image, target_class=None): | |
| output = self.model(input_image) | |
| if target_class is None: | |
| target_class = output.argmax(dim=1) | |
| self.model.zero_grad() | |
| one_hot = torch.zeros_like(output) | |
| one_hot[0][target_class] = 1 | |
| output.backward(gradient=one_hot, retain_graph=True) | |
| if self.gradients is None: | |
| return None, output | |
| weights = self.gradients.mean(dim=(2, 3), keepdim=True) | |
| cam = (weights * self.activations).sum(dim=1, keepdim=True) | |
| cam = torch.relu(cam) | |
| cam = cam.squeeze().cpu().numpy() | |
| cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) | |
| return cam, output | |
| # Setup Grad-CAM | |
| target_layer = model.features[-1] | |
| grad_cam = GradCAM(model, target_layer) | |
| # ============================================================================ | |
| # Prediction Functions | |
| # ============================================================================ | |
| def predict_tb(image, show_gradcam=True): | |
| """ | |
| Predict TB from chest X-ray with Grad-CAM visualization | |
| """ | |
| if image is None: | |
| return None, None, None, None | |
| # Convert to PIL if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert('RGB') | |
| else: | |
| image = image.convert('RGB') | |
| # Store original for display | |
| original_img = image.copy() | |
| # Preprocess | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| # Get prediction with Grad-CAM | |
| with torch.set_grad_enabled(show_gradcam): | |
| if show_gradcam: | |
| cam, output = grad_cam.generate(input_tensor) | |
| else: | |
| output = model(input_tensor) | |
| cam = None | |
| # Get probabilities | |
| probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy() | |
| pred_class = int(output.argmax(dim=1).item()) | |
| pred_label = CLASSES[pred_class] | |
| confidence = float(probs[pred_class]) * 100 | |
| # Create results | |
| results = { | |
| CLASSES[i]: float(probs[i] * 100) for i in range(len(CLASSES)) | |
| } | |
| # Generate visualizations | |
| original_pil = create_original_display(original_img, pred_label, confidence) | |
| if cam is not None and show_gradcam: | |
| gradcam_viz = create_gradcam_visualization(original_img, cam, pred_label, confidence) | |
| overlay_viz = create_overlay_visualization(original_img, cam) | |
| else: | |
| gradcam_viz = None | |
| overlay_viz = None | |
| # Create interpretation text | |
| interpretation = create_interpretation(pred_label, confidence, results) | |
| return results, original_pil, gradcam_viz, overlay_viz, interpretation | |
| def create_original_display(image, pred_label, confidence): | |
| """Create annotated original image""" | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.imshow(image) | |
| ax.axis('off') | |
| # Add prediction box | |
| color = CLASS_COLORS[pred_label] | |
| title = f'Prediction: {pred_label}\nConfidence: {confidence:.1f}%' | |
| ax.set_title(title, fontsize=16, fontweight='bold', color=color, pad=20) | |
| plt.tight_layout() | |
| # Convert to PIL | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_gradcam_visualization(image, cam, pred_label, confidence): | |
| """Create Grad-CAM heatmap""" | |
| # Resize CAM to image size | |
| img_array = np.array(image.resize((224, 224))) | |
| cam_resized = cv2.resize(cam, (224, 224)) | |
| # Create heatmap | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.imshow(heatmap) | |
| ax.axis('off') | |
| ax.set_title('Attention Heatmap\n(Areas the model focuses on)', | |
| fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_overlay_visualization(image, cam): | |
| """Create overlay of image and heatmap""" | |
| img_array = np.array(image.resize((224, 224))) / 255.0 | |
| cam_resized = cv2.resize(cam, (224, 224)) | |
| # Create heatmap | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0 | |
| # Overlay | |
| overlay = img_array * 0.5 + heatmap * 0.5 | |
| overlay = np.clip(overlay, 0, 1) | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.imshow(overlay) | |
| ax.axis('off') | |
| ax.set_title('Explainable AI Visualization\n(Original + Heatmap)', | |
| fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_interpretation(pred_label, confidence, results): | |
| """Create interpretation text""" | |
| normal_prob = results['Normal'] | |
| tb_prob = results['Tuberculosis'] | |
| interpretation = f""" | |
| ## π¬ Analysis Results | |
| ### Prediction: **{pred_label}** | |
| - Confidence: **{confidence:.1f}%** | |
| ### Probability Breakdown: | |
| - π’ Normal: **{normal_prob:.1f}%** | |
| - π΄ Tuberculosis: **{tb_prob:.1f}%** | |
| ### Clinical Interpretation: | |
| """ | |
| if pred_label == 'Tuberculosis': | |
| if confidence >= 90: | |
| interpretation += """ | |
| **β οΈ High Confidence TB Detection** | |
| The model has detected features highly consistent with tuberculosis infection. | |
| **Recommended Actions:** | |
| 1. Immediate consultation with a healthcare provider | |
| 2. Confirmatory sputum test (AFB smear or GeneXpert) | |
| 3. Clinical correlation with symptoms (cough, fever, weight loss, night sweats) | |
| 4. Isolation and contact tracing if confirmed | |
| **Note**: This is a screening tool. Clinical diagnosis requires laboratory confirmation. | |
| """ | |
| elif confidence >= 70: | |
| interpretation += """ | |
| **β οΈ Moderate Confidence TB Detection** | |
| The model has detected features suggestive of tuberculosis. | |
| **Recommended Actions:** | |
| 1. Consult healthcare provider for further evaluation | |
| 2. Consider confirmatory testing | |
| 3. Monitor symptoms closely | |
| **Note**: Moderate confidence requires clinical correlation. | |
| """ | |
| else: | |
| interpretation += """ | |
| **β οΈ Low Confidence TB Detection** | |
| The model has detected some features that may indicate tuberculosis, but confidence is low. | |
| **Recommended Actions:** | |
| 1. Clinical evaluation recommended | |
| 2. Consider additional imaging or testing if symptomatic | |
| 3. Repeat X-ray if indicated | |
| **Note**: Low confidence predictions should be interpreted cautiously. | |
| """ | |
| else: # Normal | |
| if confidence >= 90: | |
| interpretation += """ | |
| **β High Confidence Normal Result** | |
| The chest X-ray shows no significant features suggestive of active tuberculosis. | |
| **Note**: | |
| - This does not completely rule out latent TB infection | |
| - Consult healthcare provider if symptomatic | |
| - Regular screening recommended for high-risk individuals | |
| """ | |
| elif confidence >= 70: | |
| interpretation += """ | |
| **β Moderate Confidence Normal Result** | |
| The chest X-ray appears largely normal, though some uncertainty exists. | |
| **Recommended Actions:** | |
| - If symptomatic, seek clinical evaluation | |
| - Consider repeat imaging if indicated | |
| """ | |
| else: | |
| interpretation += """ | |
| **β οΈ Low Confidence Normal Result** | |
| The model suggests the X-ray may be normal, but confidence is low. | |
| **Recommended Actions:** | |
| - Clinical correlation strongly recommended | |
| - Consider expert radiologist review | |
| - Additional testing if symptomatic | |
| """ | |
| interpretation += """ | |
| --- | |
| ### π― About This Model | |
| - **Accuracy**: 99.29% on validation set | |
| - **Energy Efficient**: Uses only 10% of computational resources | |
| - **Technology**: Adaptive Sparse Training (AST) | |
| - **Training**: 50 epochs on chest X-ray dataset | |
| ### β οΈ Important Disclaimer | |
| This is an AI screening tool designed to assist healthcare providers. It is NOT a substitute for: | |
| - Professional medical diagnosis | |
| - Laboratory confirmation | |
| - Clinical evaluation by qualified healthcare providers | |
| Always consult with healthcare professionals for proper diagnosis and treatment. | |
| """ | |
| return interpretation | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| # Custom CSS for modern UI | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Inter', sans-serif; | |
| max-width: 1400px !important; | |
| } | |
| .header { | |
| text-align: center; | |
| padding: 2rem; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border-radius: 10px; | |
| margin-bottom: 2rem; | |
| } | |
| .metric-box { | |
| background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); | |
| padding: 1.5rem; | |
| border-radius: 10px; | |
| color: white; | |
| text-align: center; | |
| margin: 1rem 0; | |
| } | |
| .warning-box { | |
| background-color: #fff3cd; | |
| border-left: 4px solid #ffc107; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| border-radius: 5px; | |
| } | |
| .success-box { | |
| background-color: #d4edda; | |
| border-left: 4px solid #28a745; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| border-radius: 5px; | |
| } | |
| .footer { | |
| text-align: center; | |
| padding: 2rem; | |
| margin-top: 2rem; | |
| border-top: 2px solid #eee; | |
| color: #666; | |
| } | |
| #component-0 { | |
| max-width: 100% !important; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| border: none !important; | |
| } | |
| .gr-button-secondary { | |
| background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important; | |
| border: none !important; | |
| } | |
| """ | |
| # Build interface | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="TB Detection AI") as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>π« Tuberculosis Detection AI</h1> | |
| <p style="font-size: 1.2rem; margin-top: 1rem;"> | |
| Advanced chest X-ray analysis with Explainable AI | |
| </p> | |
| <p style="font-size: 0.9rem; opacity: 0.9;"> | |
| 99.3% Accuracy | 89% Energy Efficient | Powered by Adaptive Sparse Training | |
| </p> | |
| </div> | |
| """) | |
| # Main content | |
| with gr.Row(): | |
| # Left column - Input | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Upload Chest X-Ray") | |
| image_input = gr.Image( | |
| label="Chest X-Ray Image", | |
| type="pil", | |
| sources=["upload", "webcam", "clipboard"], | |
| height=400 | |
| ) | |
| with gr.Row(): | |
| predict_btn = gr.Button( | |
| "π¬ Analyze X-Ray", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| clear_btn = gr.Button( | |
| "π Clear", | |
| variant="secondary", | |
| size="lg" | |
| ) | |
| gradcam_checkbox = gr.Checkbox( | |
| label="Enable Grad-CAM Visualization (Explainable AI)", | |
| value=True, | |
| info="Shows which areas the model focuses on" | |
| ) | |
| # Examples | |
| gr.Markdown("### π Example X-Rays") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/normal_1.png"], | |
| ["examples/tb_1.png"], | |
| ] if Path("examples").exists() else [], | |
| inputs=image_input, | |
| label="Click to load example" | |
| ) | |
| # Right column - Results | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Analysis Results") | |
| # Confidence meter | |
| confidence_output = gr.Label( | |
| label="Prediction Confidence", | |
| num_top_classes=2, | |
| show_label=True | |
| ) | |
| # Interpretation | |
| interpretation_output = gr.Markdown( | |
| label="Clinical Interpretation", | |
| value="Upload an X-ray image and click 'Analyze' to get results." | |
| ) | |
| # Visualization section | |
| gr.Markdown("---") | |
| gr.Markdown("## π¬ Explainable AI Visualizations") | |
| gr.Markdown("See exactly where the model is looking to make its decision") | |
| with gr.Row(): | |
| original_output = gr.Image(label="Original X-Ray with Prediction", height=300) | |
| gradcam_output = gr.Image(label="Attention Heatmap", height=300) | |
| overlay_output = gr.Image(label="Explainable AI Overlay", height=300) | |
| # Information section | |
| gr.Markdown("---") | |
| with gr.Accordion("βΉοΈ About This AI Model", open=False): | |
| gr.Markdown(""" | |
| ### π― Model Performance | |
| | Metric | Value | | |
| |--------|-------| | |
| | **Accuracy** | 99.29% | | |
| | **Energy Savings** | 89.52% | | |
| | **Training Method** | Adaptive Sparse Training (AST) | | |
| | **Architecture** | EfficientNet-B0 | | |
| | **Dataset** | TB Chest X-Ray Database (~3,500 images) | | |
| ### π Built for Global Health | |
| This model is designed to run on low-power devices, making it accessible for: | |
| - Rural clinics without high-end infrastructure | |
| - Mobile health screening units | |
| - Resource-limited healthcare settings | |
| - Telemedicine networks | |
| ### β‘ Energy Efficiency | |
| Uses only **10% of computational resources** compared to traditional models: | |
| - Lower electricity costs | |
| - Runs on affordable hardware | |
| - Reduced carbon footprint | |
| - Faster inference time (<2 seconds) | |
| ### π¬ How It Works | |
| 1. **Upload**: Provide a chest X-ray image | |
| 2. **Analysis**: Model analyzes lung patterns for TB indicators | |
| 3. **Grad-CAM**: Highlights regions of interest | |
| 4. **Result**: Get prediction with confidence score and clinical interpretation | |
| ### β οΈ Medical Disclaimer | |
| This tool is designed to **assist** healthcare providers, not replace them: | |
| - Always seek professional medical advice | |
| - Confirmatory laboratory testing required | |
| - Clinical correlation essential | |
| - Not approved for standalone diagnostic use | |
| ### π Learn More | |
| - [GitHub Repository](https://github.com/oluwafemidiakhoa/Tuberculosis) | |
| - [Research Paper](#) (Coming soon) | |
| - [Documentation](#) | |
| ### π¨ββοΈ For Healthcare Providers | |
| This AI tool can help with: | |
| - Initial screening in high-burden areas | |
| - Triage in busy clinics | |
| - Second opinion for challenging cases | |
| - Remote consultation support | |
| **Integration**: Can be integrated into existing PACS systems or used standalone. | |
| """) | |
| # Usage guide | |
| with gr.Accordion("π How to Use", open=False): | |
| gr.Markdown(""" | |
| ### Step-by-Step Guide | |
| 1. **Upload X-Ray** | |
| - Click the upload area or drag & drop | |
| - Supports PNG, JPG, JPEG formats | |
| - Or use webcam/clipboard | |
| 2. **Enable Grad-CAM** (Recommended) | |
| - Check the box to see AI explanations | |
| - Shows which lung areas the model focuses on | |
| - Helps understand the decision-making process | |
| 3. **Analyze** | |
| - Click "π¬ Analyze X-Ray" button | |
| - Wait 2-3 seconds for processing | |
| - View results and visualizations | |
| 4. **Interpret Results** | |
| - Check prediction confidence | |
| - Review probability breakdown | |
| - Read clinical interpretation | |
| - Examine Grad-CAM heatmaps | |
| 5. **Clinical Action** | |
| - Follow recommended actions | |
| - Consult healthcare provider | |
| - Arrange confirmatory testing if needed | |
| ### π‘ Tips for Best Results | |
| - Use clear, well-exposed X-rays | |
| - Ensure proper patient positioning (PA or AP view) | |
| - Avoid heavily rotated or oblique views | |
| - Check image quality before upload | |
| ### π΄ When to Seek Immediate Medical Attention | |
| - High confidence TB detection | |
| - Severe respiratory symptoms | |
| - Hemoptysis (coughing blood) | |
| - Significant weight loss | |
| - Persistent fever | |
| """) | |
| # Footer | |
| gr.HTML(""" | |
| <div class="footer"> | |
| <p><strong>π Built for Global Health | π Sustainable AI | π¬ Explainable AI</strong></p> | |
| <p>Powered by Adaptive Sparse Training (Sundew Algorithm)</p> | |
| <p> | |
| <a href="https://github.com/oluwafemidiakhoa/Tuberculosis" target="_blank">GitHub</a> | | |
| <a href="https://github.com/oluwafemidiakhoa" target="_blank">Developer</a> | | |
| <a href="https://huggingface.co/mgbam" target="_blank">Hugging Face</a> | |
| </p> | |
| <p style="font-size: 0.8rem; color: #999; margin-top: 1rem;"> | |
| Β© 2024 Oluwafemi Idiakhoa | MIT License<br> | |
| For research and educational purposes. Not approved for clinical use. | |
| </p> | |
| </div> | |
| """) | |
| # Event handlers | |
| predict_btn.click( | |
| fn=predict_tb, | |
| inputs=[image_input, gradcam_checkbox], | |
| outputs=[confidence_output, original_output, gradcam_output, overlay_output, interpretation_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, None, None, None, None, "Upload an X-ray image and click 'Analyze' to get results."), | |
| outputs=[image_input, confidence_output, original_output, gradcam_output, overlay_output, interpretation_output] | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) | |