import gradio as gr import torch import torch.nn as nn import cv2 import numpy as np import albumentations as A from albumentations.pytorch import ToTensorV2 import segmentation_models_pytorch as smp from torchvision import models, transforms from PIL import Image import os # ========================================== # 1. CONFIGURATION # ========================================== DEVICE = "cpu" # Hugging Face Free Tier uses CPU SEG_IMG_SIZE = 224 CLS_IMG_SIZE = 224 # Class Labels (Ensure these match your folder indices 0,1,2,3) CLASSES = {0: 'No Tumor', 1: 'Glioma Tumor', 2: 'Meningioma Tumor', 3: 'Pituitary Tumor'} # ========================================== # 2. LOAD MODELS # ========================================== # A. Segmentation Model (Swin-UNet) def load_seg_model(): model = smp.Unet( encoder_name="tu-swin_tiny_patch4_window7_224", encoder_weights=None, in_channels=3, classes=1, activation=None ) try: model.load_state_dict(torch.load("swin_unet_best.pth", map_location=DEVICE)) print("✅ Segmentation Model Loaded") except FileNotFoundError: print("⚠️ Warning: swin_unet_best.pth not found. Using random weights.") model.to(DEVICE) model.eval() return model # B. Classification Model (EfficientNet-B3) def load_cls_model(): model = models.efficientnet_b3(weights=None) # Recreate the head exactly as trained num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, 4) try: model.load_state_dict(torch.load("efficientnet_b3_cls.pth", map_location=DEVICE)) print("✅ Classification Model Loaded") except FileNotFoundError: print("⚠️ Warning: efficientnet_b3_cls.pth not found.") model.to(DEVICE) model.eval() return model seg_model = load_seg_model() cls_model = load_cls_model() # ========================================== # 3. PREPROCESSING # ========================================== # Albumentations for Segmentation seg_transform = A.Compose([ A.Resize(SEG_IMG_SIZE, SEG_IMG_SIZE), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2() ]) # Torchvision for Classification cls_transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((CLS_IMG_SIZE, CLS_IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # ========================================== # 4. PREDICTION PIPELINE # ========================================== def analyze_mri(image): if image is None: return None, None # --- 1. Classification --- # Prepare input cls_input = cls_transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): cls_out = cls_model(cls_input) probs = torch.softmax(cls_out, dim=1)[0] # Create dictionary for Label output {Label: Confidence} confidences = {CLASSES[i]: float(probs[i]) for i in range(4)} # Get top class to decide on mask color later top_class_id = torch.argmax(probs).item() # --- 2. Segmentation --- h, w = image.shape[:2] # Preprocess aug = seg_transform(image=image) seg_input = aug['image'].unsqueeze(0).to(DEVICE) with torch.no_grad(): seg_out = seg_model(seg_input) pred_mask = (torch.sigmoid(seg_out) > 0.5).float().cpu().numpy().squeeze() # Resize mask to original image size pred_mask = cv2.resize(pred_mask, (w, h)) # --- 3. Visualization --- output_image = image.copy() # If mask detected if np.any(pred_mask): overlay = output_image.copy() # Color coding based on tumor type (Optional aesthetic touch) # Glioma=Red, Meningioma=Blue, Pituitary=Green, No Tumor=None colors = {0: (255, 0, 0), 1: (0, 0, 255), 2: (212, 28, 15), 3: (0, 255, 0)} color = colors.get(top_class_id, (255, 0, 0)) # Default Red # Apply mask overlay[pred_mask == 1] = color # Blend output_image = cv2.addWeighted(image, 0.65, overlay, 0.35, 0) # Add contours for sharper edge contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(output_image, contours, -1, color, 2) return output_image, confidences # ========================================== # 5. GRADIO UI LAYOUT # ========================================== # Custom CSS for a medical look custom_css = """ .container {max-width: 1100px; margin: auto; padding-top: 20px;} #header {text-align: center; margin-bottom: 20px;} #header h1 {color: #2c3e50; font-family: 'Helvetica', sans-serif;} .gr-button-primary {background-color: #3498db !important; border: none;} """ # Check for example images in the root folder examples = [] if os.path.exists("test_images_hf"): # Assuming you unzipped the test images here # Just grabbing a few random ones if they exist for root, _, files in os.walk("test_images_hf"): for f in files[:2]: examples.append(os.path.join(root, f)) # Create Interface with gr.Blocks(css=custom_css, title="BrainInsightAI: Brain Tumor Analysis") as demo: with gr.Column(elem_id="header"): gr.Markdown("# 🧠Brain Tumor Diagnosis & Segmentation") gr.Markdown("Artificial Intelligence System for automated MRI analysis. Supports classification of **Glioma, Meningioma, and Pituitary** tumors with pixel-level segmentation.") # --- IMPORTANT NOTES SECTION --- gr.Markdown( """

⚠️ Important Usage Notes:

""" ) with gr.Row(): # Left Column: Input with gr.Column(): input_img = gr.Image(label="Upload MRI Scan", type="numpy", height=400) analyze_btn = gr.Button("🔍 Analyze Scan", variant="primary") # Examples section if examples: gr.Examples(examples=examples, inputs=input_img) else: gr.Markdown("*Upload an image to start.*") # Right Column: Output with gr.Column(): # Tabbed output for cleaner look with gr.Tabs(): with gr.Tab("Visual Segmentation"): output_img = gr.Image(label="Tumor Location", type="numpy") with gr.Tab("Diagnostic Confidence"): output_lbl = gr.Label(label="Predicted Pathology", num_top_classes=4) # Footer gr.Markdown("---") gr.Markdown("**Note:** This is an AI research prototype for testing purpose of our model.") gr.Markdown("**Developed by Bhargavi Tippareddy") # Logic analyze_btn.click( fn=analyze_mri, inputs=input_img, outputs=[output_img, output_lbl] ) if __name__ == "__main__": demo.launch(share=True)