Brain / app.py
Bhargavitippareddy's picture
Update app.py
f8dc300 verified
import gradio as gr
import torch
import torch.nn as nn
import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from torchvision import models, transforms
from PIL import Image
import os
# ==========================================
# 1. CONFIGURATION
# ==========================================
DEVICE = "cpu" # Hugging Face Free Tier uses CPU
SEG_IMG_SIZE = 224
CLS_IMG_SIZE = 224
# Class Labels (Ensure these match your folder indices 0,1,2,3)
CLASSES = {0: 'No Tumor', 1: 'Glioma Tumor', 2: 'Meningioma Tumor', 3: 'Pituitary Tumor'}
# ==========================================
# 2. LOAD MODELS
# ==========================================
# A. Segmentation Model (Swin-UNet)
def load_seg_model():
model = smp.Unet(
encoder_name="tu-swin_tiny_patch4_window7_224",
encoder_weights=None,
in_channels=3,
classes=1,
activation=None
)
try:
model.load_state_dict(torch.load("swin_unet_best.pth", map_location=DEVICE))
print("✅ Segmentation Model Loaded")
except FileNotFoundError:
print("⚠️ Warning: swin_unet_best.pth not found. Using random weights.")
model.to(DEVICE)
model.eval()
return model
# B. Classification Model (EfficientNet-B3)
def load_cls_model():
model = models.efficientnet_b3(weights=None)
# Recreate the head exactly as trained
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 4)
try:
model.load_state_dict(torch.load("efficientnet_b3_cls.pth", map_location=DEVICE))
print("✅ Classification Model Loaded")
except FileNotFoundError:
print("⚠️ Warning: efficientnet_b3_cls.pth not found.")
model.to(DEVICE)
model.eval()
return model
seg_model = load_seg_model()
cls_model = load_cls_model()
# ==========================================
# 3. PREPROCESSING
# ==========================================
# Albumentations for Segmentation
seg_transform = A.Compose([
A.Resize(SEG_IMG_SIZE, SEG_IMG_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
# Torchvision for Classification
cls_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((CLS_IMG_SIZE, CLS_IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ==========================================
# 4. PREDICTION PIPELINE
# ==========================================
def analyze_mri(image):
if image is None:
return None, None
# --- 1. Classification ---
# Prepare input
cls_input = cls_transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
cls_out = cls_model(cls_input)
probs = torch.softmax(cls_out, dim=1)[0]
# Create dictionary for Label output {Label: Confidence}
confidences = {CLASSES[i]: float(probs[i]) for i in range(4)}
# Get top class to decide on mask color later
top_class_id = torch.argmax(probs).item()
# --- 2. Segmentation ---
h, w = image.shape[:2]
# Preprocess
aug = seg_transform(image=image)
seg_input = aug['image'].unsqueeze(0).to(DEVICE)
with torch.no_grad():
seg_out = seg_model(seg_input)
pred_mask = (torch.sigmoid(seg_out) > 0.5).float().cpu().numpy().squeeze()
# Resize mask to original image size
pred_mask = cv2.resize(pred_mask, (w, h))
# --- 3. Visualization ---
output_image = image.copy()
# If mask detected
if np.any(pred_mask):
overlay = output_image.copy()
# Color coding based on tumor type (Optional aesthetic touch)
# Glioma=Red, Meningioma=Blue, Pituitary=Green, No Tumor=None
colors = {0: (255, 0, 0), 1: (0, 0, 255), 2: (212, 28, 15), 3: (0, 255, 0)}
color = colors.get(top_class_id, (255, 0, 0)) # Default Red
# Apply mask
overlay[pred_mask == 1] = color
# Blend
output_image = cv2.addWeighted(image, 0.65, overlay, 0.35, 0)
# Add contours for sharper edge
contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(output_image, contours, -1, color, 2)
return output_image, confidences
# ==========================================
# 5. GRADIO UI LAYOUT
# ==========================================
# Custom CSS for a medical look
custom_css = """
.container {max-width: 1100px; margin: auto; padding-top: 20px;}
#header {text-align: center; margin-bottom: 20px;}
#header h1 {color: #2c3e50; font-family: 'Helvetica', sans-serif;}
.gr-button-primary {background-color: #3498db !important; border: none;}
"""
# Check for example images in the root folder
examples = []
if os.path.exists("test_images_hf"): # Assuming you unzipped the test images here
# Just grabbing a few random ones if they exist
for root, _, files in os.walk("test_images_hf"):
for f in files[:2]:
examples.append(os.path.join(root, f))
# Create Interface
with gr.Blocks(css=custom_css, title="BrainInsightAI: Brain Tumor Analysis") as demo:
with gr.Column(elem_id="header"):
gr.Markdown("# 🧠Brain Tumor Diagnosis & Segmentation")
gr.Markdown("Artificial Intelligence System for automated MRI analysis. Supports classification of **Glioma, Meningioma, and Pituitary** tumors with pixel-level segmentation.")
# --- IMPORTANT NOTES SECTION ---
gr.Markdown(
"""
<div class="important-note">
<h3>⚠️ Important Usage Notes:</h3>
<ul>
<li><strong>Image Requirement:</strong> Please ensure you upload a clear <strong>MRI Brain Scan</strong> (T1-weighted contrast-enhanced recommended). Uploading non-MRI images (e.g., photos of people, animals) will yield incorrect results.</li>
<li><strong>No Tumor Logic:</strong> If the model predicts "No Tumor", the segmentation mask will remain blank (just the original image).</li>
<li><strong>Privacy:</strong> Images are processed in RAM and not stored permanently on this server.</li>
</ul>
</div>
"""
)
with gr.Row():
# Left Column: Input
with gr.Column():
input_img = gr.Image(label="Upload MRI Scan", type="numpy", height=400)
analyze_btn = gr.Button("🔍 Analyze Scan", variant="primary")
# Examples section
if examples:
gr.Examples(examples=examples, inputs=input_img)
else:
gr.Markdown("*Upload an image to start.*")
# Right Column: Output
with gr.Column():
# Tabbed output for cleaner look
with gr.Tabs():
with gr.Tab("Visual Segmentation"):
output_img = gr.Image(label="Tumor Location", type="numpy")
with gr.Tab("Diagnostic Confidence"):
output_lbl = gr.Label(label="Predicted Pathology", num_top_classes=4)
# Footer
gr.Markdown("---")
gr.Markdown("**Note:** This is an AI research prototype for testing purpose of our model.")
gr.Markdown("**Developed by Bhargavi Tippareddy")
# Logic
analyze_btn.click(
fn=analyze_mri,
inputs=input_img,
outputs=[output_img, output_lbl]
)
if __name__ == "__main__":
demo.launch(share=True)