File size: 7,704 Bytes
c71e49b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8dc300
c71e49b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e882257
c71e49b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import gradio as gr
import torch
import torch.nn as nn
import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from torchvision import models, transforms
from PIL import Image
import os

# ==========================================
# 1. CONFIGURATION
# ==========================================
DEVICE = "cpu" # Hugging Face Free Tier uses CPU
SEG_IMG_SIZE = 224
CLS_IMG_SIZE = 224

# Class Labels (Ensure these match your folder indices 0,1,2,3)

CLASSES = {0: 'No Tumor', 1: 'Glioma Tumor', 2: 'Meningioma Tumor', 3: 'Pituitary Tumor'}

# ==========================================
# 2. LOAD MODELS
# ==========================================

# A. Segmentation Model (Swin-UNet)
def load_seg_model():
    model = smp.Unet(
        encoder_name="tu-swin_tiny_patch4_window7_224",
        encoder_weights=None,
        in_channels=3,
        classes=1,
        activation=None
    )
    try:
        model.load_state_dict(torch.load("swin_unet_best.pth", map_location=DEVICE))
        print("✅ Segmentation Model Loaded")
    except FileNotFoundError:
        print("⚠️ Warning: swin_unet_best.pth not found. Using random weights.")
    
    model.to(DEVICE)
    model.eval()
    return model

# B. Classification Model (EfficientNet-B3)
def load_cls_model():
    model = models.efficientnet_b3(weights=None)
    # Recreate the head exactly as trained
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_ftrs, 4)
    
    try:
        model.load_state_dict(torch.load("efficientnet_b3_cls.pth", map_location=DEVICE))
        print("✅ Classification Model Loaded")
    except FileNotFoundError:
        print("⚠️ Warning: efficientnet_b3_cls.pth not found.")
        
    model.to(DEVICE)
    model.eval()
    return model

seg_model = load_seg_model()
cls_model = load_cls_model()

# ==========================================
# 3. PREPROCESSING
# ==========================================
# Albumentations for Segmentation
seg_transform = A.Compose([
    A.Resize(SEG_IMG_SIZE, SEG_IMG_SIZE),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# Torchvision for Classification
cls_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((CLS_IMG_SIZE, CLS_IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ==========================================
# 4. PREDICTION PIPELINE
# ==========================================
def analyze_mri(image):
    if image is None:
        return None, None
    
    # --- 1. Classification ---
    # Prepare input
    cls_input = cls_transform(image).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        cls_out = cls_model(cls_input)
        probs = torch.softmax(cls_out, dim=1)[0]
        
        # Create dictionary for Label output {Label: Confidence}
        confidences = {CLASSES[i]: float(probs[i]) for i in range(4)}
        
        # Get top class to decide on mask color later
        top_class_id = torch.argmax(probs).item()

    # --- 2. Segmentation ---
    h, w = image.shape[:2]
    # Preprocess
    aug = seg_transform(image=image)
    seg_input = aug['image'].unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        seg_out = seg_model(seg_input)
        pred_mask = (torch.sigmoid(seg_out) > 0.5).float().cpu().numpy().squeeze()
    
    # Resize mask to original image size
    pred_mask = cv2.resize(pred_mask, (w, h))
    
    # --- 3. Visualization ---
    output_image = image.copy()
    
    # If mask detected
    if np.any(pred_mask):
        overlay = output_image.copy()
        
        # Color coding based on tumor type (Optional aesthetic touch)
        # Glioma=Red, Meningioma=Blue, Pituitary=Green, No Tumor=None
        colors = {0: (255, 0, 0), 1: (0, 0, 255), 2: (212, 28, 15), 3: (0, 255, 0)}
        color = colors.get(top_class_id, (255, 0, 0)) # Default Red
        
        # Apply mask
        overlay[pred_mask == 1] = color
        
        # Blend
        output_image = cv2.addWeighted(image, 0.65, overlay, 0.35, 0)
        
        # Add contours for sharper edge
        contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(output_image, contours, -1, color, 2)

    return output_image, confidences

# ==========================================
# 5. GRADIO UI LAYOUT
# ==========================================
# Custom CSS for a medical look
custom_css = """
.container {max-width: 1100px; margin: auto; padding-top: 20px;}
#header {text-align: center; margin-bottom: 20px;}
#header h1 {color: #2c3e50; font-family: 'Helvetica', sans-serif;}
.gr-button-primary {background-color: #3498db !important; border: none;}
"""

# Check for example images in the root folder
examples = []
if os.path.exists("test_images_hf"): # Assuming you unzipped the test images here
    # Just grabbing a few random ones if they exist
    for root, _, files in os.walk("test_images_hf"):
        for f in files[:2]: 
            examples.append(os.path.join(root, f))

# Create Interface
with gr.Blocks(css=custom_css, title="BrainInsightAI: Brain Tumor Analysis") as demo:
    
    with gr.Column(elem_id="header"):
        gr.Markdown("# 🧠Brain Tumor Diagnosis & Segmentation")
        gr.Markdown("Artificial Intelligence System for automated MRI analysis. Supports classification of **Glioma, Meningioma, and Pituitary** tumors with pixel-level segmentation.")

        # --- IMPORTANT NOTES SECTION ---
        gr.Markdown(
                """
                <div class="important-note">
                <h3>⚠️ Important Usage Notes:</h3>
                <ul>
                    <li><strong>Image Requirement:</strong> Please ensure you upload a clear <strong>MRI Brain Scan</strong> (T1-weighted contrast-enhanced recommended). Uploading non-MRI images (e.g., photos of people, animals) will yield incorrect results.</li>
                    <li><strong>No Tumor Logic:</strong> If the model predicts "No Tumor", the segmentation mask will remain blank (just the original image).</li>
                    <li><strong>Privacy:</strong> Images are processed in RAM and not stored permanently on this server.</li>
                </ul>
                </div>
                """
            )

    
    with gr.Row():
        # Left Column: Input
        with gr.Column():
            input_img = gr.Image(label="Upload MRI Scan", type="numpy", height=400)
            analyze_btn = gr.Button("🔍 Analyze Scan", variant="primary")
            
            # Examples section
            if examples:
                gr.Examples(examples=examples, inputs=input_img)
            else:
                gr.Markdown("*Upload an image to start.*")

        # Right Column: Output
        with gr.Column():
            # Tabbed output for cleaner look
            with gr.Tabs():
                with gr.Tab("Visual Segmentation"):
                    output_img = gr.Image(label="Tumor Location", type="numpy")
                
                with gr.Tab("Diagnostic Confidence"):
                    output_lbl = gr.Label(label="Predicted Pathology", num_top_classes=4)

    # Footer
    gr.Markdown("---")
    gr.Markdown("**Note:** This is an AI research prototype for testing purpose of our model.")
    gr.Markdown("**Developed by Bhargavi Tippareddy")
    # Logic
    analyze_btn.click(
        fn=analyze_mri,
        inputs=input_img,
        outputs=[output_img, output_lbl]
    )

if __name__ == "__main__":
    demo.launch(share=True)