ArchCoder commited on
Commit
f97533f
Β·
verified Β·
1 Parent(s): 69c21ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -264
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # full_app_with_heatmap.py
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
@@ -11,22 +10,11 @@ from torchvision import transforms
11
  import torchvision.transforms.functional as TF
12
  import urllib.request
13
  import os
14
- import random
15
- from glob import glob
16
- import kagglehub # if you use dataset download in the app; remove if not needed
17
 
18
- # -------------------------
19
- # Setup / Globals
20
- # -------------------------
21
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
  model = None
23
- dataset_images = []
24
- dataset_masks = []
25
- dataset_loaded = False
26
 
27
- # -------------------------
28
- # Model classes (Attention U-Net)
29
- # -------------------------
30
  class DoubleConv(nn.Module):
31
  def __init__(self, in_channels, out_channels):
32
  super(DoubleConv, self).__init__()
@@ -38,11 +26,9 @@ class DoubleConv(nn.Module):
38
  nn.BatchNorm2d(out_channels),
39
  nn.ReLU(inplace=True),
40
  )
41
-
42
  def forward(self, x):
43
  return self.conv(x)
44
 
45
-
46
  class AttentionBlock(nn.Module):
47
  def __init__(self, F_g, F_l, F_int):
48
  super(AttentionBlock, self).__init__()
@@ -60,14 +46,12 @@ class AttentionBlock(nn.Module):
60
  nn.Sigmoid()
61
  )
62
  self.relu = nn.ReLU(inplace=True)
63
-
64
  def forward(self, g, x):
65
  g1 = self.W_g(g)
66
  x1 = self.W_x(x)
67
  psi = self.relu(g1 + x1)
68
  psi = self.psi(psi)
69
- return x * psi, psi # return attended skip, attention map
70
-
71
 
72
  class AttentionUNET(nn.Module):
73
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
@@ -78,12 +62,15 @@ class AttentionUNET(nn.Module):
78
  self.attentions = nn.ModuleList()
79
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
80
 
 
81
  for feature in features:
82
  self.downs.append(DoubleConv(in_channels, feature))
83
  in_channels = feature
84
 
 
85
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
86
 
 
87
  for feature in reversed(features):
88
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
89
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
@@ -93,291 +80,283 @@ class AttentionUNET(nn.Module):
93
 
94
  def forward(self, x):
95
  skip_connections = []
96
- attention_maps = []
97
-
98
  for down in self.downs:
99
  x = down(x)
100
  skip_connections.append(x)
101
  x = self.pool(x)
102
-
103
  x = self.bottleneck(x)
104
- skip_connections = skip_connections[::-1]
105
 
106
  for idx in range(0, len(self.ups), 2):
107
  x = self.ups[idx](x)
