Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import cv2 | |
| import numpy as np | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import segmentation_models_pytorch as smp | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import os | |
| # ========================================== | |
| # 1. CONFIGURATION | |
| # ========================================== | |
| DEVICE = "cpu" # Hugging Face Free Tier uses CPU | |
| SEG_IMG_SIZE = 224 | |
| CLS_IMG_SIZE = 224 | |
| # Class Labels (Ensure these match your folder indices 0,1,2,3) | |
| CLASSES = {0: 'No Tumor', 1: 'Glioma Tumor', 2: 'Meningioma Tumor', 3: 'Pituitary Tumor'} | |
| # ========================================== | |
| # 2. LOAD MODELS | |
| # ========================================== | |
| # A. Segmentation Model (Swin-UNet) | |
| def load_seg_model(): | |
| model = smp.Unet( | |
| encoder_name="tu-swin_tiny_patch4_window7_224", | |
| encoder_weights=None, | |
| in_channels=3, | |
| classes=1, | |
| activation=None | |
| ) | |
| try: | |
| model.load_state_dict(torch.load("swin_unet_best.pth", map_location=DEVICE)) | |
| print("✅ Segmentation Model Loaded") | |
| except FileNotFoundError: | |
| print("⚠️ Warning: swin_unet_best.pth not found. Using random weights.") | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |
| # B. Classification Model (EfficientNet-B3) | |
| def load_cls_model(): | |
| model = models.efficientnet_b3(weights=None) | |
| # Recreate the head exactly as trained | |
| num_ftrs = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(num_ftrs, 4) | |
| try: | |
| model.load_state_dict(torch.load("efficientnet_b3_cls.pth", map_location=DEVICE)) | |
| print("✅ Classification Model Loaded") | |
| except FileNotFoundError: | |
| print("⚠️ Warning: efficientnet_b3_cls.pth not found.") | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |
| seg_model = load_seg_model() | |
| cls_model = load_cls_model() | |
| # ========================================== | |
| # 3. PREPROCESSING | |
| # ========================================== | |
| # Albumentations for Segmentation | |
| seg_transform = A.Compose([ | |
| A.Resize(SEG_IMG_SIZE, SEG_IMG_SIZE), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2() | |
| ]) | |
| # Torchvision for Classification | |
| cls_transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((CLS_IMG_SIZE, CLS_IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # ========================================== | |
| # 4. PREDICTION PIPELINE | |
| # ========================================== | |
| def analyze_mri(image): | |
| if image is None: | |
| return None, None | |
| # --- 1. Classification --- | |
| # Prepare input | |
| cls_input = cls_transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| cls_out = cls_model(cls_input) | |
| probs = torch.softmax(cls_out, dim=1)[0] | |
| # Create dictionary for Label output {Label: Confidence} | |
| confidences = {CLASSES[i]: float(probs[i]) for i in range(4)} | |
| # Get top class to decide on mask color later | |
| top_class_id = torch.argmax(probs).item() | |
| # --- 2. Segmentation --- | |
| h, w = image.shape[:2] | |
| # Preprocess | |
| aug = seg_transform(image=image) | |
| seg_input = aug['image'].unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| seg_out = seg_model(seg_input) | |
| pred_mask = (torch.sigmoid(seg_out) > 0.5).float().cpu().numpy().squeeze() | |
| # Resize mask to original image size | |
| pred_mask = cv2.resize(pred_mask, (w, h)) | |
| # --- 3. Visualization --- | |
| output_image = image.copy() | |
| # If mask detected | |
| if np.any(pred_mask): | |
| overlay = output_image.copy() | |
| # Color coding based on tumor type (Optional aesthetic touch) | |
| # Glioma=Red, Meningioma=Blue, Pituitary=Green, No Tumor=None | |
| colors = {0: (255, 0, 0), 1: (0, 0, 255), 2: (212, 28, 15), 3: (0, 255, 0)} | |
| color = colors.get(top_class_id, (255, 0, 0)) # Default Red | |
| # Apply mask | |
| overlay[pred_mask == 1] = color | |
| # Blend | |
| output_image = cv2.addWeighted(image, 0.65, overlay, 0.35, 0) | |
| # Add contours for sharper edge | |
| contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(output_image, contours, -1, color, 2) | |
| return output_image, confidences | |
| # ========================================== | |
| # 5. GRADIO UI LAYOUT | |
| # ========================================== | |
| # Custom CSS for a medical look | |
| custom_css = """ | |
| .container {max-width: 1100px; margin: auto; padding-top: 20px;} | |
| #header {text-align: center; margin-bottom: 20px;} | |
| #header h1 {color: #2c3e50; font-family: 'Helvetica', sans-serif;} | |
| .gr-button-primary {background-color: #3498db !important; border: none;} | |
| """ | |
| # Check for example images in the root folder | |
| examples = [] | |
| if os.path.exists("test_images_hf"): # Assuming you unzipped the test images here | |
| # Just grabbing a few random ones if they exist | |
| for root, _, files in os.walk("test_images_hf"): | |
| for f in files[:2]: | |
| examples.append(os.path.join(root, f)) | |
| # Create Interface | |
| with gr.Blocks(css=custom_css, title="BrainInsightAI: Brain Tumor Analysis") as demo: | |
| with gr.Column(elem_id="header"): | |
| gr.Markdown("# 🧠Brain Tumor Diagnosis & Segmentation") | |
| gr.Markdown("Artificial Intelligence System for automated MRI analysis. Supports classification of **Glioma, Meningioma, and Pituitary** tumors with pixel-level segmentation.") | |
| # --- IMPORTANT NOTES SECTION --- | |
| gr.Markdown( | |
| """ | |
| <div class="important-note"> | |
| <h3>⚠️ Important Usage Notes:</h3> | |
| <ul> | |
| <li><strong>Image Requirement:</strong> Please ensure you upload a clear <strong>MRI Brain Scan</strong> (T1-weighted contrast-enhanced recommended). Uploading non-MRI images (e.g., photos of people, animals) will yield incorrect results.</li> | |
| <li><strong>No Tumor Logic:</strong> If the model predicts "No Tumor", the segmentation mask will remain blank (just the original image).</li> | |
| <li><strong>Privacy:</strong> Images are processed in RAM and not stored permanently on this server.</li> | |
| </ul> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| # Left Column: Input | |
| with gr.Column(): | |
| input_img = gr.Image(label="Upload MRI Scan", type="numpy", height=400) | |
| analyze_btn = gr.Button("🔍 Analyze Scan", variant="primary") | |
| # Examples section | |
| if examples: | |
| gr.Examples(examples=examples, inputs=input_img) | |
| else: | |
| gr.Markdown("*Upload an image to start.*") | |
| # Right Column: Output | |
| with gr.Column(): | |
| # Tabbed output for cleaner look | |
| with gr.Tabs(): | |
| with gr.Tab("Visual Segmentation"): | |
| output_img = gr.Image(label="Tumor Location", type="numpy") | |
| with gr.Tab("Diagnostic Confidence"): | |
| output_lbl = gr.Label(label="Predicted Pathology", num_top_classes=4) | |
| # Footer | |
| gr.Markdown("---") | |
| gr.Markdown("**Note:** This is an AI research prototype for testing purpose of our model.") | |
| gr.Markdown("**Developed by Bhargavi Tippareddy") | |
| # Logic | |
| analyze_btn.click( | |
| fn=analyze_mri, | |
| inputs=input_img, | |
| outputs=[output_img, output_lbl] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |