Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| import base64 | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # Global variable to store model | |
| model = None | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def load_model(): | |
| """Load the pretrained brain segmentation model""" | |
| global model | |
| if model is None: | |
| try: | |
| print("Loading brain segmentation model...") | |
| model = torch.hub.load( | |
| 'mateuszbuda/brain-segmentation-pytorch', | |
| 'unet', | |
| in_channels=3, | |
| out_channels=1, | |
| init_features=32, | |
| pretrained=True, | |
| force_reload=False | |
| ) | |
| model.eval() | |
| model = model.to(device) | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model = None | |
| return model | |
| def preprocess_image(image): | |
| """Preprocess the input image for the model""" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Convert to RGB if not already | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Resize to 256x256 (model's expected input size) | |
| # Use LANCZOS if available, otherwise use BILINEAR | |
| try: | |
| image = image.resize((256, 256), Image.Resampling.LANCZOS) | |
| except AttributeError: | |
| # For older PIL versions | |
| image = image.resize((256, 256), Image.LANCZOS) | |
| # Convert to tensor and normalize | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
| return image_tensor, image | |
| def create_overlay_visualization(original_img, mask, alpha=0.6): | |
| """Create an overlay visualization of the segmentation""" | |
| # Convert original image to numpy array | |
| original_np = np.array(original_img) | |
| # Create colored mask (red for tumor regions) | |
| colored_mask = np.zeros_like(original_np) | |
| colored_mask[:, :, 0] = mask * 255 # Red channel for tumor | |
| # Create overlay | |
| overlay = cv2.addWeighted(original_np, 1-alpha, colored_mask, alpha, 0) | |
| return overlay | |
| def predict_tumor(image): | |
| """Main prediction function""" | |
| # Load model if not loaded | |
| current_model = load_model() | |
| if current_model is None: | |
| return None, "β Model failed to load. Please try again later." | |
| if image is None: | |
| return None, "β οΈ Please upload an image first." | |
| try: | |
| print("Processing image...") | |
| # Preprocess the image | |
| input_tensor, original_img = preprocess_image(image) | |
| input_tensor = input_tensor.to(device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| prediction = current_model(input_tensor) | |
| # Apply sigmoid to get probability map | |
| prediction = torch.sigmoid(prediction) | |
| # Convert to numpy | |
| prediction = prediction.squeeze().cpu().numpy() | |
| # Threshold the prediction (you can adjust this threshold) | |
| threshold = 0.5 | |
| binary_mask = (prediction > threshold).astype(np.uint8) | |
| # Create visualizations | |
| # 1. Original image | |
| original_array = np.array(original_img) | |
| # 2. Segmentation mask | |
| mask_colored = np.zeros((256, 256, 3), dtype=np.uint8) | |
| mask_colored[:, :, 0] = binary_mask * 255 # Red channel | |
| # 3. Overlay | |
| overlay = create_overlay_visualization(original_img, binary_mask, alpha=0.4) | |
| # 4. Side-by-side comparison | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| fig.suptitle('Brain Tumor Segmentation Results', fontsize=16, fontweight='bold') | |
| axes[0].imshow(original_array) | |
| axes[0].set_title('Original Image', fontsize=12, fontweight='bold') | |
| axes[0].axis('off') | |
| axes[1].imshow(mask_colored) | |
| axes[1].set_title('Tumor Segmentation', fontsize=12, fontweight='bold') | |
| axes[1].axis('off') | |
| axes[2].imshow(overlay) | |
| axes[2].set_title('Overlay (Red = Tumor)', fontsize=12, fontweight='bold') | |
| axes[2].axis('off') | |
| plt.tight_layout() | |
| # Save plot to bytes | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| buf.seek(0) | |
| plt.close() | |
| # Convert to PIL Image | |
| result_image = Image.open(buf) | |
| # Calculate tumor statistics | |
| total_pixels = 256 * 256 | |
| tumor_pixels = np.sum(binary_mask) | |
| tumor_percentage = (tumor_pixels / total_pixels) * 100 | |
| # Create analysis report | |
| analysis_text = f""" | |
| ## π§ Brain Tumor Segmentation Analysis | |
| **π Tumor Statistics:** | |
| - Total pixels analyzed: {total_pixels:,} | |
| - Tumor pixels detected: {tumor_pixels:,} | |
| - Tumor area percentage: {tumor_percentage:.2f}% | |
| **π― Model Information:** | |
| - Model: Pre-trained U-Net for brain segmentation | |
| - Input resolution: 256Γ256 pixels | |
| - Detection threshold: {threshold} | |
| - Device: {device.type.upper()} | |
| **β οΈ Medical Disclaimer:** | |
| This is an AI tool for research and educational purposes only. | |
| Always consult qualified medical professionals for diagnosis. | |
| """ | |
| print("Processing completed successfully!") | |
| return result_image, analysis_text | |
| except Exception as e: | |
| error_msg = f"β Error during prediction: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| def clear_all(): | |
| """Clear all inputs and outputs""" | |
| return None, None, "Upload an image and click 'Analyze Image' to see results." | |
| # Custom CSS for better styling | |
| css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: auto !important; | |
| } | |
| #title { | |
| text-align: center; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| } | |
| .output-image { | |
| border-radius: 10px; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
| } | |
| button { | |
| border-radius: 8px; | |
| font-weight: 500; | |
| } | |
| .progress-bar { | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=css, title="π§ Brain Tumor Segmentation AI", theme=gr.themes.Soft()) as app: | |
| # Header | |
| gr.HTML(""" | |
| <div id="title"> | |
| <h1>π§ Brain Tumor Segmentation AI</h1> | |
| <p>Upload an MRI brain scan to detect and visualize tumor regions using deep learning</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Input Image") | |
| # Image input with camera option | |
| image_input = gr.Image( | |
| label="Upload Brain MRI Scan", | |
| type="pil", | |
| sources=["upload", "webcam"], | |
| height=300 | |
| ) | |
| with gr.Row(): | |
| predict_btn = gr.Button("π Analyze Image", variant="primary", scale=2) | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary", scale=1) | |
| gr.HTML(""" | |
| <div style="margin-top: 20px; padding: 15px; background-color: #f0f8ff; border-radius: 8px; border-left: 4px solid #667eea;"> | |
| <h4>π Instructions:</h4> | |
| <ul style="margin: 10px 0; padding-left: 20px;"> | |
| <li>Upload a brain MRI scan image</li> | |
| <li>Supported formats: PNG, JPG, JPEG</li> | |
| <li>For best results, use clear, high-contrast MRI images</li> | |
| <li>Camera option available for mobile devices</li> | |
| </ul> | |
| </div> | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Segmentation Results") | |
| # Output image | |
| output_image = gr.Image( | |
| label="Segmentation Results", | |
| type="pil", | |
| height=400, | |
| elem_classes=["output-image"] | |
| ) | |
| # Analysis text | |
| analysis_output = gr.Markdown( | |
| value="Upload an image and click 'Analyze Image' to see results.", | |
| elem_id="analysis" | |
| ) | |
| # Add footer with information | |
| gr.HTML(""" | |
| <div style="margin-top: 30px; padding: 20px; background-color: #f9f9f9; border-radius: 10px; border: 1px solid #e1e4e8;"> | |
| <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px;"> | |
| <div> | |
| <h4 style="color: #667eea; margin-bottom: 10px;">π¬ About This Tool</h4> | |
| <p><strong>Model:</strong> Pre-trained U-Net for brain segmentation</p> | |
| <p><strong>Technology:</strong> PyTorch + Deep Learning</p> | |
| <p><strong>Purpose:</strong> Research & Educational Use</p> | |
| </div> | |
| <div> | |
| <h4 style="color: #d73027; margin-bottom: 10px;">β οΈ Medical Disclaimer</h4> | |
| <p style="color: #d73027; font-weight: 500;"> | |
| This AI tool is for research and educational purposes only.<br> | |
| <strong>NOT for medical diagnosis.</strong> Always consult healthcare professionals. | |
| </p> | |
| </div> | |
| </div> | |
| <hr style="margin: 20px 0; border: none; border-top: 1px solid #e1e4e8;"> | |
| <p style="text-align: center; color: #666; margin: 10px 0;"> | |
| Made with β€οΈ using Gradio β’ Powered by PyTorch β’ Hosted on π€ Hugging Face Spaces | |
| </p> | |
| </div> | |
| """) | |
| # Event handlers | |
| predict_btn.click( | |
| fn=predict_tumor, | |
| inputs=[image_input], | |
| outputs=[output_image, analysis_output], | |
| show_progress=True | |
| ) | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[image_input, output_image, analysis_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| print("Starting Brain Tumor Segmentation App...") | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| share=False | |
| ) | |