108
- skip_connection = skip_connections[idx // 2]
109
-
110
  if x.shape != skip_connection.shape:
111
  x = TF.resize(x, size=skip_connection.shape[2:])
112
-
113
- attended_skip, att_map = self.attentions[idx // 2](x, skip_connection)
114
- attention_maps.append(att_map)
115
- concat_skip = torch.cat((attended_skip, x), dim=1)
116
  x = self.ups[idx+1](concat_skip)
 
117
 
118
- return self.final_conv(x), attention_maps
119
-
120
- # -------------------------
121
- # Model download / load
122
- # -------------------------
123
- def download_and_load_model():
124
- global model
125
- print("Loading Attention U-Net model...")
126
-
127
  model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
128
  model_path = "best_attention_model.pth.tar"
129
-
130
  if not os.path.exists(model_path):
131
- print("Downloading model weights...")
132
  try:
133
  urllib.request.urlretrieve(model_url, model_path)
 
134
  except Exception as e:
135
- print(f"Failed to download model: {e}")
136
- return False
 
 
 
137
 
138
- try:
139
- model = AttentionUNET(in_channels=1, out_channels=1).to(device)
140
- checkpoint = torch.load(model_path, map_location=device)
141
- # checkpoint format expected to have "state_dict"
142
- if "state_dict" in checkpoint:
143
- sd = checkpoint["state_dict"]
144
- else:
145
- sd = checkpoint
146
- # Try exact load; if mismatch, try strict=False and warn
147
  try:
 
 
 
 
 
 
 
 
 
 
 
148
  model.load_state_dict(sd)
149
- except Exception as ex:
150
- print(f"Warning: strict load failed: {ex}. Trying strict=False...")
151
- model.load_state_dict(sd, strict=False)
152
- model.eval()
153
- print("βœ“ Model loaded successfully!")
154
- return True
155
- except Exception as e:
156
- print(f"Failed to load model: {e}")
157
- model = None
158
- return False
159
-
160
- # -------------------------
161
- # Dataset utilities (optional)
162
- # -------------------------
163
- def download_and_load_dataset():
164
- global dataset_images, dataset_masks, dataset_loaded
165
- if dataset_loaded:
166
- return True
167
- try:
168
- print("Loading brain tumor dataset (kagglehub)...")
169
- dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
170
- images_dir = os.path.join(dataset_path, 'images')
171
- masks_dir = os.path.join(dataset_path, 'masks')
172
- if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
173
- # fallback search
174
- all_files = glob(os.path.join(dataset_path, "**/*.png"), recursive=True) + \
175
- glob(os.path.join(dataset_path, "**/*.jpg"), recursive=True)
176
- dataset_images = [f for f in all_files if '/images/' in f or 'image' in f.lower()]
177
- dataset_masks = [f for f in all_files if '/masks/' in f or 'mask' in f.lower()]
178
- else:
179
- dataset_images = sorted(glob(os.path.join(images_dir, "*.*")))
180
- dataset_masks = sorted(glob(os.path.join(masks_dir, "*.*")))
181
- print(f"βœ“ Found {len(dataset_images)} images and {len(dataset_masks)} masks")
182
- dataset_loaded = True
183
- return True
184
- except Exception as e:
185
- print(f"Failed to load dataset: {e}")
186
- return False
187
-
188
- def get_random_sample():
189
- if not dataset_loaded:
190
- return None, None, "Dataset not loaded"
191
- if not dataset_images:
192
- return None, None, "No images found"
193
- idx = random.randint(0, len(dataset_images)-1)
194
- img_path = dataset_images[idx]
195
- img_name = os.path.basename(img_path)
196
- mask_path = None
197
- for mask in dataset_masks:
198
- if os.path.basename(mask) == img_name:
199
- mask_path = mask
200
- break
201
- try:
202
- image = Image.open(img_path).convert("L")
203
- mask = Image.open(mask_path).convert("L") if mask_path else None
204
- return image, mask, img_name
205
- except Exception as e:
206
- return None, None, f"Error loading sample: {e}"
207
 
208
- # -------------------------
209
- # Preprocessing & Heatmap utils
210
- # -------------------------
211
- def preprocess_for_model(image):
212
  if image.mode != 'L':
213
  image = image.convert('L')
214
- transform = transforms.Compose([
215
- transforms.Resize((256, 256)),
216
  transforms.ToTensor()
217
  ])
218
- return transform(image).unsqueeze(0)
219
-
220
- def generate_attention_heatmap(attention_maps):
221
- if not attention_maps:
222
- return np.zeros((256, 256, 3), dtype=np.uint8)
223
- resized_maps = []
224
- target_size = (256, 256)
225
- for att_map in attention_maps:
226
- att_np = att_map.squeeze().cpu().numpy()
227
- att_resized = cv2.resize(att_np, target_size)
228
- resized_maps.append(att_resized)
229
- combined_att = np.mean(resized_maps, axis=0)
230
- combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8)
231
- heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET)
232
- return heatmap # BGR (OpenCV)
233
-
234
- # -------------------------
235
- # Core: produce combined 1x5 image (preserve old 1-4 behavior)
236
- # -------------------------
237
- def results_with_heatmap(image, ground_truth=None, filename=None, threshold=0.5):
238
- if model is None:
239
- return None, "Model not loaded. Please restart the application."
240
  if image is None:
