ArchCoder commited on
Commit
69c21ad
·
verified ·
1 Parent(s): 7b7ff94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -399
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
@@ -6,26 +7,26 @@ import cv2
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
8
  import io
9
- import torchvision.transforms as transforms
10
  import torchvision.transforms.functional as TF
11
- import random
12
- import os
13
  import urllib.request
14
- import kagglehub
 
15
  from glob import glob
 
16
 
17
- # Global variables - loaded once at startup
 
 
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
  model = None
20
  dataset_images = []
21
  dataset_masks = []
22
  dataset_loaded = False
23
 
24
- print("="*50)
25
- print("BRAIN TUMOR SEGMENTATION APPLICATION")
26
- print("="*50)
27
-
28
- # Your Attention U-Net classes (unchanged)
29
  class DoubleConv(nn.Module):
30
  def __init__(self, in_channels, out_channels):
31
  super(DoubleConv, self).__init__()
@@ -49,26 +50,24 @@ class AttentionBlock(nn.Module):
49
  nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
50
  nn.BatchNorm2d(F_int)
51
  )
52
-
53
  self.W_x = nn.Sequential(
54
  nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
55
  nn.BatchNorm2d(F_int)
56
  )
57
-
58
  self.psi = nn.Sequential(
59
  nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
60
  nn.BatchNorm2d(1),
61
  nn.Sigmoid()
62
  )
63
-
64
  self.relu = nn.ReLU(inplace=True)
65
-
66
  def forward(self, g, x):
67
  g1 = self.W_g(g)
68
  x1 = self.W_x(x)
69
  psi = self.relu(g1 + x1)
70
  psi = self.psi(psi)
71
- return x * psi, psi # Return both attended features AND attention map
 
72
 
73
  class AttentionUNET(nn.Module):
74
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
@@ -84,7 +83,7 @@ class AttentionUNET(nn.Module):
84
  in_channels = feature
85
 
86
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
87
-
88
  for feature in reversed(features):
89
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
90
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
@@ -106,7 +105,7 @@ class AttentionUNET(nn.Module):
106
 
107
  for idx in range(0, len(self.ups), 2):
108
  x = self.ups[idx](x)
