ArchCoder's picture
Update app.py
90efbfd verified
raw
history blame
15.9 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import io
from torchvision import transforms
import torchvision.models as models
from torchvision.models import detection
import warnings
warnings.filterwarnings("ignore")
# Global variables
model = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class TumorDetector:
def __init__(self):
self.model = None
self.device = device
def load_maskrcnn_model(self):
"""Load Mask R-CNN for tumor instance segmentation"""
try:
print("πŸ”„ Loading Mask R-CNN for brain tumor detection...")
# Use pretrained Mask R-CNN and fine-tune for brain tumors
self.model = detection.maskrcnn_resnet50_fpn(pretrained=True)
# Modify for brain tumor segmentation (2 classes: background, tumor)
num_classes = 2
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
# Modify mask predictor
in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
self.model.roi_heads.mask_predictor = detection.mask_rcnn.MaskRCNNPredictor(
in_features_mask, hidden_layer, num_classes
)
self.model.eval()
self.model = self.model.to(self.device)
print("βœ… Model loaded successfully!")
return True
except Exception as e:
print(f"❌ Error loading model: {e}")
return False
def load_robust_model():
"""Load the most robust brain tumor detection model"""
global model
if model is None:
detector = TumorDetector()
# Try multiple model options
if detector.load_maskrcnn_model():
model = detector.model
print("βœ… Using Mask R-CNN for comprehensive tumor detection")
else:
# Fallback to PyTorch Hub U-Net
try:
print("πŸ”„ Falling back to PyTorch Hub U-Net...")
model = torch.hub.load(
'mateuszbuda/brain-segmentation-pytorch',
'unet',
in_channels=3,
out_channels=1,
init_features=32,
pretrained=True,
force_reload=False
)
model.eval()
model = model.to(device)
print("βœ… Fallback model loaded!")
except:
model = None
print("❌ All models failed to load!")
return model
def enhance_mri_image(image):
"""Advanced MRI enhancement for better tumor detection"""
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
# Convert to grayscale for processing
if len(image_np.shape) == 3:
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
else:
gray = image_np
# Multi-step enhancement
# 1. CLAHE for contrast
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
enhanced = clahe.apply(gray)
# 2. Gaussian blur for noise reduction
denoised = cv2.GaussianBlur(enhanced, (3,3), 0)
# 3. Histogram equalization
hist_eq = cv2.equalizeHist(denoised)
# 4. Normalize intensity
normalized = cv2.normalize(hist_eq, None, 0, 255, cv2.NORM_MINMAX)
# 5. Edge enhancement
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
sharpened = cv2.filter2D(normalized, -1, kernel)
# Convert back to RGB
enhanced_rgb = cv2.cvtColor(sharpened, cv2.COLOR_GRAY2RGB)
return enhanced_rgb
def preprocess_for_detection(image):
"""Preprocess image for comprehensive tumor detection"""
# Enhance the image
enhanced_image = enhance_mri_image(image)
enhanced_pil = Image.fromarray(enhanced_image)
# Resize to standard size
enhanced_pil = enhanced_pil.resize((512, 512), Image.LANCZOS)
# Convert to tensor with proper normalization
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(enhanced_pil).unsqueeze(0)
return image_tensor, enhanced_pil
def detect_all_tumors(image):
"""Comprehensive tumor detection and segmentation"""
current_model = load_robust_model()
if current_model is None:
return None, "❌ Model failed to load. Please check your setup."
if image is None:
return None, "⚠️ Please upload a brain MRI image."
try:
print("🧠 Detecting ALL brain tumors in the image...")
# Preprocess image
input_tensor, processed_img = preprocess_for_detection(image)
input_tensor = input_tensor.to(device)
# Make prediction
with torch.no_grad():
if hasattr(current_model, 'roi_heads'): # Mask R-CNN
predictions = current_model(input_tensor)
# Process Mask R-CNN output
boxes = predictions[0]['boxes'].cpu().numpy()
masks = predictions[0]['masks'].cpu().numpy()
scores = predictions[0]['scores'].cpu().numpy()
# Filter high-confidence detections
threshold = 0.5
high_conf_mask = scores > threshold
final_masks = masks[high_conf_mask]
final_boxes = boxes[high_conf_mask]
final_scores = scores[high_conf_mask]
print(f"🎯 Detected {len(final_masks)} tumor(s) with confidence > {threshold}")
else: # U-Net
prediction = current_model(input_tensor)
prediction = torch.sigmoid(prediction)
prediction = prediction.squeeze().cpu().numpy()
# Create binary mask
binary_mask = (prediction > 0.3).astype(np.uint8)
# Find connected components (separate tumors)
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask)
final_masks = []
for i in range(1, num_labels):
if stats[i, cv2.CC_STAT_AREA] > 100: # Filter small regions
tumor_mask = (labels == i).astype(np.uint8)
final_masks.append(tumor_mask)
print(f"🎯 Detected {len(final_masks)} separate tumor region(s)")
# Create comprehensive visualization
original_array = np.array(image.resize((512, 512)))
processed_array = np.array(processed_img)
# Create combined visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('🧠 Comprehensive Brain Tumor Detection', fontsize=20, fontweight='bold')
# Row 1: Original, Enhanced, All Tumors
axes[0,0].imshow(original_array)
axes[0,0].set_title('Original MRI', fontsize=14, fontweight='bold')
axes[0,0].axis('off')
axes[0,1].imshow(processed_array)
axes[0,1].set_title('Enhanced Image', fontsize=14, fontweight='bold')
axes[0,1].axis('off')
# Combined tumor overlay
combined_overlay = original_array.copy()
colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)] # Different colors for different tumors
if len(final_masks) > 0:
for i, mask in enumerate(final_masks):
color = colors[i % len(colors)]
if len(mask.shape) == 3:
mask = mask[0] # Handle Mask R-CNN format
mask_resized = cv2.resize(mask, (512, 512))
combined_overlay[mask_resized > 0.5] = color
combined_overlay = cv2.addWeighted(original_array, 0.6, combined_overlay, 0.4, 0)
axes[0,2].imshow(combined_overlay)
axes[0,2].set_title(f'All Tumors Detected ({len(final_masks)})', fontsize=14, fontweight='bold')
axes[0,2].axis('off')
# Row 2: Individual tumor analysis
if len(final_masks) >= 1:
mask1 = final_masks[0]
if len(mask1.shape) == 3:
mask1 = mask1[0]
mask1_colored = np.zeros((512, 512, 3), dtype=np.uint8)
mask1_resized = cv2.resize(mask1, (512, 512))
mask1_colored[:, :, 0] = mask1_resized * 255
axes[1,0].imshow(mask1_colored)
axes[1,0].set_title('Tumor Region 1', fontsize=14)
axes[1,0].axis('off')
else:
axes[1,0].text(0.5, 0.5, 'No Tumor\nDetected', ha='center', va='center', fontsize=16)
axes[1,0].axis('off')
if len(final_masks) >= 2:
mask2 = final_masks[1]
if len(mask2.shape) == 3:
mask2 = mask2[0]
mask2_colored = np.zeros((512, 512, 3), dtype=np.uint8)
mask2_resized = cv2.resize(mask2, (512, 512))
mask2_colored[:, :, 1] = mask2_resized * 255
axes[1,1].imshow(mask2_colored)
axes[1,1].set_title('Tumor Region 2', fontsize=14)
axes[1,1].axis('off')
else:
axes[1,1].text(0.5, 0.5, 'Single Tumor\nOnly', ha='center', va='center', fontsize=16)
axes[1,1].axis('off')
# Statistics pie chart
if len(final_masks) > 0:
total_pixels = 512 * 512
tumor_pixels = sum([np.sum(cv2.resize(mask[0] if len(mask.shape) == 3 else mask, (512, 512))) for mask in final_masks])
healthy_pixels = total_pixels - tumor_pixels
if tumor_pixels > 0:
axes[1,2].pie([healthy_pixels, tumor_pixels],
labels=['Healthy', 'Tumor'],
colors=['lightblue', 'red'],
autopct='%1.1f%%',
startangle=90)
axes[1,2].set_title('Tissue Distribution', fontsize=14, fontweight='bold')
else:
axes[1,2].text(0.5, 0.5, 'No Tumors\nDetected', ha='center', va='center', fontsize=16)
axes[1,2].axis('off')
else:
axes[1,2].text(0.5, 0.5, 'Healthy\nBrain', ha='center', va='center', fontsize=16, color='green')
axes[1,2].axis('off')
plt.tight_layout()
# Save result
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
buf.seek(0)
plt.close()
result_image = Image.open(buf)
# Calculate comprehensive statistics
total_tumor_pixels = 0
tumor_areas = []
if len(final_masks) > 0:
for i, mask in enumerate(final_masks):
if len(mask.shape) == 3:
mask = mask[0]
mask_resized = cv2.resize(mask, (512, 512))
pixels = np.sum(mask_resized > 0.5)
total_tumor_pixels += pixels
tumor_areas.append(pixels)
total_percentage = (total_tumor_pixels / (512*512)) * 100
# Comprehensive analysis report
analysis_text = f"""
## 🧠 Comprehensive Brain Tumor Analysis
### 🎯 Detection Summary:
- **Tumors Detected**: **{len(final_masks)} tumor region(s)**
- **Total Tumor Area**: {total_tumor_pixels:,} pixels ({total_percentage:.2f}%)
- **Detection Model**: {'Mask R-CNN Instance Segmentation' if hasattr(current_model, 'roi_heads') else 'Enhanced U-Net Segmentation'}
### πŸ“Š Individual Tumor Analysis:
"""
for i, area in enumerate(tumor_areas):
percentage = (area / (512*512)) * 100
analysis_text += f"- **Tumor {i+1}**: {area:,} pixels ({percentage:.2f}%)\n"
analysis_text += f"""
### πŸ”¬ Technical Details:
- **Enhancement**: CLAHE + Histogram Equalization + Edge Enhancement
- **Resolution**: 512Γ—512 pixels for high-precision detection
- **Detection Threshold**: Multiple confidence levels
- **Processing**: GPU-accelerated inference
### 🎯 Clinical Insights:
- **Status**: {'πŸ”΄ MULTIPLE TUMORS DETECTED' if len(final_masks) > 1 else 'πŸ”΄ TUMOR DETECTED' if len(final_masks) == 1 else '🟒 NO TUMORS DETECTED'}
- **Complexity**: {'High (multiple lesions)' if len(final_masks) > 1 else 'Standard (single lesion)' if len(final_masks) == 1 else 'Normal brain'}
- **Recommendation**: {'Immediate specialist consultation' if total_percentage > 2.0 else 'Medical evaluation advised' if total_percentage > 0 else 'Regular monitoring'}
### ⚠️ Medical Disclaimer:
This AI analysis is for **research purposes only**. Results should be verified by qualified radiologists. Not for diagnostic use.
"""
print("βœ… Comprehensive tumor analysis completed!")
return result_image, analysis_text
except Exception as e:
error_msg = f"❌ Error during tumor detection: {str(e)}"
print(error_msg)
return None, error_msg
def clear_all():
return None, None, "Upload a brain MRI image for comprehensive tumor analysis."
# Enhanced CSS
css = """
.gradio-container {
max-width: 1400px !important;
margin: auto !important;
}
#title {
text-align: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 30px;
border-radius: 15px;
margin-bottom: 30px;
box-shadow: 0 10px 20px rgba(0,0,0,0.1);
}
"""
# Create comprehensive Gradio interface
with gr.Blocks(css=css, title="🧠 Comprehensive Brain Tumor Detection") as app:
gr.HTML("""
<div id="title">
<h1>🧠 Advanced Brain Tumor Detection AI</h1>
<p style="font-size: 18px; margin-top: 15px;">
Detects ALL Tumors β€’ Instance Segmentation β€’ Multi-Tumor Analysis
</p>
<p style="font-size: 14px; margin-top: 10px; opacity: 0.9;">
Powered by Mask R-CNN + Enhanced Image Processing
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Upload Brain MRI")
image_input = gr.Image(
label="Brain MRI Scan",
type="pil",
sources=["upload", "webcam"],
height=350
)
with gr.Row():
analyze_btn = gr.Button("πŸ” Detect All Tumors", variant="primary", scale=2, size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
with gr.Column(scale=2):
gr.Markdown("### πŸ“Š Comprehensive Analysis")
output_image = gr.Image(
label="Complete Tumor Analysis",
type="pil",
height=600
)
analysis_output = gr.Markdown(
value="Upload a brain MRI image to detect and analyze ALL tumors present.",
elem_id="analysis"
)
# Event handlers
analyze_btn.click(
fn=detect_all_tumors,
inputs=[image_input],
outputs=[output_image, analysis_output],
show_progress=True
)
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[image_input, output_image, analysis_output]
)
if __name__ == "__main__":
print("πŸš€ Starting Comprehensive Brain Tumor Detection System...")
app.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False
)