241
- return None, "Please select an image first."
242
-
243
- # Keep preprocessing & prediction exactly like your working code
244
- img_gray = image.convert('L') if image.mode != 'L' else image
245
- original_np = np.array(img_gray.resize((256, 256))).astype(np.uint8)
246
 
247
- # Preprocess for model
248
- prep = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
249
- input_tensor = prep(img_gray).unsqueeze(0).to(device)
250
-
251
- with torch.no_grad():
252
- out = model(input_tensor)
253
- # support both: model -> logits OR (logits, att_maps)
254
- if isinstance(out, (list, tuple)) and len(out) == 2:
255
- logits, attention_maps = out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  else:
257
- logits = out
258
- attention_maps = []
259
-
260
- pred_prob = torch.sigmoid(logits)
261
- pred_mask = (pred_prob > threshold).float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
- pred_mask_np = pred_mask.cpu().squeeze().numpy() # (256,256)
264
- inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255).astype(np.uint8)
265
- tumor_only = np.where(pred_mask_np == 1, original_np, 255).astype(np.uint8)
 
266
 
267
- # ground truth handling (preserve old style)
268
- if ground_truth is not None:
269
- gt_gray = ground_truth.convert('L') if ground_truth.mode != 'L' else ground_truth
270
- mask_np = prep(gt_gray).cpu().squeeze().numpy()
271
- mask_vis = (mask_np > 0.5).astype(np.uint8)
272
- else:
273
- mask_vis = np.zeros_like(original_np)
274
-
275
- # Try to build attention heatmap; fallback to probability heatmap
276
- att_heat = generate_attention_heatmap(attention_maps)
277
- if att_heat is None or att_heat.size == 0:
278
- prob_np = pred_prob.cpu().squeeze().numpy()
279
- prob_resized = cv2.resize(prob_np, (256, 256))
280
- prob_norm = (prob_resized - prob_resized.min()) / (prob_resized.max() - prob_resized.min() + 1e-8)
281
- att_heat_bgr = cv2.applyColorMap((prob_norm * 255).astype(np.uint8), cv2.COLORMAP_JET)
282
- att_heat = att_heat_bgr
283
-
284
- # convert BGR->RGB for display
285
- try:
286
- att_heat = cv2.cvtColor(att_heat, cv2.COLOR_BGR2RGB)
287
- except Exception:
288
- pass
289
-
290
- # ensure dtype/shape
291
- if att_heat.dtype != np.uint8:
292
- att_heat = (att_heat * 255).astype(np.uint8) if att_heat.max() <= 1.0 else att_heat.astype(np.uint8)
293
- if att_heat.ndim == 2:
294
- att_heat = cv2.cvtColor(att_heat, cv2.COLOR_GRAY2RGB)
295
-
296
- # Create 1x5 figure
297
- fig, axes = plt.subplots(1, 5, figsize=(22, 5))
298
- fig.suptitle('Results + Heatmap', fontsize=16, weight='bold')
299
-
300
- axes[0].imshow(original_np, cmap='gray'); axes[0].set_title('Original Image'); axes[0].axis('off')
301
- axes[1].imshow(mask_vis, cmap='gray'); axes[1].set_title('Ground Truth Mask' if ground_truth is not None else 'GT (none)'); axes[1].axis('off')
302
- axes[2].imshow(inv_pred_mask_np, cmap='gray'); axes[2].set_title('Predicted Mask'); axes[2].axis('off')
303
- axes[3].imshow(tumor_only, cmap='gray'); axes[3].set_title('Tumor Only'); axes[3].axis('off')
304
- axes[4].imshow(att_heat); axes[4].set_title('Attention / Prob Heatmap'); axes[4].axis('off')
305
-
306
- plt.tight_layout()
307
-
308
- buf = io.BytesIO()
309
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
310
- buf.seek(0)
311
- plt.close(fig)
312
- result_img = Image.open(buf).convert("RGB")
313
-
314
- tumor_pixels = int(np.sum(pred_mask_np))
315
- total_pixels = int(pred_mask_np.size)
316
- tumor_pct = (tumor_pixels / total_pixels) * 100 if total_pixels > 0 else 0.0
317
-
318
- analysis_text = f"""
319
- # Analysis Results
320
- **File:** {filename if filename else 'Uploaded Image'}
321
- - Tumor Area: {tumor_pct:.2f}%
322
- - Tumor Pixels: {tumor_pixels:,}
323
- - Max confidence: {float(pred_prob.max()):.4f}
324
- - Threshold used: {threshold}
325
- """
326
 