109
- skip_connection = skip_connections[idx//2]
110
 
111
  if x.shape != skip_connection.shape:
112
  x = TF.resize(x, size=skip_connection.shape[2:])
@@ -118,15 +117,16 @@ class AttentionUNET(nn.Module):
118
 
119
  return self.final_conv(x), attention_maps
120
 
 
 
 
121
  def download_and_load_model():
122
- """Download and load model once at startup"""
123
  global model
124
  print("Loading Attention U-Net model...")
125
-
126
  model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
127
  model_path = "best_attention_model.pth.tar"
128
-
129
- # Download model if needed
130
  if not os.path.exists(model_path):
131
  print("Downloading model weights...")
132
  try:
@@ -134,90 +134,70 @@ def download_and_load_model():
134
  except Exception as e:
135
  print(f"Failed to download model: {e}")
136
  return False
137
-
138
- # Load model
139
  try:
140
  model = AttentionUNET(in_channels=1, out_channels=1).to(device)
141
- checkpoint = torch.load(model_path, map_location=device, weights_only=True)
142
- model.load_state_dict(checkpoint["state_dict"])
 
 
 
 
 
 
 
 
 
 
143
  model.eval()
144
  print("✓ Model loaded successfully!")
145
  return True
146
  except Exception as e:
147
  print(f"Failed to load model: {e}")
 
148
  return False
149
 
 
 
 
150
  def download_and_load_dataset():
151
- """Download and load entire dataset once at startup"""
152
  global dataset_images, dataset_masks, dataset_loaded
153
-
154
  if dataset_loaded:
155
  return True
156
-
157
- print("Loading brain tumor dataset...")
158
-
159
  try:
160
- # Download dataset using kagglehub - returns directory path
161
  dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
162
- print(f"Dataset downloaded to: {dataset_path}")
163
-
164
- # Find images and masks directories
165
  images_dir = os.path.join(dataset_path, 'images')
166
  masks_dir = os.path.join(dataset_path, 'masks')
167
-
168
- # If direct path doesn't exist, search subdirectories
169
- if not os.path.exists(images_dir):
170
- # Search for images and masks directories
171
- for root, dirs, files in os.walk(dataset_path):
172
- if 'images' in dirs:
173
- images_dir = os.path.join(root, 'images')
174
- if 'masks' in dirs:
175
- masks_dir = os.path.join(root, 'masks')
176
-
177
  if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
178
- print("Could not find images/masks directories. Searching all files...")
179
- # Fallback: find all image files
180
  all_files = glob(os.path.join(dataset_path, "**/*.png"), recursive=True) + \
181
- glob(os.path.join(dataset_path, "**/*.jpg"), recursive=True)
182
-
183
  dataset_images = [f for f in all_files if '/images/' in f or 'image' in f.lower()]
184
  dataset_masks = [f for f in all_files if '/masks/' in f or 'mask' in f.lower()]
185
  else:
186
- # Load image and mask file paths
187
- dataset_images = glob(os.path.join(images_dir, "*.*"))
188
- dataset_masks = glob(os.path.join(masks_dir, "*.*"))
189
-
190
- dataset_images = sorted(dataset_images)
191
- dataset_masks = sorted(dataset_masks)
192
-
193
  print(f"✓ Found {len(dataset_images)} images and {len(dataset_masks)} masks")
194
  dataset_loaded = True
195
  return True
196
-
197
  except Exception as e:
198
  print(f"Failed to load dataset: {e}")
199
  return False
200
 
201
  def get_random_sample():
202
- """Get a random image and corresponding mask from dataset"""
203
  if not dataset_loaded:
204
  return None, None, "Dataset not loaded"
205
-
206
  if not dataset_images:
207
- return None, None, "No images found in dataset"
208
-
209
- # Get random index
210
- idx = random.randint(0, len(dataset_images) - 1)
211
  img_path = dataset_images[idx]
212
-
213
- # Find corresponding mask
214
  img_name = os.path.basename(img_path)
215
  mask_path = None
216
  for mask in dataset_masks:
217
  if os.path.basename(mask) == img_name:
218
  mask_path = mask
219
  break
220
-
221
  try:
222
  image = Image.open(img_path).convert("L")
223
  mask = Image.open(mask_path).convert("L") if mask_path else None
@@ -225,394 +205,179 @@ def get_random_sample():
225
  except Exception as e:
226
  return None, None, f"Error loading sample: {e}"
227
 
 
 
 
228
  def preprocess_for_model(image):
229
- """Preprocessing for your model - matches the working notebook"""
230
  if image.mode != 'L':
231
  image = image.convert('L')
232
-
233
  transform = transforms.Compose([
234
- transforms.Resize((256,256)),
235
  transforms.ToTensor()
236
  ])
237
-
238
  return transform(image).unsqueeze(0)
239
 
240
  def generate_attention_heatmap(attention_maps):
241
- """Generate attention heatmap"""
242
  if not attention_maps:
243
- return np.zeros((256, 256, 3))
244
-
245
- # Resize all attention maps to the same size (256x256) before combining
246
  resized_maps = []
247
  target_size = (256, 256)
248
-
249
  for att_map in attention_maps:
250
- # Convert to numpy and squeeze
251
  att_np = att_map.squeeze().cpu().numpy()
252
-
253
- # Resize to target size
254
  att_resized = cv2.resize(att_np, target_size)
255
  resized_maps.append(att_resized)
256
-
257
- # Now we can safely average the maps since they're all the same size
258
  combined_att = np.mean(resized_maps, axis=0)
259
-
260
- # Normalize to [0, 1]
261
  combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8)
262
-
263
- # Apply colormap
264
  heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET)
265
-
266
- return heatmap
267
-
268
- def analyze_image(image, ground_truth, filename, debug=True):
269
- """
270
- Replacement analyze_image that:
271
- - Accepts model returning either logits or (logits, attention_maps)
272
- - Prints detailed stats and shapes
273
- - Produces prob heatmap (no threshold) for debugging
274
- - Fixes broadcasting/color issues for visualization
275
- - Returns (PIL.Image, markdown_text)
276
- """
277
  if model is None:
278
  return None, "Model not loaded. Please restart the application."
279
-
280
  if image is None:
281
  return None, "Please select an image first."
282
 
