Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| from torchvision import transforms | |
| import torchvision.models as models | |
| from torchvision.models import detection | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # Global variables | |
| model = None | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| class TumorDetector: | |
| def __init__(self): | |
| self.model = None | |
| self.device = device | |
| def load_maskrcnn_model(self): | |
| """Load Mask R-CNN for tumor instance segmentation""" | |
| try: | |
| print("π Loading Mask R-CNN for brain tumor detection...") | |
| # Use pretrained Mask R-CNN and fine-tune for brain tumors | |
| self.model = detection.maskrcnn_resnet50_fpn(pretrained=True) | |
| # Modify for brain tumor segmentation (2 classes: background, tumor) | |
| num_classes = 2 | |
| in_features = self.model.roi_heads.box_predictor.cls_score.in_features | |
| self.model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) | |
| # Modify mask predictor | |
| in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels | |
| hidden_layer = 256 | |
| self.model.roi_heads.mask_predictor = detection.mask_rcnn.MaskRCNNPredictor( | |
| in_features_mask, hidden_layer, num_classes | |
| ) | |
| self.model.eval() | |
| self.model = self.model.to(self.device) | |
| print("β Model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| return False | |
| def load_robust_model(): | |
| """Load the most robust brain tumor detection model""" | |
| global model | |
| if model is None: | |
| detector = TumorDetector() | |
| # Try multiple model options | |
| if detector.load_maskrcnn_model(): | |
| model = detector.model | |
| print("β Using Mask R-CNN for comprehensive tumor detection") | |
| else: | |
| # Fallback to PyTorch Hub U-Net | |
| try: | |
| print("π Falling back to PyTorch Hub U-Net...") | |
| model = torch.hub.load( | |
| 'mateuszbuda/brain-segmentation-pytorch', | |
| 'unet', | |
| in_channels=3, | |
| out_channels=1, | |
| init_features=32, | |
| pretrained=True, | |
| force_reload=False | |
| ) | |
| model.eval() | |
| model = model.to(device) | |
| print("β Fallback model loaded!") | |
| except: | |
| model = None | |
| print("β All models failed to load!") | |
| return model | |
| def enhance_mri_image(image): | |
| """Advanced MRI enhancement for better tumor detection""" | |
| if isinstance(image, Image.Image): | |
| image_np = np.array(image) | |
| else: | |
| image_np = image | |
| # Convert to grayscale for processing | |
| if len(image_np.shape) == 3: | |
| gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = image_np | |
| # Multi-step enhancement | |
| # 1. CLAHE for contrast | |
| clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) | |
| enhanced = clahe.apply(gray) | |
| # 2. Gaussian blur for noise reduction | |
| denoised = cv2.GaussianBlur(enhanced, (3,3), 0) | |
| # 3. Histogram equalization | |
| hist_eq = cv2.equalizeHist(denoised) | |
| # 4. Normalize intensity | |
| normalized = cv2.normalize(hist_eq, None, 0, 255, cv2.NORM_MINMAX) | |
| # 5. Edge enhancement | |
| kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) | |
| sharpened = cv2.filter2D(normalized, -1, kernel) | |
| # Convert back to RGB | |
| enhanced_rgb = cv2.cvtColor(sharpened, cv2.COLOR_GRAY2RGB) | |
| return enhanced_rgb | |
| def preprocess_for_detection(image): | |
| """Preprocess image for comprehensive tumor detection""" | |
| # Enhance the image | |
| enhanced_image = enhance_mri_image(image) | |
| enhanced_pil = Image.fromarray(enhanced_image) | |
| # Resize to standard size | |
| enhanced_pil = enhanced_pil.resize((512, 512), Image.LANCZOS) | |
| # Convert to tensor with proper normalization | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image_tensor = transform(enhanced_pil).unsqueeze(0) | |
| return image_tensor, enhanced_pil | |
| def detect_all_tumors(image): | |
| """Comprehensive tumor detection and segmentation""" | |
| current_model = load_robust_model() | |
| if current_model is None: | |
| return None, "β Model failed to load. Please check your setup." | |
| if image is None: | |
| return None, "β οΈ Please upload a brain MRI image." | |
| try: | |
| print("π§ Detecting ALL brain tumors in the image...") | |
| # Preprocess image | |
| input_tensor, processed_img = preprocess_for_detection(image) | |
| input_tensor = input_tensor.to(device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| if hasattr(current_model, 'roi_heads'): # Mask R-CNN | |
| predictions = current_model(input_tensor) | |
| # Process Mask R-CNN output | |
| boxes = predictions[0]['boxes'].cpu().numpy() | |
| masks = predictions[0]['masks'].cpu().numpy() | |
| scores = predictions[0]['scores'].cpu().numpy() | |
| # Filter high-confidence detections | |
| threshold = 0.5 | |
| high_conf_mask = scores > threshold | |
| final_masks = masks[high_conf_mask] | |
| final_boxes = boxes[high_conf_mask] | |
| final_scores = scores[high_conf_mask] | |
| print(f"π― Detected {len(final_masks)} tumor(s) with confidence > {threshold}") | |
| else: # U-Net | |
| prediction = current_model(input_tensor) | |
| prediction = torch.sigmoid(prediction) | |
| prediction = prediction.squeeze().cpu().numpy() | |
| # Create binary mask | |
| binary_mask = (prediction > 0.3).astype(np.uint8) | |
| # Find connected components (separate tumors) | |
| num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask) | |
| final_masks = [] | |
| for i in range(1, num_labels): | |
| if stats[i, cv2.CC_STAT_AREA] > 100: # Filter small regions | |
| tumor_mask = (labels == i).astype(np.uint8) | |
| final_masks.append(tumor_mask) | |
| print(f"π― Detected {len(final_masks)} separate tumor region(s)") | |
| # Create comprehensive visualization | |
| original_array = np.array(image.resize((512, 512))) | |
| processed_array = np.array(processed_img) | |
| # Create combined visualization | |
| fig, axes = plt.subplots(2, 3, figsize=(18, 12)) | |
| fig.suptitle('π§ Comprehensive Brain Tumor Detection', fontsize=20, fontweight='bold') | |
| # Row 1: Original, Enhanced, All Tumors | |
| axes[0,0].imshow(original_array) | |
| axes[0,0].set_title('Original MRI', fontsize=14, fontweight='bold') | |
| axes[0,0].axis('off') | |
| axes[0,1].imshow(processed_array) | |
| axes[0,1].set_title('Enhanced Image', fontsize=14, fontweight='bold') | |
| axes[0,1].axis('off') | |
| # Combined tumor overlay | |
| combined_overlay = original_array.copy() | |
| colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)] # Different colors for different tumors | |
| if len(final_masks) > 0: | |
| for i, mask in enumerate(final_masks): | |
| color = colors[i % len(colors)] | |
| if len(mask.shape) == 3: | |
| mask = mask[0] # Handle Mask R-CNN format | |
| mask_resized = cv2.resize(mask, (512, 512)) | |
| combined_overlay[mask_resized > 0.5] = color | |
| combined_overlay = cv2.addWeighted(original_array, 0.6, combined_overlay, 0.4, 0) | |
| axes[0,2].imshow(combined_overlay) | |
| axes[0,2].set_title(f'All Tumors Detected ({len(final_masks)})', fontsize=14, fontweight='bold') | |
| axes[0,2].axis('off') | |
| # Row 2: Individual tumor analysis | |
| if len(final_masks) >= 1: | |
| mask1 = final_masks[0] | |
| if len(mask1.shape) == 3: | |
| mask1 = mask1[0] | |
| mask1_colored = np.zeros((512, 512, 3), dtype=np.uint8) | |
| mask1_resized = cv2.resize(mask1, (512, 512)) | |
| mask1_colored[:, :, 0] = mask1_resized * 255 | |
| axes[1,0].imshow(mask1_colored) | |
| axes[1,0].set_title('Tumor Region 1', fontsize=14) | |
| axes[1,0].axis('off') | |
| else: | |
| axes[1,0].text(0.5, 0.5, 'No Tumor\nDetected', ha='center', va='center', fontsize=16) | |
| axes[1,0].axis('off') | |
| if len(final_masks) >= 2: | |
| mask2 = final_masks[1] | |
| if len(mask2.shape) == 3: | |
| mask2 = mask2[0] | |
| mask2_colored = np.zeros((512, 512, 3), dtype=np.uint8) | |
| mask2_resized = cv2.resize(mask2, (512, 512)) | |
| mask2_colored[:, :, 1] = mask2_resized * 255 | |
| axes[1,1].imshow(mask2_colored) | |
| axes[1,1].set_title('Tumor Region 2', fontsize=14) | |
| axes[1,1].axis('off') | |
| else: | |
| axes[1,1].text(0.5, 0.5, 'Single Tumor\nOnly', ha='center', va='center', fontsize=16) | |
| axes[1,1].axis('off') | |
| # Statistics pie chart | |
| if len(final_masks) > 0: | |
| total_pixels = 512 * 512 | |
| tumor_pixels = sum([np.sum(cv2.resize(mask[0] if len(mask.shape) == 3 else mask, (512, 512))) for mask in final_masks]) | |
| healthy_pixels = total_pixels - tumor_pixels | |
| if tumor_pixels > 0: | |
| axes[1,2].pie([healthy_pixels, tumor_pixels], | |
| labels=['Healthy', 'Tumor'], | |
| colors=['lightblue', 'red'], | |
| autopct='%1.1f%%', | |
| startangle=90) | |
| axes[1,2].set_title('Tissue Distribution', fontsize=14, fontweight='bold') | |
| else: | |
| axes[1,2].text(0.5, 0.5, 'No Tumors\nDetected', ha='center', va='center', fontsize=16) | |
| axes[1,2].axis('off') | |
| else: | |
| axes[1,2].text(0.5, 0.5, 'Healthy\nBrain', ha='center', va='center', fontsize=16, color='green') | |
| axes[1,2].axis('off') | |
| plt.tight_layout() | |
| # Save result | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| buf.seek(0) | |
| plt.close() | |
| result_image = Image.open(buf) | |
| # Calculate comprehensive statistics | |
| total_tumor_pixels = 0 | |
| tumor_areas = [] | |
| if len(final_masks) > 0: | |
| for i, mask in enumerate(final_masks): | |
| if len(mask.shape) == 3: | |
| mask = mask[0] | |
| mask_resized = cv2.resize(mask, (512, 512)) | |
| pixels = np.sum(mask_resized > 0.5) | |
| total_tumor_pixels += pixels | |
| tumor_areas.append(pixels) | |
| total_percentage = (total_tumor_pixels / (512*512)) * 100 | |
| # Comprehensive analysis report | |
| analysis_text = f""" | |
| ## π§ Comprehensive Brain Tumor Analysis | |
| ### π― Detection Summary: | |
| - **Tumors Detected**: **{len(final_masks)} tumor region(s)** | |
| - **Total Tumor Area**: {total_tumor_pixels:,} pixels ({total_percentage:.2f}%) | |
| - **Detection Model**: {'Mask R-CNN Instance Segmentation' if hasattr(current_model, 'roi_heads') else 'Enhanced U-Net Segmentation'} | |
| ### π Individual Tumor Analysis: | |
| """ | |
| for i, area in enumerate(tumor_areas): | |
| percentage = (area / (512*512)) * 100 | |
| analysis_text += f"- **Tumor {i+1}**: {area:,} pixels ({percentage:.2f}%)\n" | |
| analysis_text += f""" | |
| ### π¬ Technical Details: | |
| - **Enhancement**: CLAHE + Histogram Equalization + Edge Enhancement | |
| - **Resolution**: 512Γ512 pixels for high-precision detection | |
| - **Detection Threshold**: Multiple confidence levels | |
| - **Processing**: GPU-accelerated inference | |
| ### π― Clinical Insights: | |
| - **Status**: {'π΄ MULTIPLE TUMORS DETECTED' if len(final_masks) > 1 else 'π΄ TUMOR DETECTED' if len(final_masks) == 1 else 'π’ NO TUMORS DETECTED'} | |
| - **Complexity**: {'High (multiple lesions)' if len(final_masks) > 1 else 'Standard (single lesion)' if len(final_masks) == 1 else 'Normal brain'} | |
| - **Recommendation**: {'Immediate specialist consultation' if total_percentage > 2.0 else 'Medical evaluation advised' if total_percentage > 0 else 'Regular monitoring'} | |
| ### β οΈ Medical Disclaimer: | |
| This AI analysis is for **research purposes only**. Results should be verified by qualified radiologists. Not for diagnostic use. | |
| """ | |
| print("β Comprehensive tumor analysis completed!") | |
| return result_image, analysis_text | |
| except Exception as e: | |
| error_msg = f"β Error during tumor detection: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| def clear_all(): | |
| return None, None, "Upload a brain MRI image for comprehensive tumor analysis." | |
| # Enhanced CSS | |
| css = """ | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: auto !important; | |
| } | |
| #title { | |
| text-align: center; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 30px; | |
| border-radius: 15px; | |
| margin-bottom: 30px; | |
| box-shadow: 0 10px 20px rgba(0,0,0,0.1); | |
| } | |
| """ | |
| # Create comprehensive Gradio interface | |
| with gr.Blocks(css=css, title="π§ Comprehensive Brain Tumor Detection") as app: | |
| gr.HTML(""" | |
| <div id="title"> | |
| <h1>π§ Advanced Brain Tumor Detection AI</h1> | |
| <p style="font-size: 18px; margin-top: 15px;"> | |
| Detects ALL Tumors β’ Instance Segmentation β’ Multi-Tumor Analysis | |
| </p> | |
| <p style="font-size: 14px; margin-top: 10px; opacity: 0.9;"> | |
| Powered by Mask R-CNN + Enhanced Image Processing | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Upload Brain MRI") | |
| image_input = gr.Image( | |
| label="Brain MRI Scan", | |
| type="pil", | |
| sources=["upload", "webcam"], | |
| height=350 | |
| ) | |
| with gr.Row(): | |
| analyze_btn = gr.Button("π Detect All Tumors", variant="primary", scale=2, size="lg") | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary", scale=1) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Comprehensive Analysis") | |
| output_image = gr.Image( | |
| label="Complete Tumor Analysis", | |
| type="pil", | |
| height=600 | |
| ) | |
| analysis_output = gr.Markdown( | |
| value="Upload a brain MRI image to detect and analyze ALL tumors present.", | |
| elem_id="analysis" | |
| ) | |
| # Event handlers | |
| analyze_btn.click( | |
| fn=detect_all_tumors, | |
| inputs=[image_input], | |
| outputs=[output_image, analysis_output], | |
| show_progress=True | |
| ) | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[image_input, output_image, analysis_output] | |
| ) | |
| if __name__ == "__main__": | |
| print("π Starting Comprehensive Brain Tumor Detection System...") | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| share=False | |
| ) | |