Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,11 +6,19 @@ import gradio as gr
|
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
import os
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
model_path = hf_hub_download(
|
| 10 |
repo_id="harikrishnaaa321/cnn_attention_model",
|
| 11 |
filename="cnn_attention_best.pth"
|
| 12 |
)
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class SEBlock(nn.Module):
|
| 15 |
def __init__(self, channels, reduction=8):
|
| 16 |
super(SEBlock, self).__init__()
|
|
@@ -57,6 +65,10 @@ class CNN_Attention_Model(nn.Module):
|
|
| 57 |
return x
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 61 |
model = CNN_Attention_Model(num_classes=4).to(device)
|
| 62 |
state_dict = torch.load(model_path, map_location=device)
|
|
@@ -64,24 +76,34 @@ model.load_state_dict(state_dict, strict=False)
|
|
| 64 |
model.eval()
|
| 65 |
|
| 66 |
labels = ["Glioma", "Meningioma", "Pituitary", "Normal"]
|
| 67 |
-
|
|
|
|
| 68 |
example_images = {
|
| 69 |
-
"Glioma":
|
| 70 |
-
"Meningioma":
|
| 71 |
-
"Pituitary":
|
| 72 |
-
"Normal":
|
| 73 |
}
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
transform = transforms.Compose([
|
| 76 |
transforms.Resize((224, 224)),
|
| 77 |
transforms.ToTensor(),
|
| 78 |
])
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def predict_tumor(image):
|
| 81 |
# Convert to grayscale
|
| 82 |
image = image.convert("L")
|
| 83 |
|
| 84 |
-
#
|
| 85 |
img_tensor = transform(image).repeat(2, 1, 1).unsqueeze(0).to(device)
|
| 86 |
|
| 87 |
# Forward pass
|
|
@@ -93,30 +115,64 @@ def predict_tumor(image):
|
|
| 93 |
pred_label = labels[pred_idx]
|
| 94 |
confidences = {labels[i]: float(probs[i]) for i in range(len(labels))}
|
| 95 |
|
| 96 |
-
#
|
|
|
|
|
|
|
|
|
|
| 97 |
example_imgs = [Image.open(example_images[label]).resize((224, 224)) for label in labels]
|
|
|
|
| 98 |
combined_width = 224 * 4
|
| 99 |
-
combined_height = 224 +
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
draw = ImageDraw.Draw(combined_image)
|
| 102 |
font = ImageFont.load_default()
|
| 103 |
|
| 104 |
for i, img in enumerate(example_imgs):
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
if i == pred_idx:
|
| 108 |
-
draw.rectangle([
|
| 109 |
-
|
|
|
|
| 110 |
text = labels[i]
|
| 111 |
bbox = draw.textbbox((0, 0), text, font=font)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
text_y = 224 + 5
|
| 116 |
-
draw.text((text_x, text_y), text, fill="white", font=font) # White text for dark theme
|
| 117 |
|
| 118 |
return pred_label, confidences, combined_image
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
interface = gr.Interface(
|
| 121 |
fn=predict_tumor,
|
| 122 |
inputs=gr.Image(type="pil", label="Upload MRI Scan"),
|
|
@@ -125,34 +181,10 @@ interface = gr.Interface(
|
|
| 125 |
gr.Label(label="Confidence Scores"),
|
| 126 |
gr.Image(label="Reference Tumor Images")
|
| 127 |
],
|
| 128 |
-
title="🧠 Brain Tumor Classification",
|
| 129 |
-
description="Upload an MRI scan
|
| 130 |
-
|
| 131 |
-
primary_hue="blue",
|
| 132 |
-
secondary_hue="slate",
|
| 133 |
-
).set(
|
| 134 |
-
body_background_fill="#0f0f0f",
|
| 135 |
-
body_background_fill_dark="#0f0f0f",
|
| 136 |
-
block_background_fill="#1a1a1a",
|
| 137 |
-
block_background_fill_dark="#1a1a1a",
|
| 138 |
-
block_border_width="1px",
|
| 139 |
-
block_label_background_fill="#1a1a1a",
|
| 140 |
-
block_label_background_fill_dark="#1a1a1a",
|
| 141 |
-
block_label_text_color="#60a5fa",
|
| 142 |
-
block_label_text_color_dark="#60a5fa",
|
| 143 |
-
block_title_text_color="#3b82f6",
|
| 144 |
-
block_title_text_color_dark="#3b82f6",
|
| 145 |
-
body_text_color="#e5e7eb",
|
| 146 |
-
body_text_color_dark="#e5e7eb",
|
| 147 |
-
input_background_fill="#2d2d2d",
|
| 148 |
-
input_background_fill_dark="#2d2d2d",
|
| 149 |
-
button_primary_background_fill="#3b82f6",
|
| 150 |
-
button_primary_background_fill_hover="#2563eb",
|
| 151 |
-
# Label text inside components (like confidence scores)
|
| 152 |
-
color_accent="#60a5fa",
|
| 153 |
-
color_accent_soft="#3b82f6",
|
| 154 |
-
),
|
| 155 |
)
|
| 156 |
|
| 157 |
if __name__ == "__main__":
|
| 158 |
-
interface.launch()
|
|
|
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
import os
|
| 8 |
|
| 9 |
+
# ===============================
|
| 10 |
+
# 1. LOAD MODEL FROM HF
|
| 11 |
+
# ===============================
|
| 12 |
+
|
| 13 |
model_path = hf_hub_download(
|
| 14 |
repo_id="harikrishnaaa321/cnn_attention_model",
|
| 15 |
filename="cnn_attention_best.pth"
|
| 16 |
)
|
| 17 |
|
| 18 |
+
# ===============================
|
| 19 |
+
# 2. MODEL ARCHITECTURE
|
| 20 |
+
# ===============================
|
| 21 |
+
|
| 22 |
class SEBlock(nn.Module):
|
| 23 |
def __init__(self, channels, reduction=8):
|
| 24 |
super(SEBlock, self).__init__()
|
|
|
|
| 65 |
return x
|
| 66 |
|
| 67 |
|
| 68 |
+
# ===============================
|
| 69 |
+
# 3. LOAD MODEL
|
| 70 |
+
# ===============================
|
| 71 |
+
|
| 72 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 73 |
model = CNN_Attention_Model(num_classes=4).to(device)
|
| 74 |
state_dict = torch.load(model_path, map_location=device)
|
|
|
|
| 76 |
model.eval()
|
| 77 |
|
| 78 |
labels = ["Glioma", "Meningioma", "Pituitary", "Normal"]
|
| 79 |
+
|
| 80 |
+
# Direct images (not folder)
|
| 81 |
example_images = {
|
| 82 |
+
"Glioma": "./glioma.jpg",
|
| 83 |
+
"Meningioma": "./meningioma.jpg",
|
| 84 |
+
"Pituitary": "./pituitary.jpg",
|
| 85 |
+
"Normal": "./notumor.jpg"
|
| 86 |
}
|
| 87 |
|
| 88 |
+
# ===============================
|
| 89 |
+
# 4. TRANSFORMS
|
| 90 |
+
# ===============================
|
| 91 |
+
|
| 92 |
transform = transforms.Compose([
|
| 93 |
transforms.Resize((224, 224)),
|
| 94 |
transforms.ToTensor(),
|
| 95 |
])
|
| 96 |
|
| 97 |
+
|
| 98 |
+
# ===============================
|
| 99 |
+
# 5. PREDICT FUNCTION
|
| 100 |
+
# ===============================
|
| 101 |
+
|
| 102 |
def predict_tumor(image):
|
| 103 |
# Convert to grayscale
|
| 104 |
image = image.convert("L")
|
| 105 |
|
| 106 |
+
# Make 2-channel input
|
| 107 |
img_tensor = transform(image).repeat(2, 1, 1).unsqueeze(0).to(device)
|
| 108 |
|
| 109 |
# Forward pass
|
|
|
|
| 115 |
pred_label = labels[pred_idx]
|
| 116 |
confidences = {labels[i]: float(probs[i]) for i in range(len(labels))}
|
| 117 |
|
| 118 |
+
# ===============================
|
| 119 |
+
# Create reference panel (CREAM BG)
|
| 120 |
+
# ===============================
|
| 121 |
+
|
| 122 |
example_imgs = [Image.open(example_images[label]).resize((224, 224)) for label in labels]
|
| 123 |
+
|
| 124 |
combined_width = 224 * 4
|
| 125 |
+
combined_height = 224 + 40
|
| 126 |
+
|
| 127 |
+
combined_image = Image.new("RGB", (combined_width, combined_height), "#f8eecf") # cream
|
| 128 |
+
|
| 129 |
draw = ImageDraw.Draw(combined_image)
|
| 130 |
font = ImageFont.load_default()
|
| 131 |
|
| 132 |
for i, img in enumerate(example_imgs):
|
| 133 |
+
x = 224 * i
|
| 134 |
+
combined_image.paste(img, (x, 0))
|
| 135 |
+
|
| 136 |
+
# Highlight predicted class with red border
|
| 137 |
if i == pred_idx:
|
| 138 |
+
draw.rectangle([x, 0, x + 223, 223], outline="#ff4444", width=5)
|
| 139 |
+
|
| 140 |
+
# Draw label text
|
| 141 |
text = labels[i]
|
| 142 |
bbox = draw.textbbox((0, 0), text, font=font)
|
| 143 |
+
tw = bbox[2] - bbox[0]
|
| 144 |
+
tx = x + (224 - tw) // 2
|
| 145 |
+
draw.text((tx, 228), text, fill="black", font=font)
|
|
|
|
|
|
|
| 146 |
|
| 147 |
return pred_label, confidences, combined_image
|
| 148 |
|
| 149 |
+
|
| 150 |
+
# ===============================
|
| 151 |
+
# 6. CREAM UI THEME CSS
|
| 152 |
+
# ===============================
|
| 153 |
+
|
| 154 |
+
custom_css = """
|
| 155 |
+
:root {
|
| 156 |
+
--body-background-fill: #f8eecf;
|
| 157 |
+
--block-background-fill: #fffaf0;
|
| 158 |
+
--border-color: #d5c7a1;
|
| 159 |
+
--button-primary-background-fill: #d6b77a;
|
| 160 |
+
--button-primary-text-color: #000000;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
body, .gradio-container {
|
| 164 |
+
background: #f8eecf !important;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
.gr-button {
|
| 168 |
+
border-radius: 8px !important;
|
| 169 |
+
}
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
# ===============================
|
| 173 |
+
# 7. GRADIO INTERFACE
|
| 174 |
+
# ===============================
|
| 175 |
+
|
| 176 |
interface = gr.Interface(
|
| 177 |
fn=predict_tumor,
|
| 178 |
inputs=gr.Image(type="pil", label="Upload MRI Scan"),
|
|
|
|
| 181 |
gr.Label(label="Confidence Scores"),
|
| 182 |
gr.Image(label="Reference Tumor Images")
|
| 183 |
],
|
| 184 |
+
title="🧠 Brain Tumor Classification (Attention CNN)",
|
| 185 |
+
description="Upload an MRI scan. Model predicts tumor type and shows reference examples.",
|
| 186 |
+
css=custom_css
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
)
|
| 188 |
|
| 189 |
if __name__ == "__main__":
|
| 190 |
+
interface.launch()
|