327
- return result_img, analysis_text
328
-
329
- # -------------------------
330
- # Initialize model & dataset at startup
331
- # -------------------------
332
- print("Initializing application components...")
333
- model_loaded = download_and_load_model()
334
- dataset_loaded_success = download_and_load_dataset()
335
- if not model_loaded:
336
- print("WARNING: Model failed to load!")
337
- if not dataset_loaded_success:
338
- print("WARNING: Dataset failed to load!")
339
- print("Application ready!")
340
-
341
- # -------------------------
342
- # Gradio UI
343
- # -------------------------
344
  css = """
345
- .gradio-container { max-width: 1400px !important; margin:auto !important; font-family: 'Segoe UI', Tahoma, Verdana; }
346
- .gr-button { border-radius: 6px !important; font-weight: 500 !important; }
 
 
 
 
 
 
 
 
 
 
 
347
  """
348
 
349
- with gr.Blocks(css=css, title="Brain Tumor Segmentation + Heatmap") as app:
350
- gr.Markdown("# Brain Tumor Segmentation β€” Attention U-Net\nPreserves original 1–4 outputs; adds 5th: heatmap.")
 
 
 
 
 
 
 
 
 
 
 
351
  with gr.Row():
352
  with gr.Column(scale=1):
353
- image_display = gr.Image(label="Selected Image", type="pil", height=300)
 
 
 
 
 
 
 
354
  with gr.Row():
355
- load_sample_btn = gr.Button("Load Random Sample", variant="primary")
356
- upload_btn = gr.UploadButton("Upload Image", file_types=["image"])
357
- analyze_btn = gr.Button("Analyze Image", variant="primary", size="lg")
358
- gr.Markdown(f"**Model Status:** {'βœ“ Loaded' if model_loaded else 'βœ— Failed'} \n**Dataset:** {'βœ“ Loaded' if dataset_loaded_success else 'βœ— Failed'}")
 
 
 
 
 
 
 
 
 
 
 
359
  with gr.Column(scale=2):
360
- gr.Markdown("### Results (1x5 panel)")
361
- result_display = gr.Image(label="Segmentation + Heatmap", type="pil", height=600)
362
- analysis_text = gr.Markdown("Upload or load a sample and click Analyze.")
363
-
364
- current_ground_truth = gr.State()
365
- current_filename = gr.State()
366
-
367
- def handle_sample_load():
368
- image, mask, filename = get_random_sample()
369
- return image, mask, filename
370
-
371
- def handle_upload(f):
372
- if f is not None:
373
- img = Image.open(f.name).convert("L")
374
- return img, None, os.path.basename(f.name)
375
- return None, None, ""
376
-
377
- load_sample_btn.click(fn=handle_sample_load, outputs=[image_display, current_ground_truth, current_filename])
378
- upload_btn.upload(fn=handle_upload, inputs=[upload_btn], outputs=[image_display, current_ground_truth, current_filename])
379
-
380
- analyze_btn.click(fn=results_with_heatmap, inputs=[image_display, current_ground_truth, current_filename], outputs=[result_display, analysis_text])
 
 
381
 
