Spaces:
Sleeping
Sleeping
File size: 7,704 Bytes
c71e49b f8dc300 c71e49b e882257 c71e49b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | import gradio as gr
import torch
import torch.nn as nn
import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from torchvision import models, transforms
from PIL import Image
import os
# ==========================================
# 1. CONFIGURATION
# ==========================================
DEVICE = "cpu" # Hugging Face Free Tier uses CPU
SEG_IMG_SIZE = 224
CLS_IMG_SIZE = 224
# Class Labels (Ensure these match your folder indices 0,1,2,3)
CLASSES = {0: 'No Tumor', 1: 'Glioma Tumor', 2: 'Meningioma Tumor', 3: 'Pituitary Tumor'}
# ==========================================
# 2. LOAD MODELS
# ==========================================
# A. Segmentation Model (Swin-UNet)
def load_seg_model():
model = smp.Unet(
encoder_name="tu-swin_tiny_patch4_window7_224",
encoder_weights=None,
in_channels=3,
classes=1,
activation=None
)
try:
model.load_state_dict(torch.load("swin_unet_best.pth", map_location=DEVICE))
print("✅ Segmentation Model Loaded")
except FileNotFoundError:
print("⚠️ Warning: swin_unet_best.pth not found. Using random weights.")
model.to(DEVICE)
model.eval()
return model
# B. Classification Model (EfficientNet-B3)
def load_cls_model():
model = models.efficientnet_b3(weights=None)
# Recreate the head exactly as trained
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 4)
try:
model.load_state_dict(torch.load("efficientnet_b3_cls.pth", map_location=DEVICE))
print("✅ Classification Model Loaded")
except FileNotFoundError:
print("⚠️ Warning: efficientnet_b3_cls.pth not found.")
model.to(DEVICE)
model.eval()
return model
seg_model = load_seg_model()
cls_model = load_cls_model()
# ==========================================
# 3. PREPROCESSING
# ==========================================
# Albumentations for Segmentation
seg_transform = A.Compose([
A.Resize(SEG_IMG_SIZE, SEG_IMG_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
# Torchvision for Classification
cls_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((CLS_IMG_SIZE, CLS_IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ==========================================
# 4. PREDICTION PIPELINE
# ==========================================
def analyze_mri(image):
if image is None:
return None, None
# --- 1. Classification ---
# Prepare input
cls_input = cls_transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
cls_out = cls_model(cls_input)
probs = torch.softmax(cls_out, dim=1)[0]
# Create dictionary for Label output {Label: Confidence}
confidences = {CLASSES[i]: float(probs[i]) for i in range(4)}
# Get top class to decide on mask color later
top_class_id = torch.argmax(probs).item()
# --- 2. Segmentation ---
h, w = image.shape[:2]
# Preprocess
aug = seg_transform(image=image)
seg_input = aug['image'].unsqueeze(0).to(DEVICE)
with torch.no_grad():
seg_out = seg_model(seg_input)
pred_mask = (torch.sigmoid(seg_out) > 0.5).float().cpu().numpy().squeeze()
# Resize mask to original image size
pred_mask = cv2.resize(pred_mask, (w, h))
# --- 3. Visualization ---
output_image = image.copy()
# If mask detected
if np.any(pred_mask):
overlay = output_image.copy()
# Color coding based on tumor type (Optional aesthetic touch)
# Glioma=Red, Meningioma=Blue, Pituitary=Green, No Tumor=None
colors = {0: (255, 0, 0), 1: (0, 0, 255), 2: (212, 28, 15), 3: (0, 255, 0)}
color = colors.get(top_class_id, (255, 0, 0)) # Default Red
# Apply mask
overlay[pred_mask == 1] = color
# Blend
output_image = cv2.addWeighted(image, 0.65, overlay, 0.35, 0)
# Add contours for sharper edge
contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(output_image, contours, -1, color, 2)
return output_image, confidences
# ==========================================
# 5. GRADIO UI LAYOUT
# ==========================================
# Custom CSS for a medical look
custom_css = """
.container {max-width: 1100px; margin: auto; padding-top: 20px;}
#header {text-align: center; margin-bottom: 20px;}
#header h1 {color: #2c3e50; font-family: 'Helvetica', sans-serif;}
.gr-button-primary {background-color: #3498db !important; border: none;}
"""
# Check for example images in the root folder
examples = []
if os.path.exists("test_images_hf"): # Assuming you unzipped the test images here
# Just grabbing a few random ones if they exist
for root, _, files in os.walk("test_images_hf"):
for f in files[:2]:
examples.append(os.path.join(root, f))
# Create Interface
with gr.Blocks(css=custom_css, title="BrainInsightAI: Brain Tumor Analysis") as demo:
with gr.Column(elem_id="header"):
gr.Markdown("# 🧠Brain Tumor Diagnosis & Segmentation")
gr.Markdown("Artificial Intelligence System for automated MRI analysis. Supports classification of **Glioma, Meningioma, and Pituitary** tumors with pixel-level segmentation.")
# --- IMPORTANT NOTES SECTION ---
gr.Markdown(
"""
<div class="important-note">
<h3>⚠️ Important Usage Notes:</h3>
<ul>
<li><strong>Image Requirement:</strong> Please ensure you upload a clear <strong>MRI Brain Scan</strong> (T1-weighted contrast-enhanced recommended). Uploading non-MRI images (e.g., photos of people, animals) will yield incorrect results.</li>
<li><strong>No Tumor Logic:</strong> If the model predicts "No Tumor", the segmentation mask will remain blank (just the original image).</li>
<li><strong>Privacy:</strong> Images are processed in RAM and not stored permanently on this server.</li>
</ul>
</div>
"""
)
with gr.Row():
# Left Column: Input
with gr.Column():
input_img = gr.Image(label="Upload MRI Scan", type="numpy", height=400)
analyze_btn = gr.Button("🔍 Analyze Scan", variant="primary")
# Examples section
if examples:
gr.Examples(examples=examples, inputs=input_img)
else:
gr.Markdown("*Upload an image to start.*")
# Right Column: Output
with gr.Column():
# Tabbed output for cleaner look
with gr.Tabs():
with gr.Tab("Visual Segmentation"):
output_img = gr.Image(label="Tumor Location", type="numpy")
with gr.Tab("Diagnostic Confidence"):
output_lbl = gr.Label(label="Predicted Pathology", num_top_classes=4)
# Footer
gr.Markdown("---")
gr.Markdown("**Note:** This is an AI research prototype for testing purpose of our model.")
gr.Markdown("**Developed by Bhargavi Tippareddy")
# Logic
analyze_btn.click(
fn=analyze_mri,
inputs=input_img,
outputs=[output_img, output_lbl]
)
if __name__ == "__main__":
demo.launch(share=True) |