harikrishnaaa321 commited on
Commit
489e96c
·
verified ·
1 Parent(s): e5bcdd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -46
app.py CHANGED
@@ -6,11 +6,19 @@ import gradio as gr
6
  from huggingface_hub import hf_hub_download
7
  import os
8
 
 
 
 
 
9
  model_path = hf_hub_download(
10
  repo_id="harikrishnaaa321/cnn_attention_model",
11
  filename="cnn_attention_best.pth"
12
  )
13
 
 
 
 
 
14
  class SEBlock(nn.Module):
15
  def __init__(self, channels, reduction=8):
16
  super(SEBlock, self).__init__()
@@ -57,6 +65,10 @@ class CNN_Attention_Model(nn.Module):
57
  return x
58
 
59
 
 
 
 
 
60
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
  model = CNN_Attention_Model(num_classes=4).to(device)
62
  state_dict = torch.load(model_path, map_location=device)
@@ -64,24 +76,34 @@ model.load_state_dict(state_dict, strict=False)
64
  model.eval()
65
 
66
  labels = ["Glioma", "Meningioma", "Pituitary", "Normal"]
67
- EXAMPLE_IMAGE_DIR = "./examples" # Must contain glioma.jpg, meningioma.jpg, pituitary.jpg, notumor.jpg
 
68
  example_images = {
69
- "Glioma": os.path.join(EXAMPLE_IMAGE_DIR, "glioma.jpg"),
70
- "Meningioma": os.path.join(EXAMPLE_IMAGE_DIR, "meningioma.jpg"),
71
- "Pituitary": os.path.join(EXAMPLE_IMAGE_DIR, "pituitary.jpg"),
72
- "Normal": os.path.join(EXAMPLE_IMAGE_DIR, "notumor.jpg")
73
  }
74
 
 
 
 
 
75
  transform = transforms.Compose([
76
  transforms.Resize((224, 224)),
77
  transforms.ToTensor(),
78
  ])
79
 
 
 
 
 
 
80
  def predict_tumor(image):
81
  # Convert to grayscale
82
  image = image.convert("L")
83
 
84
- # Transform to tensor [1, 2, H, W]
85
  img_tensor = transform(image).repeat(2, 1, 1).unsqueeze(0).to(device)
86
 
87
  # Forward pass
@@ -93,30 +115,64 @@ def predict_tumor(image):
93
  pred_label = labels[pred_idx]
94
  confidences = {labels[i]: float(probs[i]) for i in range(len(labels))}
95
 
96
- # Create combined image with names below (dark background)
 
 
 
97
  example_imgs = [Image.open(example_images[label]).resize((224, 224)) for label in labels]
 
98
  combined_width = 224 * 4
99
- combined_height = 224 + 30
100
- combined_image = Image.new("RGB", (combined_width, combined_height), "#1a1a1a") # Dark background
 
 
101
  draw = ImageDraw.Draw(combined_image)
102
  font = ImageFont.load_default()
103
 
104
  for i, img in enumerate(example_imgs):
105
- combined_image.paste(img, (224 * i, 0))
106
- # Draw red border for prediction
 
 
107
  if i == pred_idx:
108
- draw.rectangle([224 * i, 0, 224 * (i + 1) - 1, 223], outline="#ff4444", width=5)
109
- # Draw label text below in white
 
110
  text = labels[i]
111
  bbox = draw.textbbox((0, 0), text, font=font)
112
- text_width = bbox[2] - bbox[0]
113
- text_height = bbox[3] - bbox[1]
114
- text_x = 224 * i + (224 - text_width) // 2
115
- text_y = 224 + 5
116
- draw.text((text_x, text_y), text, fill="white", font=font) # White text for dark theme
117
 
118
  return pred_label, confidences, combined_image
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  interface = gr.Interface(
121
  fn=predict_tumor,
122
  inputs=gr.Image(type="pil", label="Upload MRI Scan"),
@@ -125,34 +181,10 @@ interface = gr.Interface(
125
  gr.Label(label="Confidence Scores"),
126
  gr.Image(label="Reference Tumor Images")
127
  ],
128
- title="🧠 Brain Tumor Classification",
129
- description="Upload an MRI scan (JPG/PNG). The model will predict the tumor type and show reference images with names.",
130
- theme=gr.themes.Soft(
131
- primary_hue="blue",
132
- secondary_hue="slate",
133
- ).set(
134
- body_background_fill="#0f0f0f",
135
- body_background_fill_dark="#0f0f0f",
136
- block_background_fill="#1a1a1a",
137
- block_background_fill_dark="#1a1a1a",
138
- block_border_width="1px",
139
- block_label_background_fill="#1a1a1a",
140
- block_label_background_fill_dark="#1a1a1a",
141
- block_label_text_color="#60a5fa",
142
- block_label_text_color_dark="#60a5fa",
143
- block_title_text_color="#3b82f6",
144
- block_title_text_color_dark="#3b82f6",
145
- body_text_color="#e5e7eb",
146
- body_text_color_dark="#e5e7eb",
147
- input_background_fill="#2d2d2d",
148
- input_background_fill_dark="#2d2d2d",
149
- button_primary_background_fill="#3b82f6",
150
- button_primary_background_fill_hover="#2563eb",
151
- # Label text inside components (like confidence scores)
152
- color_accent="#60a5fa",
153
- color_accent_soft="#3b82f6",
154
- ),
155
  )