382
  if __name__ == "__main__":
 
383
  app.launch(server_name="0.0.0.0", server_port=7860, show_error=True, share=False)
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
 
10
  import torchvision.transforms.functional as TF
11
  import urllib.request
12
  import os
 
 
 
13
 
 
 
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model = None
 
 
 
16
 
17
+ # ---- model classes (kept equivalent to your working code) ----
 
 
18
  class DoubleConv(nn.Module):
19
  def __init__(self, in_channels, out_channels):
20
  super(DoubleConv, self).__init__()
 
26
  nn.BatchNorm2d(out_channels),
27
  nn.ReLU(inplace=True),
28
  )
 
29
  def forward(self, x):
30
  return self.conv(x)
31
 
 
32
  class AttentionBlock(nn.Module):
33
  def __init__(self, F_g, F_l, F_int):
34
  super(AttentionBlock, self).__init__()
 
46
  nn.Sigmoid()
47
  )
48
  self.relu = nn.ReLU(inplace=True)
 
49
  def forward(self, g, x):
50
  g1 = self.W_g(g)
51
  x1 = self.W_x(x)
52
  psi = self.relu(g1 + x1)
53
  psi = self.psi(psi)
54
+ return x * psi # matches the old working code
 
55
 
56
  class AttentionUNET(nn.Module):
57
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
 
62
  self.attentions = nn.ModuleList()
63
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
64
 
65
+ # down
66
  for feature in features:
67
  self.downs.append(DoubleConv(in_channels, feature))
68
  in_channels = feature
69
 
70
+ # bottleneck
71
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
72
 
73
+ # up
74
  for feature in reversed(features):
75
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
76
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
 
80
 
81
  def forward(self, x):
82
  skip_connections = []
 
 
83
  for down in self.downs:
84
  x = down(x)
85
  skip_connections.append(x)
86
  x = self.pool(x)
 
87
  x = self.bottleneck(x)
88
+ skip_connections = skip_connections[::-1] # reverse
89
 
90
  for idx in range(0, len(self.ups), 2):
91
  x = self.ups[idx](x)