283
- try:
284
- print("=" * 50)
285
- print("DEBUG: Starting analysis...")
286
- print(f"Input image mode: {image.mode}")
287
- print(f"Input image size: {image.size}")
288
-
289
- # Preprocess - same as your notebook/app
290
- input_tensor = preprocess_for_model(image).to(device) # shape [1,1,256,256]
291
- print(f"Input tensor shape: {input_tensor.shape}")
292
- print(f"Input tensor min/max: {input_tensor.min():.4f}/{input_tensor.max():.4f}")
293
-
294
- # Forward pass and robust unpacking (support both return styles)
295
- with torch.no_grad():
296
- out = model(input_tensor)
297
- # If model returned tuple/list: (logits, attention_maps)
298
- if isinstance(out, (list, tuple)) and len(out) == 2:
299
- logits, attention_maps = out
300
- else:
301
- # assume out is logits tensor and no attention maps were returned
302
- logits = out
303
- attention_maps = []
304
-
305
- # Ensure logits is a tensor
306
- if not torch.is_tensor(logits):
307
- raise RuntimeError("Model output is not a tensor. Check model forward() return type.")
308
-
309
- print(f"Model output (logits) shape: {logits.shape}")
310
- print(f"Model output min/max BEFORE sigmoid: {logits.min():.4f}/{logits.max():.4f}")
311
-
312
- # Probabilities (sigmoid)
313
- pred_prob = torch.sigmoid(logits)
314
- print(f"Pred prob min/max: {pred_prob.min():.4f}/{pred_prob.max():.4f}")
315
-
316
- # Convert to numpy for visualization; keep a float prob map for the heatmap
317
- pred_prob_np = pred_prob.cpu().squeeze().numpy() # shape (H, W)
318
- pred_mask_bin = (pred_prob_np > 0.5).astype(np.uint8) # default threshold 0.5
319
-
320
- print(f"Binary mask (0.5 threshold) sum: {pred_mask_bin.sum()}")
321
-
322
- # Debug: print attention maps shapes and stats
323
- if debug:
324
- print("Attention maps info:")
325
- for i, att in enumerate(attention_maps):
326
- try:
327
- att_np = att.squeeze().cpu().numpy()
328
- print(f" att[{i}] shape: {att_np.shape} min/max: {att_np.min():.4f}/{att_np.max():.4f}")
329
- except Exception as ex:
330
- print(f" att[{i}] inspect failed: {ex}")
331
-
332
- # Build prob heatmap (no threshold) for debugging
333
- try:
334
- prob_resized = cv2.resize(pred_prob_np, (256, 256)) if pred_prob_np.shape != (256, 256) else pred_prob_np
335
- prob_norm = (prob_resized - prob_resized.min()) / (prob_resized.max() - prob_resized.min() + 1e-8)
336
- prob_heatmap_bgr = cv2.applyColorMap((prob_norm * 255).astype(np.uint8), cv2.COLORMAP_JET)
337
- prob_heatmap = cv2.cvtColor(prob_heatmap_bgr, cv2.COLOR_BGR2RGB)
338
- except Exception:
339
- prob_heatmap = np.zeros((256, 256, 3), dtype=np.uint8)
340
-
341
- # Generate attention heatmap (reuse your function), convert BGR->RGB
342
- att_heatmap = generate_attention_heatmap(attention_maps)
343
- if att_heatmap is not None and att_heatmap.size != 0:
344
- try:
345
- att_heatmap = cv2.cvtColor(att_heatmap, cv2.COLOR_BGR2RGB)
346
- except Exception:
347
- pass
348
-
349
- # Prepare images (gray and rgb)
350
- original_gray = np.array(image.convert('L').resize((256, 256))).astype(np.uint8)
351
- original_rgb = np.array(image.convert('RGB').resize((256, 256))).astype(np.uint8)
352
-
353
- # Ensure binary mask dtype/shape consistency
354
- pred_mask_bin = (pred_mask_bin > 0).astype(np.uint8)
355
- inv_pred_mask_np = np.where(pred_mask_bin == 1, 0, 255).astype(np.uint8)
356
-
357
- tumor_only_gray = np.where(pred_mask_bin == 1, original_gray, 255).astype(np.uint8)
358
- tumor_only_rgb = original_rgb.copy()
359
- tumor_only_rgb[pred_mask_bin == 0] = 255
360
-
361
- # Decide grid: show prob heatmap next to attention so you can compare
362
- if ground_truth is not None:
363
- fig, axes = plt.subplots(3, 4, figsize=(16, 12)) # add an extra row for debug heatmap
364
- else:
365
- fig, axes = plt.subplots(3, 3, figsize=(15, 12))
366
-
367
- fig.suptitle('Brain Tumor Segmentation Analysis (debug)', fontsize=18, weight='bold')
368
-
369
- # Row 1
370
- axes[0,0].imshow(original_gray, cmap='gray'); axes[0,0].set_title('Original'); axes[0,0].axis('off')
371
- axes[0,1].imshow(original_rgb);
372
- if att_heatmap is not None and att_heatmap.size != 0:
373
- axes[0,1].imshow(att_heatmap, alpha=0.45)
374
- axes[0,1].set_title('Attention Heatmap (overlay)'); axes[0,1].axis('off')
375
- axes[0,2].imshow(inv_pred_mask_np, cmap='gray'); axes[0,2].set_title('Pred Mask (inv)'); axes[0,2].axis('off')
376
- if ground_truth is not None:
377
- axes[0,3].imshow(tumor_only_rgb); axes[0,3].set_title('Tumor Only (RGB)'); axes[0,3].axis('off')
378
-
379
- # Row 2
380
- if ground_truth is not None:
381
- # show GT and overlay and metrics
382
- val_test_transform = transforms.Compose([transforms.Resize((256,256)), transforms.ToTensor()])
383
- mask_np = val_test_transform(ground_truth).cpu().squeeze().numpy()
384
- mask_bin = (mask_np > 0.5).astype(np.uint8)
385
-
386
- axes[1,0].imshow(mask_bin, cmap='gray'); axes[1,0].set_title('Ground Truth Mask'); axes[1,0].axis('off')
387
- overlay = original_rgb.copy()
388
- overlay[pred_mask_bin == 1] = [0,255,0]
389
- overlay[mask_bin == 1] = [255,0,0]
390
- axes[1,1].imshow(overlay); axes[1,1].set_title('Prediction (G) vs GT (R)'); axes[1,1].axis('off')
391
-
392
- intersection = np.logical_and(pred_mask_bin, mask_bin).sum()
393
- union = np.logical_or(pred_mask_bin, mask_bin).sum()
394
- iou = intersection / (union + 1e-7)
395
- dice = (2 * intersection) / (pred_mask_bin.sum() + mask_bin.sum() + 1e-7)
396
-
397
- axes[1,2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold')
398
- axes[1,2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold')
399
- axes[1,2].axis('off'); axes[1,2].set_title('Metrics')
400
-
401
- axes[1,3].imshow(tumor_only_gray, cmap='gray'); axes[1,3].set_title('Segmented Tumor'); axes[1,3].axis('off')
402
- else:
403
- # No GT: second row shows predicted mask, tumor only and overlay
404
- axes[1,0].imshow(inv_pred_mask_np, cmap='gray'); axes[1,0].set_title('Predicted Mask'); axes[1,0].axis('off')
405
- axes[1,1].imshow(tumor_only_gray, cmap='gray'); axes[1,1].set_title('Tumor Only'); axes[1,1].axis('off')
406
- overlay = original_rgb.copy(); overlay[pred_mask_bin==1] = [255,0,0]
407
- axes[1,2].imshow(overlay); axes[1,2].set_title('Prediction Overlay'); axes[1,2].axis('off')
408
-
409
- # Row 3 (debug): probability heatmap + (optional) raw att channel thumbnails
410
- axes[2,0].imshow(original_rgb); axes[2,0].imshow(prob_heatmap, alpha=0.5); axes[2,0].set_title('Prob Heatmap (overlay)'); axes[2,0].axis('off')
411
- # show the plain probability heatmap
412
- axes[2,1].imshow(prob_heatmap); axes[2,1].set_title('Prob Heatmap (plain)'); axes[2,1].axis('off')
413
-
414
- # if we have attention maps, show up to two scaled maps for quick check
415
- if len(attention_maps) >= 1:
416
- try:
417
- att0 = attention_maps[0].squeeze().cpu().numpy()
418
- att0 = cv2.resize((att0 - att0.min())/(att0.max()-att0.min()+1e-8), (256,256))
419
- axes[2,2].imshow(att0, cmap='viridis'); axes[2,2].set_title('Att map 0 (rescaled)'); axes[2,2].axis('off')
420
- except Exception:
421
- axes[2,2].axis('off')
422
- else:
423
- axes[2,2].axis('off')
424
-
425
- # hide any unused axes (robust)
426
- for ax_row in axes.reshape(-1):
427
- if not hasattr(ax_row, 'has_data') or ax_row.images == []:
428
- ax_row.axis('off')
429
 
430
- plt.tight_layout()
 
 
431
 
432
- # Save plot to buffer and return as PIL image
433
- buf = io.BytesIO()
434
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
435
- buf.seek(0)
436
- plt.close()
437
- result_image = Image.open(buf).convert("RGB")
438
-
439
- # Numeric analysis text
440
- tumor_pixels = int(pred_mask_bin.sum())
441
- total_pixels = int(pred_mask_bin.size)
442
- tumor_percentage = (tumor_pixels / total_pixels) * 100 if total_pixels > 0 else 0.0
443
-
444
- analysis_text = f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  # Analysis Results
446
-
447
  **File:** {filename if filename else 'Uploaded Image'}
448
-
449
- **Tumor Detection:**
450
- - Tumor Area: {tumor_percentage:.2f}%
451
  - Tumor Pixels: {tumor_pixels:,}
452
-
453
- **Model Features:**
454
- - Attention Visualization: Generated
455
- - Probability Heatmap: Generated
456
- """
457
-
458
- if ground_truth is not None:
459
- analysis_text += f"""
460
- **Performance Metrics:**
461
- - IoU Score: {iou:.4f}
462
- - Dice Score: {dice:.4f}
463
  """
464
 
465
- # Extra helpful hint when predictions are all zero
466
- if debug and pred_prob_np.max() < 0.5:
467
- analysis_text += "\n\n**Debug hint:** model probabilities are low (max < 0.5). Try lowering threshold (e.g. 0.3) or inspect model weights/loading."
468
 
469
- return result_image, analysis_text
470
-
471
- except Exception as e:
472
- import traceback
473
- error_msg = f"Analysis failed: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
474
- print(error_msg)
475
- return None, error_msg
476
-
477
-
478
- # Initialize model and dataset at startup
479
  print("Initializing application components...")
480
  model_loaded = download_and_load_model()
481
  dataset_loaded_success = download_and_load_dataset()
482
-
483
  if not model_loaded:
484
  print("WARNING: Model failed to load!")
485
  if not dataset_loaded_success:
486
  print("WARNING: Dataset failed to load!")
487
-
488
  print("Application ready!")
489
 
490
- # Professional CSS
 
 
491
  css = """
492
- .gradio-container {
493
- max-width: 1600px !important;
494
- margin: auto !important;
495
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
496
- }
497
- .gr-button {
498
- border-radius: 6px !important;
499
- font-weight: 500 !important;
500
- }
501
- .gr-button-primary {
502
- background: #2563eb !important;
503
- border-color: #2563eb !important;
504
- }
505
- .gr-button-secondary {
506
- background: #6b7280 !important;
507
- border-color: #6b7280 !important;
508
- }
509
- h1, h2, h3 {
510
- color: #1f2937 !important;
511
- }
512
- .gr-form {
513
- border: 1px solid #e5e7eb !important;
514
- border-radius: 8px !important;
515
- }
516
  """
517
 
518
- # Create Gradio interface
519
- with gr.Blocks(css=css, title="Brain Tumor Segmentation Analysis") as app:
520
-
521
- gr.Markdown("""
522
- # Brain Tumor Segmentation Using Attention U-Net
523
-
524
- **Advanced Medical Image Analysis Tool**
525
-
526
- Features: Attention Visualization, Dataset Integration, Morphological Post-processing
527
- """)
528
-
529
- # Status display
530
- with gr.Row():
531
- with gr.Column():
532
- status_text = f"Model Status: {'✓ Loaded' if model_loaded else '✗ Failed'} | Dataset Status: {'✓ Loaded' if dataset_loaded_success else '✗ Failed'}"
533
- if dataset_loaded_success:
534
- status_text += f" | Images: {len(dataset_images)} | Masks: {len(dataset_masks)}"
535
- gr.Markdown(f"**{status_text}**")
536
-
537
  with gr.Row():
538
  with gr.Column(scale=1):
539
- gr.Markdown("### Input Selection")
540
-
541
- # Image display
542
- image_display = gr.Image(
543
- label="Selected Image",
544
- type="pil",
545
- height=300
546
- )
547
-
548
- # Control buttons
549
  with gr.Row():
550
- load_sample_btn = gr.Button("Load Random Sample", variant="primary", scale=1)
551
- upload_btn = gr.UploadButton("Upload Image", file_types=["image"], scale=1)
552
-
553
  analyze_btn = gr.Button("Analyze Image", variant="primary", size="lg")
554
-
555
- # Dataset info
556
- gr.Markdown(f"""
557
- **Dataset Information:**
558
- - Total Images: {len(dataset_images) if dataset_loaded_success else 'N/A'}
559
- - Total Masks: {len(dataset_masks) if dataset_loaded_success else 'N/A'}
560
- - Source: nikhilroxtomar/brain-tumor-segmentation
561
- """)
562
-
563
  with gr.Column(scale=2):
564
- gr.Markdown("### Analysis Results")
565
-
566
- result_display = gr.Image(
567
- label="Segmentation Analysis",
568
- type="pil",
569
- height=500
570
- )
571
-
572
- analysis_text = gr.Markdown(
573
- value="Load an image and click 'Analyze Image' to begin."
574
- )
575
-
576
- # Hidden states
577
  current_ground_truth = gr.State()
578
  current_filename = gr.State()
579
-
580
- # Event handlers
581
  def handle_sample_load():
582
  image, mask, filename = get_random_sample()
583
  return image, mask, filename
584
-
585
- def handle_upload(file):
586
- if file is not None:
587
- image = Image.open(file.name).convert("L")
588
- return image, None, os.path.basename(file.name)
589
  return None, None, ""
590
-
591
- load_sample_btn.click(
592
- fn=handle_sample_load,
593
- outputs=[image_display, current_ground_truth, current_filename]
594
- )
595
-
596
- upload_btn.upload(
597
- fn=handle_upload,
598
- inputs=[upload_btn],
599
- outputs=[image_display, current_ground_truth, current_filename]
600
- )
601
-
602
- analyze_btn.click(
603
- fn=analyze_image,
604
- inputs=[image_display, current_ground_truth, current_filename],
605
- outputs=[result_display, analysis_text]
606
- )
607
 
608
  if __name__ == "__main__":
609
- print("\n" + "="*50)
610
- print("LAUNCHING BRAIN TUMOR SEGMENTATION APPLICATION")
611
- print("="*50)
612
-
613
- app.launch(
614
- server_name="0.0.0.0",
615
- server_port=7860,
616
- show_error=True,
617
- share=False
618
- )
 
1
+ # full_app_with_heatmap.py
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
 
7
  from PIL import Image
8
  import matplotlib.pyplot as plt
9
  import io
10
+ from torchvision import transforms
11
  import torchvision.transforms.functional as TF
 
 
12
  import urllib.request
13
+ import os
14
+ import random
15
  from glob import glob
16
+ import kagglehub # if you use dataset download in the app; remove if not needed
17
 
18
+ # -------------------------
19
+ # Setup / Globals
20
+ # -------------------------
21
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
  model = None
23
  dataset_images = []
24
  dataset_masks = []
25
  dataset_loaded = False
26
 
27
+ # -------------------------
28
+ # Model classes (Attention U-Net)
29
+ # -------------------------
 
 
30
  class DoubleConv(nn.Module):
31
  def __init__(self, in_channels, out_channels):
32
  super(DoubleConv, self).__init__()
 
50
  nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
51
  nn.BatchNorm2d(F_int)
52
  )
 
53
  self.W_x = nn.Sequential(
54
  nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
55
  nn.BatchNorm2d(F_int)
56
  )
 
57
  self.psi = nn.Sequential(
58
  nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
59
  nn.BatchNorm2d(1),
60
  nn.Sigmoid()
61
  )
 
62
  self.relu = nn.ReLU(inplace=True)
63
+
64
  def forward(self, g, x):
65
  g1 = self.W_g(g)
66
  x1 = self.W_x(x)
67
  psi = self.relu(g1 + x1)
68
  psi = self.psi(psi)
69
+ return x * psi, psi # return attended skip, attention map
70
+
71
 
72
  class AttentionUNET(nn.Module):
73
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
 
83
  in_channels = feature
84
 
85
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
86
+
87
  for feature in reversed(features):
88
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
89
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
 
105
 
106
  for idx in range(0, len(self.ups), 2):
107
  x = self.ups[idx](x)
108
+ skip_connection = skip_connections[idx // 2]
109
 
110
  if x.shape != skip_connection.shape:
111
  x = TF.resize(x, size=skip_connection.shape[2:])
 
117
 
118
  return self.final_conv(x), attention_maps
119
 
120
+ # -------------------------
121
+ # Model download / load
122
+ # -------------------------
123
  def download_and_load_model():
 
124
  global model
125
  print("Loading Attention U-Net model...")
126
+
127
  model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
128
  model_path = "best_attention_model.pth.tar"
129
+
 
130
  if not os.path.exists(model_path):
131
  print("Downloading model weights...")
132
  try:
 
134
  except Exception as e:
135
  print(f"Failed to download model: {e}")
136
  return False
137
+
 
138
  try:
139
  model = AttentionUNET(in_channels=1, out_channels=1).to(device)
140
+ checkpoint = torch.load(model_path, map_location=device)
141
+ # checkpoint format expected to have "state_dict"
142
+ if "state_dict" in checkpoint:
143
+ sd = checkpoint["state_dict"]
144
+ else:
145
+ sd = checkpoint
146
+ # Try exact load; if mismatch, try strict=False and warn
147
+ try:
148
+ model.load_state_dict(sd)
149
+ except Exception as ex:
150
+ print(f"Warning: strict load failed: {ex}. Trying strict=False...")
151
+ model.load_state_dict(sd, strict=False)
152
  model.eval()
153
  print("✓ Model loaded successfully!")
154
  return True
155
  except Exception as e:
156
  print(f"Failed to load model: {e}")
157
+ model = None
158
  return False
159
 
160
+ # -------------------------
161
+ # Dataset utilities (optional)
162
+ # -------------------------
163
  def download_and_load_dataset():
 
164
  global dataset_images, dataset_masks, dataset_loaded
 
165
  if dataset_loaded:
166
  return True
 
 
 
167
  try:
168
+ print("Loading brain tumor dataset (kagglehub)...")
169
  dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
 
 
 
170
  images_dir = os.path.join(dataset_path, 'images')
171
  masks_dir = os.path.join(dataset_path, 'masks')
 
 
 
 
 
 
 
 
 
 
172
  if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
173
+ # fallback search
 
174
  all_files = glob(os.path.join(dataset_path, "**/*.png"), recursive=True) + \
175
+ glob(os.path.join(dataset_path, "**/*.jpg"), recursive=True)
 
176
  dataset_images = [f for f in all_files if '/images/' in f or 'image' in f.lower()]
177
  dataset_masks = [f for f in all_files if '/masks/' in f or 'mask' in f.lower()]
178
  else:
179
+ dataset_images = sorted(glob(os.path.join(images_dir, "*.*")))
180
+ dataset_masks = sorted(glob(os.path.join(masks_dir, "*.*")))
 
 
 
 
 
181
  print(f"✓ Found {len(dataset_images)} images and {len(dataset_masks)} masks")
182
  dataset_loaded = True
183
  return True
 
184
  except Exception as e:
185
  print(f"Failed to load dataset: {e}")
186
  return False
187
 
188
  def get_random_sample():
 
189
  if not dataset_loaded:
190
  return None, None, "Dataset not loaded"
 
191
  if not dataset_images:
192
+ return None, None, "No images found"
193
+ idx = random.randint(0, len(dataset_images)-1)
 
 
194
  img_path = dataset_images[idx]
 
 
195
  img_name = os.path.basename(img_path)
196
  mask_path = None
197
  for mask in dataset_masks:
198
  if os.path.basename(mask) == img_name:
199
  mask_path = mask
200
  break
 
201
  try:
202
  image = Image.open(img_path).convert("L")
203
  mask = Image.open(mask_path).convert("L") if mask_path else None
 
205
  except Exception as e:
206
  return None, None, f"Error loading sample: {e}"
207
 
208
+ # -------------------------
209
+ # Preprocessing & Heatmap utils
210
+ # -------------------------
211
  def preprocess_for_model(image):
 
212
  if image.mode != 'L':
213
  image = image.convert('L')
 
214
  transform = transforms.Compose([
215
+ transforms.Resize((256, 256)),
216
  transforms.ToTensor()
217
  ])
 
218
  return transform(image).unsqueeze(0)
219
 
220
  def generate_attention_heatmap(attention_maps):
 
221
  if not attention_maps:
222
+ return np.zeros((256, 256, 3), dtype=np.uint8)
 
 
223
  resized_maps = []
224
  target_size = (256, 256)
 
225
  for att_map in attention_maps:
 
226
  att_np = att_map.squeeze().cpu().numpy()
 
 
227
  att_resized = cv2.resize(att_np, target_size)
228
  resized_maps.append(att_resized)
 
 
229
  combined_att = np.mean(resized_maps, axis=0)
 
 
230
  combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8)
 
 
231
  heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET)
232
+ return heatmap # BGR (OpenCV)
233
+
234
+ # -------------------------
235
+ # Core: produce combined 1x5 image (preserve old 1-4 behavior)
236
+ # -------------------------
237
+ def results_with_heatmap(image, ground_truth=None, filename=None, threshold=0.5):
 
 
 
 
 
 
238
  if model is None:
239
  return None, "Model not loaded. Please restart the application."
 
240
  if image is None:
241
  return None, "Please select an image first."
242
 
243
+ # Keep preprocessing & prediction exactly like your working code
244
+ img_gray = image.convert('L') if image.mode != 'L' else image
245
+ original_np = np.array(img_gray.resize((256, 256))).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ # Preprocess for model
248
+ prep = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
249
+ input_tensor = prep(img_gray).unsqueeze(0).to(device)
250
 
251
+ with torch.no_grad():
252
+ out = model(input_tensor)
253
+ # support both: model -> logits OR (logits, att_maps)
254
+ if isinstance(out, (list, tuple)) and len(out) == 2:
255
+ logits, attention_maps = out
256
+ else:
257
+ logits = out
258
+ attention_maps = []
259
+
260
+ pred_prob = torch.sigmoid(logits)
261
+ pred_mask = (pred_prob > threshold).float()
262
+
263
+ pred_mask_np = pred_mask.cpu().squeeze().numpy() # (256,256)
264
+ inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255).astype(np.uint8)
265
+ tumor_only = np.where(pred_mask_np == 1, original_np, 255).astype(np.uint8)
266
+
267
+ # ground truth handling (preserve old style)
268
+ if ground_truth is not None:
269
+ gt_gray = ground_truth.convert('L') if ground_truth.mode != 'L' else ground_truth
270
+ mask_np = prep(gt_gray).cpu().squeeze().numpy()
271
+ mask_vis = (mask_np > 0.5).astype(np.uint8)
272
+ else:
273
+ mask_vis = np.zeros_like(original_np)
274
+
275
+ # Try to build attention heatmap; fallback to probability heatmap
276
+ att_heat = generate_attention_heatmap(attention_maps)
277
+ if att_heat is None or att_heat.size == 0:
278
+ prob_np = pred_prob.cpu().squeeze().numpy()
279
+ prob_resized = cv2.resize(prob_np, (256, 256))
280
+ prob_norm = (prob_resized - prob_resized.min()) / (prob_resized.max() - prob_resized.min() + 1e-8)
281
+ att_heat_bgr = cv2.applyColorMap((prob_norm * 255).astype(np.uint8), cv2.COLORMAP_JET)
282
+ att_heat = att_heat_bgr
283
+
284
+ # convert BGR->RGB for display
285
+ try:
286
+ att_heat = cv2.cvtColor(att_heat, cv2.COLOR_BGR2RGB)
287
+ except Exception:
288
+ pass
289
+
290
+ # ensure dtype/shape
291
+ if att_heat.dtype != np.uint8:
292
+ att_heat = (att_heat * 255).astype(np.uint8) if att_heat.max() <= 1.0 else att_heat.astype(np.uint8)
293
+ if att_heat.ndim == 2:
294
+ att_heat = cv2.cvtColor(att_heat, cv2.COLOR_GRAY2RGB)
295
+
296
+ # Create 1x5 figure
297
+ fig, axes = plt.subplots(1, 5, figsize=(22, 5))
298
+ fig.suptitle('Results + Heatmap', fontsize=16, weight='bold')
299
+
300
+ axes[0].imshow(original_np, cmap='gray'); axes[0].set_title('Original Image'); axes[0].axis('off')
301
+ axes[1].imshow(mask_vis, cmap='gray'); axes[1].set_title('Ground Truth Mask' if ground_truth is not None else 'GT (none)'); axes[1].axis('off')
302
+ axes[2].imshow(inv_pred_mask_np, cmap='gray'); axes[2].set_title('Predicted Mask'); axes[2].axis('off')
303
+ axes[3].imshow(tumor_only, cmap='gray'); axes[3].set_title('Tumor Only'); axes[3].axis('off')
304
+ axes[4].imshow(att_heat); axes[4].set_title('Attention / Prob Heatmap'); axes[4].axis('off')
305
+
306
+ plt.tight_layout()
307
+
308
+ buf = io.BytesIO()
309
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
310
+ buf.seek(0)
311
+ plt.close(fig)
312
+ result_img = Image.open(buf).convert("RGB")
313
+
314
+ tumor_pixels = int(np.sum(pred_mask_np))
315
+ total_pixels = int(pred_mask_np.size)
316
+ tumor_pct = (tumor_pixels / total_pixels) * 100 if total_pixels > 0 else 0.0
317
+
318
+ analysis_text = f"""
319
  # Analysis Results
 
320
  **File:** {filename if filename else 'Uploaded Image'}
321
+ - Tumor Area: {tumor_pct:.2f}%
 
 
322
  - Tumor Pixels: {tumor_pixels:,}
323
+ - Max confidence: {float(pred_prob.max()):.4f}
324
+ - Threshold used: {threshold}
 
 
 
 
 
 
 
 
 
325
  """
326
 
327
+ return result_img, analysis_text
 
 
328
 
329
+ # -------------------------
330
+ # Initialize model & dataset at startup
331
+ # -------------------------
 
 
 
 
 
 
 
332
  print("Initializing application components...")
333
  model_loaded = download_and_load_model()
334
  dataset_loaded_success = download_and_load_dataset()
 
335
  if not model_loaded:
336
  print("WARNING: Model failed to load!")
337
  if not dataset_loaded_success:
338
  print("WARNING: Dataset failed to load!")
 
339
  print("Application ready!")
340
 
341
+ # -------------------------
342
+ # Gradio UI
343
+ # -------------------------
344
  css = """
345
+ .gradio-container { max-width: 1400px !important; margin:auto !important; font-family: 'Segoe UI', Tahoma, Verdana; }
346
+ .gr-button { border-radius: 6px !important; font-weight: 500 !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  """
348
 
349
+ with gr.Blocks(css=css, title="Brain Tumor Segmentation + Heatmap") as app:
350
+ gr.Markdown("# Brain Tumor Segmentation Attention U-Net\nPreserves original 1–4 outputs; adds 5th: heatmap.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  with gr.Row():
352
  with gr.Column(scale=1):
353
+ image_display = gr.Image(label="Selected Image", type="pil", height=300)
 
 
 
 
 
 
 
 
 
354
  with gr.Row():
355
+ load_sample_btn = gr.Button("Load Random Sample", variant="primary")
356
+ upload_btn = gr.UploadButton("Upload Image", file_types=["image"])
 
357
  analyze_btn = gr.Button("Analyze Image", variant="primary", size="lg")
358
+ gr.Markdown(f"**Model Status:** {'✓ Loaded' if model_loaded else '✗ Failed'} \n**Dataset:** {'✓ Loaded' if dataset_loaded_success else '✗ Failed'}")
 
 
 
 
 
 
 
 
359
  with gr.Column(scale=2):
360
+ gr.Markdown("### Results (1x5 panel)")
361
+ result_display = gr.Image(label="Segmentation + Heatmap", type="pil", height=600)
362
+ analysis_text = gr.Markdown("Upload or load a sample and click Analyze.")
363
+
 
 
 
 
 
 
 
 
 
364
  current_ground_truth = gr.State()
365
  current_filename = gr.State()
366
+
 
367
  def handle_sample_load():
368
  image, mask, filename = get_random_sample()
369
  return image, mask, filename
370
+
371
+ def handle_upload(f):
372
+ if f is not None:
373
+ img = Image.open(f.name).convert("L")
374
+ return img, None, os.path.basename(f.name)
375
  return None, None, ""
376
+
377
+ load_sample_btn.click(fn=handle_sample_load, outputs=[image_display, current_ground_truth, current_filename])
378
+ upload_btn.upload(fn=handle_upload, inputs=[upload_btn], outputs=[image_display, current_ground_truth, current_filename])
379
+
380
+ analyze_btn.click(fn=results_with_heatmap, inputs=[image_display, current_ground_truth, current_filename], outputs=[result_display, analysis_text])
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
  if __name__ == "__main__":
383
+ app.launch(server_name="0.0.0.0", server_port=7860, show_error=True, share=False)