156
 
157
  if __name__ == "__main__":
158
- interface.launch()
 
6
  from huggingface_hub import hf_hub_download
7
  import os
8
 
9
+ # ===============================
10
+ # 1. LOAD MODEL FROM HF
11
+ # ===============================
12
+
13
  model_path = hf_hub_download(
14
  repo_id="harikrishnaaa321/cnn_attention_model",
15
  filename="cnn_attention_best.pth"
16
  )
17
 
18
+ # ===============================
19
+ # 2. MODEL ARCHITECTURE
20
+ # ===============================
21
+
22
  class SEBlock(nn.Module):
23
  def __init__(self, channels, reduction=8):
24
  super(SEBlock, self).__init__()
 
65
  return x
66
 
67
 
68
+ # ===============================
69
+ # 3. LOAD MODEL
70
+ # ===============================
71
+
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  model = CNN_Attention_Model(num_classes=4).to(device)
74
  state_dict = torch.load(model_path, map_location=device)
 
76
  model.eval()
77
 
78
  labels = ["Glioma", "Meningioma", "Pituitary", "Normal"]
79
+
80
+ # Direct images (not folder)
81
  example_images = {
82
+ "Glioma": "./glioma.jpg",
83
+ "Meningioma": "./meningioma.jpg",
84
+ "Pituitary": "./pituitary.jpg",
85
+ "Normal": "./notumor.jpg"
86
  }
87
 
88
+ # ===============================
89
+ # 4. TRANSFORMS
90
+ # ===============================
91
+
92
  transform = transforms.Compose([
93
  transforms.Resize((224, 224)),
94
  transforms.ToTensor(),
95
  ])
96
 
97
+
98
+ # ===============================
99
+ # 5. PREDICT FUNCTION
100
+ # ===============================
101
+
102
  def predict_tumor(image):
103
  # Convert to grayscale
104
  image = image.convert("L")
105
 
106
+ # Make 2-channel input
107
  img_tensor = transform(image).repeat(2, 1, 1).unsqueeze(0).to(device)
108
 
109
  # Forward pass
 
115
  pred_label = labels[pred_idx]
116
  confidences = {labels[i]: float(probs[i]) for i in range(len(labels))}
117
 
118
+ # ===============================
119
+ # Create reference panel (CREAM BG)
120
+ # ===============================
121
+
122
  example_imgs = [Image.open(example_images[label]).resize((224, 224)) for label in labels]
123
+
124
  combined_width = 224 * 4
125
+ combined_height = 224 + 40
126
+
127
+ combined_image = Image.new("RGB", (combined_width, combined_height), "#f8eecf") # cream
128
+
129
  draw = ImageDraw.Draw(combined_image)
130
  font = ImageFont.load_default()
131
 
132
  for i, img in enumerate(example_imgs):
133
+ x = 224 * i
134
+ combined_image.paste(img, (x, 0))
135
+
136
+ # Highlight predicted class with red border
137
  if i == pred_idx:
138
+ draw.rectangle([x, 0, x + 223, 223], outline="#ff4444", width=5)
139
+
140
+ # Draw label text
141
  text = labels[i]
142
  bbox = draw.textbbox((0, 0), text, font=font)
143
+ tw = bbox[2] - bbox[0]
144
+ tx = x + (224 - tw) // 2
145
+ draw.text((tx, 228), text, fill="black", font=font)
 
 
146
 
147
  return pred_label, confidences, combined_image
148
 
149
+
150
+ # ===============================
151
+ # 6. CREAM UI THEME CSS
152
+ # ===============================
153
+
154
+ custom_css = """
155
+ :root {
156
+ --body-background-fill: #f8eecf;
157
+ --block-background-fill: #fffaf0;
158
+ --border-color: #d5c7a1;
159
+ --button-primary-background-fill: #d6b77a;
160
+ --button-primary-text-color: #000000;
161
+ }
162
+
163
+ body, .gradio-container {
164
+ background: #f8eecf !important;
165
+ }
166
+
167
+ .gr-button {
168
+ border-radius: 8px !important;
169
+ }
170
+ """
171
+
172
+ # ===============================
173
+ # 7. GRADIO INTERFACE
174
+ # ===============================
175
+
176
  interface = gr.Interface(
177
  fn=predict_tumor,
178
  inputs=gr.Image(type="pil", label="Upload MRI Scan"),
 
181
  gr.Label(label="Confidence Scores"),
182
  gr.Image(label="Reference Tumor Images")
183
  ],
184
+ title="🧠 Brain Tumor Classification (Attention CNN)",
185
+ description="Upload an MRI scan. Model predicts tumor type and shows reference examples.",
186
+ css=custom_css
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
 
189
  if __name__ == "__main__":
190
+ interface.launch()