92
+ skip_connection = skip_connections[idx//2]
 
93
  if x.shape != skip_connection.shape:
94
  x = TF.resize(x, size=skip_connection.shape[2:])
95
+ # attention applied exactly as in your working code
96
+ skip_connection = self.attentions[idx // 2](skip_connection, x)
97
+ concat_skip = torch.cat((skip_connection, x), dim=1)
 
98
  x = self.ups[idx+1](concat_skip)
99
+ return self.final_conv(x)
100
 
101
+ # ---- model download/load helpers (same as yours) ----
102
+ def download_model():
 
 
 
 
 
 
 
103
  model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
104
  model_path = "best_attention_model.pth.tar"
 
105
  if not os.path.exists(model_path):
106
+ print("πŸ“₯ Downloading your trained model...")
107
  try:
108
  urllib.request.urlretrieve(model_url, model_path)
109
+ print("βœ… Model downloaded successfully!")
110
  except Exception as e:
111
+ print(f"❌ Failed to download model: {e}")
112
+ return None
113
+ else:
114
+ print("βœ… Model already exists!")
115
+ return model_path
116
 
117
+ def load_your_attention_model():
118
+ global model
119
+ if model is None:
 
 
 
 
 
 
120
  try:
121
+ print("πŸ”„ Loading your trained Attention U-Net model...")
122
+ model_path = download_model()
123
+ if model_path is None:
124
+ return None
125
+ model = AttentionUNET(in_channels=1, out_channels=1).to(device)
126
+ checkpoint = torch.load(model_path, map_location=device)
127
+ # checkpoint expected to have "state_dict" key as in your working code
128
+ if "state_dict" in checkpoint:
129
+ sd = checkpoint["state_dict"]
130
+ else:
131
+ sd = checkpoint
132
  model.load_state_dict(sd)
133
+ model.eval()
134
+ print("βœ… Your Attention U-Net model loaded successfully!")
135
+ except Exception as e:
136
+ print(f"❌ Error loading your model: {e}")
137
+ model = None
138
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ # ---- preprocessing (same as your Colab code) ----
141
+ def preprocess_for_your_model(image):
 
 
142
  if image.mode != 'L':
143
  image = image.convert('L')
144
+ val_test_transform = transforms.Compose([
145
+ transforms.Resize((256,256)),
146
  transforms.ToTensor()
147
  ])
148
+ return val_test_transform(image).unsqueeze(0) # Add batch dimension
149
+
150
+ # ---- main predict function (modified to add separate heatmap, no change to 1-4) ----
151
+ def predict_tumor(image):
152
+ """
153
+ Keeps the exact old 4-panel outputs the same, and adds a 5th panel with
154
+ the probability heatmap. The heatmap is computed from the sigmoid(probabilities)
155
+ and does not change any tensors used for predictions.
156
+ """
157
+ current_model = load_your_attention_model()
158
+ if current_model is None:
159
+ return None, "❌ Failed to load your trained model."
 
 
 
 
 
 
 
 
 
 
160
  if image is None:
161
+ return None, "⚠️ Please upload an image first."
 
 
 
 
162
 
163
+ try:
164
+ print("🧠 Processing with YOUR trained Attention U-Net...")
165
+
166
+ # Preprocess exactly like your Colab
167
+ input_tensor = preprocess_for_your_model(image).to(device) # [1,1,256,256]
168
+
169
+ # Forward and prediction (identical to your working code)
170
+ with torch.no_grad():
171
+ logits = current_model(input_tensor) # model returns logits tensor
172
+ pred_prob = torch.sigmoid(logits) # keep prob map for heatmap
173
+ pred_mask = (pred_prob > 0.5).float() # binary mask (same threshold as old code)
174
+
175
+ # Convert to numpy like old code
176
+ pred_mask_np = pred_mask.cpu().squeeze().numpy() # shape (256,256)
177
+ original_np = np.array(image.convert('L').resize((256, 256)))
178
+ inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
179
+ tumor_only = np.where(pred_mask_np == 1, original_np, 255)
180
+
181
+ # -------------------------
182
+ # Create heatmap (NO CHANGES to pred_mask or any prediction tensors)
183
+ # -------------------------
184
+ # Use the probability map (float) as the basis
185
+ pred_prob_np = pred_prob.cpu().squeeze().numpy() # float in [0,1]
186
+ # ensure same shape 256x256
187
+ if pred_prob_np.shape != (256, 256):
188
+ pred_prob_resized = cv2.resize(pred_prob_np, (256, 256))
189
  else:
190
+ pred_prob_resized = pred_prob_np.copy()
191
+
192
+ # Normalize to 0-1 and convert to uint8 for colormap
193
+ prob_norm = (pred_prob_resized - pred_prob_resized.min()) / (pred_prob_resized.max() - pred_prob_resized.min() + 1e-8)
194
+ prob_uint8 = (prob_norm * 255).astype(np.uint8)
195
+ prob_heatmap_bgr = cv2.applyColorMap(prob_uint8, cv2.COLORMAP_JET) # OpenCV BGR
196
+ # Convert BGR -> RGB for matplotlib/PIL visualization
197
+ prob_heatmap_rgb = cv2.cvtColor(prob_heatmap_bgr, cv2.COLOR_BGR2RGB)
198
+
199
+ # -------------------------
200
+ # Build the 5-panel figure
201
+ # Panels (left->right): Original | Pred segmentation (pred*255) | Inverted mask | Tumor only | Heatmap
202
+ # Panels 1-4 are produced exactly the same as your old code
203
+ # -------------------------
204
+ fig, axes = plt.subplots(1, 5, figsize=(24, 5))
205
+ fig.suptitle('🧠 Your Attention U-Net Results (with Heatmap)', fontsize=16, fontweight='bold')
206
+
207
+ # 1 Original (gray)
208
+ axes[0].imshow(original_np, cmap='gray')
209
+ axes[0].set_title("Original Image", fontsize=12, fontweight='bold')
210
+ axes[0].axis('off')
211
+
212
+ # 2 Tumor Segmentation (pred*255) β€” identical to old code's second panel
213
+ axes[1].imshow(pred_mask_np * 255, cmap='hot')
214
+ axes[1].set_title("Tumor Segmentation (pred Γ— 255)", fontsize=12, fontweight='bold')
215
+ axes[1].axis('off')
216
+
217
+ # 3 Inverted mask β€” identical
218
+ axes[2].imshow(inv_pred_mask_np, cmap='gray')
219
+ axes[2].set_title("Inverted Mask (visual)", fontsize=12, fontweight='bold')
220
+ axes[2].axis('off')
221
+
222
+ # 4 Tumor only (grayscale crop) β€” identical
223
+ axes[3].imshow(tumor_only, cmap='gray')
224
+ axes[3].set_title("Tumor Only", fontsize=12, fontweight='bold')
225
+ axes[3].axis('off')
226
+
227
+ # 5 Heatmap (RGB)
228
+ axes[4].imshow(prob_heatmap_rgb)
229
+ axes[4].set_title("Probability Heatmap (sigmoid)", fontsize=12, fontweight='bold')
230
+ axes[4].axis('off')
231
+
232
+ plt.tight_layout()
233
+
234
+ # Save plot
235
+ buf = io.BytesIO()
236
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
237
+ buf.seek(0)
238
+ plt.close()
239
+ result_image = Image.open(buf)
240
+
241
+ # Calculate statistics (like your Colab code)
242
+ tumor_pixels = int(np.sum(pred_mask_np))
243
+ total_pixels = int(pred_mask_np.size)
244
+ tumor_percentage = (tumor_pixels / total_pixels) * 100 if total_pixels > 0 else 0.0
245
+
246
+ # Confidence metrics (from the probability tensor)
247
+ max_confidence = float(pred_prob.max().item())
248
+ mean_confidence = float(pred_prob.mean().item())
249
+
250
+ analysis_text = f"""
251
+ ## 🧠 Your Attention U-Net Analysis Results
252
+ ### πŸ“Š Detection Summary:
253
+ - **Status**: {'πŸ”΄ TUMOR DETECTED' if tumor_pixels > 50 else '🟒 NO SIGNIFICANT TUMOR'}
254
+ - **Tumor Area**: {tumor_percentage:.2f}% of image
255
+ - **Tumor Pixels**: {tumor_pixels:,} pixels
256
+ - **Max Confidence**: {max_confidence:.4f}
257
+ - **Mean Confidence**: {mean_confidence:.4f}
258
+
259
+ ### πŸ”¬ Model Info:
260
+ - **Architecture**: YOUR Attention U-Net
261
+ - **Input**: Grayscale (single channel), resized to 256Γ—256
262
+ - **Threshold**: 0.5 (sigmoid > 0.5)
263
+ - **Device**: {device.type.upper()}
264
+
265
+ ### ⚠️ Disclaimer:
266
+ This is for research/education only. Validate results with medical professionals.
267
+ """
268
+ print(f"βœ… Your model analysis completed! Tumor area: {tumor_percentage:.2f}%")
269
+ return result_image, analysis_text
270
 
271
+ except Exception as e:
272
+ error_msg = f"❌ Error with your model: {str(e)}"
273
+ print(error_msg)
274
+ return None, error_msg
275
 
276
+ def clear_all():
277
+ return None, "Upload a brain MRI image to test YOUR trained Attention U-Net model"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ # ---- Gradio UI (kept as you had it, but wired to the new predict function) ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  css = """
281
+ .gradio-container {
282
+ max-width: 1400px !important;
283
+ margin: auto !important;
284
+ }
285
+ #title {
286
+ text-align: center;
287
+ background: linear-gradient(135deg, #8B5CF6 0%, #7C3AED 100%);
288
+ color: white;
289
+ padding: 30px;
290
+ border-radius: 15px;
291
+ margin-bottom: 25px;
292
+ box-shadow: 0 8px 16px rgba(139, 92, 246, 0.3);
293
+ }
294
  """
295
 
296
+ with gr.Blocks(css=css, title="🧠 Your Attention U-Net Model", theme=gr.themes.Soft()) as app:
297
+ gr.HTML("""
298
+ <div id="title">
299
+ <h1>🧠 YOUR Attention U-Net Model</h1>
300
+ <p style="font-size: 18px; margin-top: 15px;">
301
+ Using Your Own Trained Model β€’ Dice: 0.8420 β€’ IoU: 0.7297
302
+ </p>
303
+ <p style="font-size: 14px; margin-top: 10px; opacity: 0.9;">
304
+ Loaded from: ArchCoder/the-op-segmenter HuggingFace Space
305
+ </p>
306
+ </div>
307
+ """)
308
+
309
  with gr.Row():
310
  with gr.Column(scale=1):
311
+ gr.Markdown("### πŸ“€ Upload Brain MRI")
312
+ image_input = gr.Image(
313
+ label="Brain MRI Scan",
314
+ type="pil",
315
+ sources=["upload", "webcam"],
316
+ height=350
317
+ )
318
+
319
  with gr.Row():
320
+ analyze_btn = gr.Button("πŸ” Analyze with YOUR Model", variant="primary", scale=2, size="lg")
321
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
322
+
323
+ gr.HTML("""
324
+ <div style="margin-top: 20px; padding: 20px; background: linear-gradient(135deg, #F3E8FF 0%, #EDE9FE 100%); border-radius: 10px; border-left: 4px solid #8B5CF6;">
325
+ <h4 style="color: #8B5CF6; margin-bottom: 15px;">πŸ† Your Model Features:</h4>
326
+ <ul style="margin: 10px 0; padding-left: 20px; line-height: 1.6;">
327
+ <li><strong>Personal Model:</strong> Your own trained Attention U-Net</li>
328
+ <li><strong>Proven Performance:</strong> 84.2% Dice Score, 72.97% IoU</li>
329
+ <li><strong>Attention Gates:</strong> Advanced feature selection</li>
330
+ <li><strong>Clean Output:</strong> Binary segmentation masks</li>
331
+ <li><strong>5-Panel View:</strong> Original, Segmentation, Inverted, Tumor-only, Heatmap</li>
332
+ </ul>
333
+ </div>
334
+ """)
335
  with gr.Column(scale=2):
336
+ gr.Markdown("### πŸ“Š Your Model Results")
337
+ output_image = gr.Image(
338
+ label="Your Attention U-Net Analysis",
339
+ type="pil",
340
+ height=500
341
+ )
342
+ analysis_output = gr.Markdown(
343
+ value="Upload a brain MRI image to test YOUR trained Attention U-Net model.",
344
+ elem_id="analysis"
345
+ )
346
+
347
+ analyze_btn.click(
348
+ fn=predict_tumor,
349
+ inputs=[image_input],
350
+ outputs=[output_image, analysis_output],
351
+ show_progress=True
352
+ )
353
+
354
+ clear_btn.click(
355
+ fn=clear_all,
356
+ inputs=[],
357
+ outputs=[image_input, analysis_output]
358
+ )
359
 
360
  if __name__ == "__main__":
361
+ print("πŸš€ Starting YOUR Attention U-Net Model System...")
362
  app.launch(server_name="0.0.0.0", server_port=7860, show_error=True, share=False)