ArchCoder commited on
Commit
16f55d5
·
verified ·
1 Parent(s): 073ab95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -693
app.py CHANGED
@@ -6,20 +6,19 @@ import cv2
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
8
  import io
9
- from torchvision import transforms
10
  import torchvision.transforms.functional as TF
11
- import urllib.request
12
  import os
13
- import kagglehub
14
  import random
15
- from pathlib import Path
16
- import seaborn as sns
 
17
 
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
  model = None
20
- dataset_path = None
21
 
22
- # Define your Attention U-Net architecture (from your training code)
23
  class DoubleConv(nn.Module):
24
  def __init__(self, in_channels, out_channels):
25
  super(DoubleConv, self).__init__()
@@ -61,7 +60,7 @@ class AttentionBlock(nn.Module):
61
  x1 = self.W_x(x)
62
  psi = self.relu(g1 + x1)
63
  psi = self.psi(psi)
64
- return x * psi, psi # Return attention coefficients for visualization
65
 
66
  class AttentionUNET(nn.Module):
67
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
@@ -72,7 +71,7 @@ class AttentionUNET(nn.Module):
72
  self.attentions = nn.ModuleList()
73
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
74
 
75
- # Down part of UNET
76
  for feature in features:
77
  self.downs.append(DoubleConv(in_channels, feature))
78
  in_channels = feature
@@ -80,7 +79,7 @@ class AttentionUNET(nn.Module):
80
  # Bottleneck
81
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
82
 
83
- # Up part of UNET
84
  for feature in reversed(features):
85
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
86
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
@@ -88,9 +87,9 @@ class AttentionUNET(nn.Module):
88
 
89
  self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
90
 
91
- def forward(self, x, return_attention=False):
92
  skip_connections = []
93
- attention_maps = []
94
 
95
  for down in self.downs:
96
  x = down(x)
@@ -107,30 +106,13 @@ class AttentionUNET(nn.Module):
107
  if x.shape != skip_connection.shape:
108
  x = TF.resize(x, size=skip_connection.shape[2:])
109
 
110
- skip_connection, attention_coeff = self.attentions[idx // 2](skip_connection, x)
111
- if return_attention:
112
- attention_maps.append(attention_coeff)
113
 
114
- concat_skip = torch.cat((skip_connection, x), dim=1)
115
  x = self.ups[idx+1](concat_skip)
116
 
117
- output = self.final_conv(x)
118
-
119
- if return_attention:
120
- return output, attention_maps
121
- return output
122
-
123
- def download_dataset():
124
- """Download Brain Tumor Segmentation dataset from Kaggle"""
125
- global dataset_path
126
- try:
127
- print("📥 Downloading Brain Tumor Segmentation dataset...")
128
- dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
129
- print(f"✅ Dataset downloaded to: {dataset_path}")
130
- return dataset_path
131
- except Exception as e:
132
- print(f"❌ Failed to download dataset: {e}")
133
- return None
134
 
135
  def download_model():
136
  """Download your trained model from HuggingFace"""
@@ -138,758 +120,307 @@ def download_model():
138
  model_path = "best_attention_model.pth.tar"
139
 
140
  if not os.path.exists(model_path):
141
- print("📥 Downloading trained model...")
142
  try:
143
  urllib.request.urlretrieve(model_url, model_path)
144
  print("✅ Model downloaded successfully!")
145
  except Exception as e:
146
  print(f"❌ Failed to download model: {e}")
147
  return None
148
- else:
149
- print("✅ Model already exists!")
150
-
151
  return model_path
152
 
153
- def load_attention_model():
154
- """Load trained Attention U-Net model"""
155
  global model
156
  if model is None:
157
  try:
158
- print("🔄 Loading Attention U-Net model...")
159
 
 
160
  model_path = download_model()
161
  if model_path is None:
162
  return None
163
 
 
164
  model = AttentionUNET(in_channels=1, out_channels=1).to(device)
 
 
165
  checkpoint = torch.load(model_path, map_location=device, weights_only=True)
166
  model.load_state_dict(checkpoint["state_dict"])
167
  model.eval()
168
 
169
- print("✅ Attention U-Net model loaded successfully!")
170
  except Exception as e:
171
- print(f"❌ Error loading model: {e}")
172
  model = None
173
  return model
174
 
175
- def get_random_sample_from_dataset():
176
- """Get a random sample image and ground truth mask from the dataset"""
177
- global dataset_path
 
 
178
 
179
- if dataset_path is None:
180
- dataset_path = download_dataset()
181
- if dataset_path is None:
182
- return None, None
 
183
 
184
- try:
185
- images_path = Path(dataset_path) / "images"
186
- masks_path = Path(dataset_path) / "masks"
187
-
188
- if not images_path.exists() or not masks_path.exists():
189
- print("❌ Dataset structure not found")
190
- return None, None
191
-
192
- # Get all image files
193
- image_files = list(images_path.glob("*.jpg")) + list(images_path.glob("*.png")) + list(images_path.glob("*.tif"))
194
-
195
- if not image_files:
196
- print("❌ No image files found in dataset")
197
- return None, None
198
-
199
- # Select random image
200
- random_image_file = random.choice(image_files)
201
- image_name = random_image_file.stem
202
-
203
- # Find corresponding mask
204
- possible_mask_extensions = ['.jpg', '.png', '.tif', '.gif']
205
- mask_file = None
206
-
207
- for ext in possible_mask_extensions:
208
- potential_mask = masks_path / f"{image_name}{ext}"
209
- if potential_mask.exists():
210
- mask_file = potential_mask
211
- break
212
-
213
- if mask_file is None:
214
- print(f"❌ No corresponding mask found for {image_name}")
215
- return None, None
216
-
217
- # Load image and mask
218
- image = Image.open(random_image_file).convert('L')
219
- mask = Image.open(mask_file).convert('L')
220
-
221
- print(f"✅ Loaded random sample: {image_name}")
222
- return image, mask
223
-
224
- except Exception as e:
225
- print(f"❌ Error loading random sample: {e}")
226
- return None, None
227
 
228
- def test_time_augmentation(model, image_tensor):
229
- """Apply Test-Time Augmentation (TTA) for robust predictions"""
230
- augmentations = [
231
- lambda x: x, # Original
232
- lambda x: torch.flip(x, dims=[3]), # Horizontal flip
233
- lambda x: torch.flip(x, dims=[2]), # Vertical flip
234
- lambda x: torch.flip(x, dims=[2, 3]), # Both flips
235
- lambda x: torch.rot90(x, k=1, dims=[2, 3]), # 90° rotation
236
- lambda x: torch.rot90(x, k=3, dims=[2, 3]), # 270° rotation
237
- ]
238
 
239
- reverse_augmentations = [
240
- lambda x: x, # Original
241
- lambda x: torch.flip(x, dims=[3]), # Reverse horizontal flip
242
- lambda x: torch.flip(x, dims=[2]), # Reverse vertical flip
243
- lambda x: torch.flip(x, dims=[2, 3]), # Reverse both flips
244
- lambda x: torch.rot90(x, k=3, dims=[2, 3]), # Reverse 90° rotation
245
- lambda x: torch.rot90(x, k=1, dims=[2, 3]), # Reverse 270° rotation
246
- ]
247
 
 
 
 
 
 
 
 
248
  predictions = []
249
 
250
- with torch.no_grad():
251
- for aug, rev_aug in zip(augmentations, reverse_augmentations):
252
- # Apply augmentation
253
- aug_input = aug(image_tensor)
254
-
255
- # Get prediction
256
- pred = torch.sigmoid(model(aug_input))
257
-
258
- # Reverse augmentation on prediction
259
- pred = rev_aug(pred)
260
-
261
- predictions.append(pred)
262
 
263
- # Average all predictions
264
- tta_prediction = torch.mean(torch.stack(predictions), dim=0)
 
 
 
265
 
266
- return tta_prediction
267
-
268
- def generate_attention_heatmaps(model, image_tensor):
269
- """Generate attention heatmaps for interpretability"""
270
- with torch.no_grad():
271
- pred, attention_maps = model(image_tensor, return_attention=True)
272
-
273
- # Convert attention maps to numpy for visualization
274
- heatmaps = []
275
- for i, att_map in enumerate(attention_maps):
276
- # Resize attention map to match input size
277
- att_map_resized = TF.resize(att_map, (256, 256))
278
- att_np = att_map_resized.cpu().squeeze().numpy()
279
- heatmaps.append(att_np)
280
-
281
- return heatmaps
282
-
283
- def preprocess_image(image):
284
- """Preprocessing exactly like training code"""
285
- if image.mode != 'L':
286
- image = image.convert('L')
287
 
288
- val_test_transform = transforms.Compose([
289
- transforms.Resize((256, 256)),
290
- transforms.ToTensor()
291
- ])
 
 
 
 
 
292
 
293
- return val_test_transform(image).unsqueeze(0)
 
 
 
294
 
295
- def calculate_metrics(pred_mask, ground_truth_mask):
296
- """Calculate Dice and IoU metrics"""
297
- pred_binary = (pred_mask > 0.5).float()
298
- gt_binary = (ground_truth_mask > 0.5).float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
- # Dice coefficient
301
- intersection = torch.sum(pred_binary * gt_binary)
302
- dice = (2.0 * intersection) / (torch.sum(pred_binary) + torch.sum(gt_binary) + 1e-8)
303
 
304
- # IoU
305
- union = torch.sum(pred_binary) + torch.sum(gt_binary) - intersection
306
- iou = intersection / (union + 1e-8)
307
 
308
- return dice.item(), iou.item()
 
 
 
309
 
310
- def predict_with_enhancements(image, ground_truth=None, use_tta=True, show_attention=True):
311
- """Enhanced prediction with TTA and attention visualization"""
312
- current_model = load_attention_model()
313
 
314
  if current_model is None:
315
- return None, "Failed to load trained model."
316
 
317
  if image is None:
318
- return None, "⚠️ Please upload an image first."
319
 
320
  try:
321
- print("🧠 Processing with enhanced Attention U-Net...")
322
 
323
  input_tensor = preprocess_image(image).to(device)
324
 
325
- # Standard prediction
326
- with torch.no_grad():
327
- standard_pred = torch.sigmoid(current_model(input_tensor))
328
-
329
- # Test-Time Augmentation
330
  if use_tta:
331
- tta_pred = test_time_augmentation(current_model, input_tensor)
332
- final_pred = tta_pred
333
- else:
334
- final_pred = standard_pred
335
-
336
- # Generate attention heatmaps
337
- attention_heatmaps = []
338
- if show_attention:
339
- attention_heatmaps = generate_attention_heatmaps(current_model, input_tensor)
340
-
341
- # Convert predictions to binary
342
- pred_mask_binary = (final_pred > 0.5).float()
343
- pred_mask_np = pred_mask_binary.cpu().squeeze().numpy()
344
- standard_mask_np = (standard_pred > 0.5).float().cpu().squeeze().numpy()
345
-
346
- # Prepare images for visualization
347
- original_np = np.array(image.convert('L').resize((256, 256)))
348
-
349
- # Create comprehensive visualization
350
- if ground_truth is not None:
351
- # With ground truth comparison
352
- gt_np = np.array(ground_truth.convert('L').resize((256, 256)))
353
- gt_binary = (gt_np > 127).astype(np.float32) # Threshold ground truth
354
-
355
- # Calculate metrics
356
- gt_tensor = torch.tensor(gt_binary).unsqueeze(0).unsqueeze(0).to(device)
357
- dice_score, iou_score = calculate_metrics(final_pred, gt_tensor)
358
-
359
- # Create figure with ground truth comparison
360
- n_cols = 6 if show_attention and attention_heatmaps else 5
361
- fig, axes = plt.subplots(2, n_cols, figsize=(4*n_cols, 8))
362
- fig.suptitle('🧠 Enhanced Attention U-Net Analysis with Ground Truth Comparison', fontsize=16, weight='bold')
363
-
364
- # Top row - Standard analysis
365
- axes[0, 0].imshow(original_np, cmap='gray')
366
- axes[0, 0].set_title('Original Image', fontsize=12, weight='bold')
367
- axes[0, 0].axis('off')
368
-
369
- axes[0, 1].imshow(standard_mask_np * 255, cmap='hot')
370
- axes[0, 1].set_title('Standard Prediction', fontsize=12, weight='bold')
371
- axes[0, 1].axis('off')
372
-
373
- axes[0, 2].imshow(pred_mask_np * 255, cmap='hot')
374
- axes[0, 2].set_title(f'{"TTA Enhanced" if use_tta else "Final Prediction"}', fontsize=12, weight='bold')
375
- axes[0, 2].axis('off')
376
-
377
- axes[0, 3].imshow(gt_binary * 255, cmap='hot')
378
- axes[0, 3].set_title('Ground Truth', fontsize=12, weight='bold')
379
- axes[0, 3].axis('off')
380
-
381
- # Overlay comparison
382
- overlay = original_np.copy()
383
- overlay = np.stack([overlay, overlay, overlay], axis=-1)
384
- overlay[pred_mask_np > 0.5] = [255, 0, 0] # Red for prediction
385
- overlay[gt_binary > 0.5] = [0, 255, 0] # Green for ground truth
386
- overlap = (pred_mask_np > 0.5) & (gt_binary > 0.5)
387
- overlay[overlap] = [255, 255, 0] # Yellow for overlap
388
-
389
- axes[0, 4].imshow(overlay.astype(np.uint8))
390
- axes[0, 4].set_title('Overlay (Red:Pred, Green:GT, Yellow:Match)', fontsize=10, weight='bold')
391
- axes[0, 4].axis('off')
392
-
393
- if show_attention and attention_heatmaps:
394
- # Show combined attention
395
- combined_attention = np.mean(attention_heatmaps, axis=0)
396
- axes[0, 5].imshow(combined_attention, cmap='jet', alpha=0.7)
397
- axes[0, 5].imshow(original_np, cmap='gray', alpha=0.3)
398
- axes[0, 5].set_title('Attention Heatmap', fontsize=12, weight='bold')
399
- axes[0, 5].axis('off')
400
-
401
- # Bottom row - Individual attention maps or detailed analysis
402
- if show_attention and attention_heatmaps:
403
- for i, heatmap in enumerate(attention_heatmaps[:n_cols]):
404
- axes[1, i].imshow(heatmap, cmap='jet', alpha=0.7)
405
- axes[1, i].imshow(original_np, cmap='gray', alpha=0.3)
406
- axes[1, i].set_title(f'Attention Gate {i+1}', fontsize=10, weight='bold')
407
- axes[1, i].axis('off')
408
- else:
409
- # Show tumor extraction and analysis
410
- tumor_only = np.where(pred_mask_np == 1, original_np, 255)
411
- inv_mask = np.where(pred_mask_np == 1, 0, 255)
412
-
413
- axes[1, 0].imshow(tumor_only, cmap='gray')
414
- axes[1, 0].set_title('Tumor Extraction', fontsize=12, weight='bold')
415
- axes[1, 0].axis('off')
416
-
417
- axes[1, 1].imshow(inv_mask, cmap='gray')
418
- axes[1, 1].set_title('Inverted Mask', fontsize=12, weight='bold')
419
- axes[1, 1].axis('off')
420
-
421
- # Difference map
422
- diff_map = np.abs(pred_mask_np - gt_binary)
423
- axes[1, 2].imshow(diff_map, cmap='Reds')
424
- axes[1, 2].set_title('Difference Map', fontsize=12, weight='bold')
425
- axes[1, 2].axis('off')
426
-
427
- # Clear remaining axes
428
- for j in range(3, n_cols):
429
- axes[1, j].axis('off')
430
  else:
431
- # Without ground truth
432
- n_cols = 5 if show_attention and attention_heatmaps else 4
433
- fig, axes = plt.subplots(2, n_cols, figsize=(4*n_cols, 8))
434
- fig.suptitle('🧠 Enhanced Attention U-Net Analysis', fontsize=16, weight='bold')
435
-
436
- # Top row
437
- images = [original_np, standard_mask_np * 255, pred_mask_np * 255]
438
- titles = ["Original Image", "Standard Prediction", f'{"TTA Enhanced" if use_tta else "Final Prediction"}']
439
- cmaps = ['gray', 'hot', 'hot']
440
-
441
- for i in range(3):
442
- axes[0, i].imshow(images[i], cmap=cmaps[i])
443
- axes[0, i].set_title(titles[i], fontsize=12, weight='bold')
444
- axes[0, i].axis('off')
445
-
446
- # Tumor extraction
447
- tumor_only = np.where(pred_mask_np == 1, original_np, 255)
448
- axes[0, 3].imshow(tumor_only, cmap='gray')
449
- axes[0, 3].set_title('Tumor Extraction', fontsize=12, weight='bold')
450
- axes[0, 3].axis('off')
451
-
452
- if show_attention and attention_heatmaps:
453
- combined_attention = np.mean(attention_heatmaps, axis=0)
454
- axes[0, 4].imshow(combined_attention, cmap='jet', alpha=0.7)
455
- axes[0, 4].imshow(original_np, cmap='gray', alpha=0.3)
456
- axes[0, 4].set_title('Combined Attention', fontsize=12, weight='bold')
457
- axes[0, 4].axis('off')
458
-
459
- # Bottom row - Individual attention maps
460
- if show_attention and attention_heatmaps:
461
- for i, heatmap in enumerate(attention_heatmaps[:n_cols]):
462
- axes[1, i].imshow(heatmap, cmap='jet', alpha=0.7)
463
- axes[1, i].imshow(original_np, cmap='gray', alpha=0.3)
464
- axes[1, i].set_title(f'Attention Gate {i+1}', fontsize=10, weight='bold')
465
- axes[1, i].axis('off')
466
- else:
467
- # Clear bottom row
468
- for j in range(n_cols):
469
- axes[1, j].axis('off')
 
 
 
 
 
 
 
 
470
 
471
  plt.tight_layout()
472
 
473
- # Save result
474
  buf = io.BytesIO()
475
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
476
  buf.seek(0)
477
  plt.close()
478
 
479
  result_image = Image.open(buf)
480
 
481
- # Calculate statistics
482
- tumor_pixels = np.sum(pred_mask_np)
483
- total_pixels = pred_mask_np.size
484
  tumor_percentage = (tumor_pixels / total_pixels) * 100
485
 
486
- max_confidence = torch.max(final_pred).item()
487
- mean_confidence = torch.mean(final_pred).item()
488
-
489
- # Enhanced analysis text
490
  analysis_text = f"""
491
- ## 🧠 Enhanced Attention U-Net Analysis Results
492
-
493
- ### 📊 Detection Summary
494
- - **Status**: {'🔴 TUMOR DETECTED' if tumor_pixels > 50 else '🟢 NO SIGNIFICANT TUMOR'}
495
- - **Tumor Coverage**: {tumor_percentage:.2f}% of brain region
496
- - **Tumor Pixels**: {tumor_pixels:,} pixels
497
- - **Max Confidence**: {max_confidence:.4f}
498
- - **Mean Confidence**: {mean_confidence:.4f}
499
  """
500
-
501
- if ground_truth is not None:
 
 
 
 
 
 
502
  analysis_text += f"""
503
- ### 🎯 Ground Truth Comparison
504
- - **Dice Score**: {dice_score:.4f} {'✅ Excellent' if dice_score > 0.8 else '⚠️ Good' if dice_score > 0.6 else '❌ Poor'}
505
- - **IoU Score**: {iou_score:.4f} {'✅ Excellent' if iou_score > 0.7 else '⚠️ Good' if iou_score > 0.5 else '❌ Poor'}
506
- - **Model Accuracy**: {'High precision match' if dice_score > 0.8 else 'Reasonable match' if dice_score > 0.6 else 'Needs improvement'}
507
- """
508
-
509
- analysis_text += f"""
510
- ### 🚀 Enhancement Features
511
- - **Test-Time Augmentation**: {'✅ Applied (6 augmentations averaged)' if use_tta else '❌ Disabled'}
512
- - **Attention Visualization**: {'✅ Generated attention heatmaps' if show_attention else '❌ Disabled'}
513
- - **Boundary Enhancement**: {'✅ TTA improves edge detection' if use_tta else '⚠️ Standard prediction only'}
514
- - **Interpretability**: {'✅ Attention gates show focus areas' if show_attention else '❌ Black box mode'}
515
-
516
- ### 🔬 Model Architecture
517
- - **Base Model**: Attention U-Net with skip connections
518
- - **Training Performance**: Dice: 0.8420, IoU: 0.7297, Accuracy: 98.90%
519
- - **Attention Gates**: 4 levels with soft attention mechanism
520
- - **Features Channels**: [32, 64, 128, 256] progression
521
- - **Device**: {device.type.upper()}
522
-
523
- ### 📈 Enhanced Processing Pipeline
524
- - **Preprocessing**: Resize(256×256) + Normalization
525
- - **Augmentations**: Flips (H,V), Rotations (90°,270°), Combined
526
- - **Attention Fusion**: Multi-scale attention coefficient extraction
527
- - **Post-processing**: Ensemble averaging + Binary thresholding (0.5)
528
-
529
- ### ⚠️ Medical Disclaimer
530
- This enhanced AI model is for **research and educational purposes only**.
531
- Results include advanced features for better accuracy and interpretability.
532
- Always consult medical professionals for clinical applications.
533
-
534
- ### 🏆 Research Contributions
535
- ✅ **Attention Gates**: Enhanced boundary detection through selective feature passing
536
- ✅ **Test-Time Augmentation**: Robust predictions via ensemble averaging
537
- ✅ **Interpretability**: Attention heatmaps for clinical trust and validation
538
- ✅ **Efficiency**: No retraining required, minimal computational overhead
539
  """
540
 
541
- print(f"✅ Enhanced analysis completed! Tumor coverage: {tumor_percentage:.2f}%")
542
  return result_image, analysis_text
543
 
544
  except Exception as e:
545
- error_msg = f"Error during enhanced analysis: {str(e)}"
546
- print(error_msg)
547
- return None, error_msg
548
 
549
- def load_random_sample():
550
- """Load a random sample from the dataset"""
551
- image, mask = get_random_sample_from_dataset()
552
  if image is None:
553
- return None, None, "Failed to load random sample from dataset"
554
- return image, mask, "✅ Random sample loaded from dataset"
555
 
556
- def clear_all():
557
- return None, None, None, "Upload a brain MRI image or load a random sample to test the enhanced model"
558
-
559
- # Enhanced professional CSS
560
  css = """
561
- .gradio-container {
562
- max-width: 1600px !important;
563
- margin: auto !important;
564
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
565
- }
566
-
567
- #title {
568
- text-align: center;
569
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
570
- color: white;
571
- padding: 40px;
572
- border-radius: 20px;
573
- margin-bottom: 30px;
574
- box-shadow: 0 12px 24px rgba(102, 126, 234, 0.4);
575
- }
576
-
577
- .feature-box {
578
- background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
579
- border-radius: 15px;
580
- padding: 25px;
581
- margin: 15px 0;
582
- color: white;
583
- box-shadow: 0 8px 16px rgba(240, 147, 251, 0.3);
584
- }
585
-
586
- .metric-card {
587
- background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
588
- border-radius: 12px;
589
- padding: 20px;
590
- text-align: center;
591
- margin: 10px;
592
- box-shadow: 0 6px 12px rgba(79, 172, 254, 0.3);
593
- }
594
-
595
- .enhancement-badge {
596
- display: inline-block;
597
- background: linear-gradient(45deg, #fa709a 0%, #fee140 100%);
598
- color: white;
599
- padding: 8px 16px;
600
- border-radius: 25px;
601
- margin: 5px;
602
- font-weight: bold;
603
- box-shadow: 0 4px 8px rgba(250, 112, 154, 0.3);
604
- }
605
  """
606
 
607
- # Create enhanced Gradio interface
608
- with gr.Blocks(css=css, title="🧠 Enhanced Brain Tumor Segmentation", theme=gr.themes.Soft()) as app:
609
-
610
- gr.HTML("""
611
- <div id="title">
612
- <h1>🧠 Enhanced Attention U-Net Brain Tumor Segmentation</h1>
613
- <p style="font-size: 20px; margin-top: 20px; font-weight: 300;">
614
- 🚀 Advanced Medical AI with Test-Time Augmentation & Attention Visualization
615
- </p>
616
- <p style="font-size: 16px; margin-top: 15px; opacity: 0.9;">
617
- 📊 Performance: Dice 0.8420 • IoU 0.7297 • Accuracy 98.90% |
618
- 🔬 Research-Grade Interpretability & Robustness
619
- </p>
620
- </div>
621
- """)
622
 
623
  with gr.Row():
624
  with gr.Column(scale=1):
625
- gr.Markdown("### 📤 Input & Controls")
626
-
627
- with gr.Tab("📸 Upload Image"):
628
- image_input = gr.Image(
629
- label="Brain MRI Scan",
630
- type="pil",
631
- sources=["upload", "webcam"],
632
- height=300
633
- )
634
-
635
- with gr.Tab("🎲 Random Sample"):
636
- random_image = gr.Image(
637
- label="Sample Image",
638
- type="pil",
639
- height=300,
640
- interactive=False
641
- )
642
- random_ground_truth = gr.Image(
643
- label="Ground Truth Mask",
644
- type="pil",
645
- height=300,
646
- interactive=False
647
- )
648
- load_sample_btn = gr.Button("🎲 Load Random Sample", variant="secondary", size="lg")
649
- sample_status = gr.Textbox(label="Sample Status", interactive=False)
650
-
651
- gr.Markdown("### ⚙️ Enhancement Options")
652
-
653
- use_tta = gr.Checkbox(
654
- label="🔄 Test-Time Augmentation",
655
- value=True,
656
- info="Apply multiple augmentations for robust predictions"
657
- )
658
-
659
- show_attention = gr.Checkbox(
660
- label="🔥 Attention Visualization",
661
- value=True,
662
- info="Generate attention heatmaps for interpretability"
663
- )
664
 
665
  with gr.Row():
666
- analyze_btn = gr.Button(
667
- "🧠 Analyze with Enhanced Model",
668
- variant="primary",
669
- scale=3,
670
- size="lg"
671
- )
672
- clear_btn = gr.Button("🗑️ Clear All", variant="secondary", scale=1)
673
-
674
- gr.HTML("""
675
- <div class="feature-box">
676
- <h4 style="margin-bottom: 15px;">🎯 Research Innovations</h4>
677
- <div class="enhancement-badge">Attention Gates</div>
678
- <div class="enhancement-badge">Test-Time Augmentation</div>
679
- <div class="enhancement-badge">Interpretability</div>
680
- <div class="enhancement-badge">Ground Truth Comparison</div>
681
- <p style="margin-top: 15px; font-size: 14px; opacity: 0.9;">
682
- Advanced medical AI combining accuracy, robustness, and clinical interpretability
683
- </p>
684
- </div>
685
- """)
686
 
687
  with gr.Column(scale=2):
688
- gr.Markdown("### 📊 Enhanced Analysis Results")
689
-
690
- output_image = gr.Image(
691
- label="Comprehensive Analysis Visualization",
692
- type="pil",
693
- height=600
694
- )
695
-
696
- with gr.Accordion("📈 Detailed Analysis Report", open=True):
697
- analysis_output = gr.Markdown(
698
- value="Upload a brain MRI image or load a random sample to test the enhanced Attention U-Net model.",
699
- elem_id="analysis"
700
- )
701
-
702
- # Performance metrics section
703
- gr.HTML("""
704
- <div style="margin-top: 40px;">
705
- <h3 style="text-align: center; color: #4a5568; margin-bottom: 25px;">📊 Model Performance & Research Contributions</h3>
706
- <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 20px; margin-bottom: 30px;">
707
-
708
- <div class="metric-card">
709
- <h4 style="color: white; margin-bottom: 10px;">🎯 Segmentation Accuracy</h4>
710
- <div style="font-size: 24px; font-weight: bold; margin: 10px 0;">98.90%</div>
711
- <p style="font-size: 14px; opacity: 0.9;">Training accuracy on brain tumor dataset</p>
712
- </div>
713
-
714
- <div class="metric-card">
715
- <h4 style="color: white; margin-bottom: 10px;">📐 Dice Score</h4>
716
- <div style="font-size: 24px; font-weight: bold; margin: 10px 0;">0.8420</div>
717
- <p style="font-size: 14px; opacity: 0.9;">Overlap similarity coefficient</p>
718
- </div>
719
-
720
- <div class="metric-card">
721
- <h4 style="color: white; margin-bottom: 10px;">🔲 IoU Score</h4>
722
- <div style="font-size: 24px; font-weight: bold; margin: 10px 0;">0.7297</div>
723
- <p style="font-size: 14px; opacity: 0.9;">Intersection over Union metric</p>
724
- </div>
725
-
726
- <div class="metric-card">
727
- <h4 style="color: white; margin-bottom: 10px;">⚡ Enhancement Features</h4>
728
- <div style="font-size: 20px; font-weight: bold; margin: 10px 0;">TTA + Attention</div>
729
- <p style="font-size: 14px; opacity: 0.9;">Advanced robustness & interpretability</p>
730
- </div>
731
-
732
- </div>
733
- </div>
734
- """)
735
-
736
- # Research contributions section
737
- gr.HTML("""
738
- <div style="margin-top: 30px; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 20px; color: white;">
739
- <h3 style="text-align: center; margin-bottom: 25px; color: white;">🚀 Novel Research Contributions</h3>
740
-
741
- <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px; margin-bottom: 20px;">
742
-
743
- <div>
744
- <h4 style="margin-bottom: 15px; color: #ffd700;">🔍 1. Enhanced Boundary Detection</h4>
745
- <ul style="line-height: 1.8; margin-left: 20px;">
746
- <li><strong>Problem:</strong> Traditional U-Net passes noisy features through skip connections</li>
747
- <li><strong>Solution:</strong> Attention gates filter irrelevant encoder features</li>
748
- <li><strong>Impact:</strong> Cleaner boundaries, reduced false positives</li>
749
- </ul>
750
- </div>
751
-
752
- <div>
753
- <h4 style="margin-bottom: 15px; color: #ffd700;">🔄 2. Test-Time Augmentation</h4>
754
- <ul style="line-height: 1.8; margin-left: 20px;">
755
- <li><strong>Problem:</strong> Medical datasets are small, MRI scans vary across centers</li>
756
- <li><strong>Solution:</strong> Multiple augmentations averaged for robust predictions</li>
757
- <li><strong>Impact:</strong> Improved robustness without retraining</li>
758
- </ul>
759
- </div>
760
-
761
- <div>
762
- <h4 style="margin-bottom: 15px; color: #ffd700;">🔥 3. Attention Visualization</h4>
763
- <ul style="line-height: 1.8; margin-left: 20px;">
764
- <li><strong>Problem:</strong> Deep networks are "black boxes" for clinicians</li>
765
- <li><strong>Solution:</strong> Extract attention coefficients as interpretable heatmaps</li>
766
- <li><strong>Impact:</strong> Build clinical trust through transparency</li>
767
- </ul>
768
- </div>
769
-
770
- <div>
771
- <h4 style="margin-bottom: 15px; color: #ffd700;">⚡ 4. Efficient Implementation</h4>
772
- <ul style="line-height: 1.8; margin-left: 20px;">
773
- <li><strong>Problem:</strong> Complex architectures are hard to deploy</li>
774
- <li><strong>Solution:</strong> Low-overhead enhancements within existing backbone</li>
775
- <li><strong>Impact:</strong> Practical for real-world medical workflows</li>
776
- </ul>
777
- </div>
778
-
779
- </div>
780
-
781
- <div style="text-align: center; padding-top: 20px; border-top: 2px solid rgba(255,255,255,0.3);">
782
- <p style="font-size: 16px; font-weight: 600; margin-bottom: 10px;">
783
- 🎯 Research Gap Addressed: Accuracy + Robustness + Interpretability
784
- </p>
785
- <p style="font-size: 14px; opacity: 0.9;">
786
- This combination tackles three major challenges in medical AI with minimal architectural changes
787
- </p>
788
- </div>
789
- </div>
790
- """)
791
-
792
- # Dataset and disclaimer section
793
- gr.HTML("""
794
- <div style="margin-top: 30px; padding: 25px; background-color: #f7fafc; border-radius: 15px; border-left: 5px solid #667eea;">
795
- <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px;">
796
-
797
- <div>
798
- <h4 style="color: #667eea; margin-bottom: 15px;">📚 Dataset Information</h4>
799
- <p><strong>Source:</strong> Brain Tumor Segmentation (Kaggle)</p>
800
- <p><strong>Author:</strong> nikhilroxtomar</p>
801
- <p><strong>Structure:</strong> Images + Ground Truth Masks</p>
802
- <p><strong>Format:</strong> Grayscale MRI scans</p>
803
- <p><strong>Use Case:</strong> Medical image segmentation research</p>
804
- <p><strong>Ground Truth:</strong> Available for metric calculation</p>
805
- </div>
806
-
807
- <div>
808
- <h4 style="color: #dc2626; margin-bottom: 15px;">⚠️ Medical Disclaimer</h4>
809
- <p style="color: #dc2626; font-weight: 600; line-height: 1.5;">
810
- This enhanced AI system is designed for <strong>research and educational purposes only</strong>.<br><br>
811
-
812
- While the model includes advanced features like attention visualization and test-time augmentation
813
- for improved accuracy and interpretability, all results must be validated by qualified medical professionals.<br><br>
814
-
815
- <strong>Not approved for clinical diagnosis or medical decision making.</strong>
816
- </p>
817
- </div>
818
-
819
- </div>
820
-
821
- <hr style="margin: 25px 0; border: none; border-top: 2px solid #e2e8f0;">
822
-
823
- <p style="text-align: center; color: #4a5568; margin: 15px 0; font-weight: 600;">
824
- 🔬 Research-Grade Medical AI • Enhanced Interpretability • Robust Predictions • Ground Truth Validation
825
- </p>
826
- </div>
827
- """)
828
 
829
  # Event handlers
830
- def analyze_with_ground_truth(image, gt_mask, use_tta, show_attention):
831
- """Wrapper function to handle ground truth comparison"""
832
- return predict_with_enhancements(image, gt_mask, use_tta, show_attention)
833
-
834
- def analyze_uploaded_image(image, use_tta, show_attention):
835
- """Wrapper function for uploaded images without ground truth"""
836
- return predict_with_enhancements(image, None, use_tta, show_attention)
837
-
838
- # Button event handlers
839
- analyze_btn.click(
840
- fn=lambda img, rand_img, rand_gt, tta, attention: (
841
- analyze_with_ground_truth(rand_img, rand_gt, tta, attention)
842
- if rand_img is not None
843
- else analyze_uploaded_image(img, tta, attention)
844
- ),
845
- inputs=[image_input, random_image, random_ground_truth, use_tta, show_attention],
846
- outputs=[output_image, analysis_output],
847
- show_progress=True
848
  )
849
 
850
- load_sample_btn.click(
851
- fn=load_random_sample,
852
  inputs=[],
853
- outputs=[random_image, random_ground_truth, sample_status],
854
- show_progress=True
855
  )
856
 
857
- clear_btn.click(
858
- fn=clear_all,
859
  inputs=[],
860
- outputs=[image_input, random_image, random_ground_truth, analysis_output]
861
  )
862
 
863
- # Auto-load dataset on startup
864
- gr.HTML("""
865
- <script>
866
- document.addEventListener('DOMContentLoaded', function() {
867
- console.log('Enhanced Brain Tumor Segmentation App Loaded');
868
- console.log('Features: TTA + Attention Visualization + Ground Truth Comparison');
869
- });
870
- </script>
871
- """)
872
-
873
  if __name__ == "__main__":
874
- print("🚀 Starting Enhanced Brain Tumor Segmentation System...")
875
- print("📊 Model Performance: Dice 0.8420, IoU 0.7297, Accuracy 98.90%")
876
- print("🔬 Research Features: Attention Gates + TTA + Interpretability")
877
- print("📥 Auto-downloading dataset and model...")
878
-
879
- # Initialize dataset download
880
- print("📚 Initializing dataset...")
881
- try:
882
- dataset_path = download_dataset()
883
- if dataset_path:
884
- print(f"✅ Dataset ready at: {dataset_path}")
885
- else:
886
- print("⚠️ Dataset download failed, random samples unavailable")
887
- except Exception as e:
888
- print(f"⚠️ Dataset initialization error: {e}")
889
-
890
- app.launch(
891
- server_name="0.0.0.0",
892
- server_port=7860,
893
- show_error=True,
894
- share=False
895
- )
 
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
8
  import io
 
9
  import torchvision.transforms.functional as TF
10
+ from torchvision import transforms
11
  import os
 
12
  import random
13
+ import urllib.request
14
+ import zipfile
15
+ import kagglehub
16
 
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
  model = None
19
+ DATASET_PATH = "brain_tumor_dataset"
20
 
21
+ # Your model classes (from previous code)
22
  class DoubleConv(nn.Module):
23
  def __init__(self, in_channels, out_channels):
24
  super(DoubleConv, self).__init__()
 
60
  x1 = self.W_x(x)
61
  psi = self.relu(g1 + x1)
62
  psi = self.psi(psi)
63
+ return x * psi, psi # Return both attended features and attention map
64
 
65
  class AttentionUNET(nn.Module):
66
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
 
71
  self.attentions = nn.ModuleList()
72
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
73
 
74
+ # Down part
75
  for feature in features:
76
  self.downs.append(DoubleConv(in_channels, feature))
77
  in_channels = feature
 
79
  # Bottleneck
80
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
81
 
82
+ # Up part
83
  for feature in reversed(features):
84
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
85
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
 
87
 
88
  self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
89
 
90
+ def forward(self, x):
91
  skip_connections = []
92
+ attention_maps = [] # To collect attention maps
93
 
94
  for down in self.downs:
95
  x = down(x)
 
106
  if x.shape != skip_connection.shape:
107
  x = TF.resize(x, size=skip_connection.shape[2:])
108
 
109
+ attended, attn_map = self.attentions[idx // 2](x, skip_connection) # Get attention map
110
+ attention_maps.append(attn_map)
 
111
 
112
+ concat_skip = torch.cat((attended, x), dim=1)
113
  x = self.ups[idx+1](concat_skip)
114
 
115
+ return self.final_conv(x), attention_maps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  def download_model():
118
  """Download your trained model from HuggingFace"""
 
120
  model_path = "best_attention_model.pth.tar"
121
 
122
  if not os.path.exists(model_path):
123
+ print("📥 Downloading your trained model...")
124
  try:
125
  urllib.request.urlretrieve(model_url, model_path)
126
  print("✅ Model downloaded successfully!")
127
  except Exception as e:
128
  print(f"❌ Failed to download model: {e}")
129
  return None
 
 
 
130
  return model_path
131
 
132
+ def load_model():
133
+ """Load your trained Attention U-Net model"""
134
  global model
135
  if model is None:
136
  try:
137
+ print("🔄 Loading your trained Attention U-Net model...")
138
 
139
+ # Download model if needed
140
  model_path = download_model()
141
  if model_path is None:
142
  return None
143
 
144
+ # Initialize your model architecture
145
  model = AttentionUNET(in_channels=1, out_channels=1).to(device)
146
+
147
+ # Load your trained weights
148
  checkpoint = torch.load(model_path, map_location=device, weights_only=True)
149
  model.load_state_dict(checkpoint["state_dict"])
150
  model.eval()
151
 
152
+ print("✅ Your Attention U-Net model loaded successfully!")
153
  except Exception as e:
154
+ print(f"❌ Error loading your model: {e}")
155
  model = None
156
  return model
157
 
158
+ def preprocess_image(image):
159
+ """Preprocessing like your Colab code"""
160
+ # Convert to grayscale
161
+ if image.mode != 'L':
162
+ image = image.convert('L')
163
 
164
+ # Use your exact transform
165
+ val_test_transform = transforms.Compose([
166
+ transforms.Resize((256,256)),
167
+ transforms.ToTensor()
168
+ ])
169
 
170
+ return val_test_transform(image).unsqueeze(0) # Add batch dimension
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ def post_process_mask(pred_mask_np):
173
+ """Post-processing with morphological operations (Novelty 1)"""
174
+ # Binarize
175
+ binary_mask = (pred_mask_np > 0.5).astype(np.uint8)
 
 
 
 
 
 
176
 
177
+ # Morphological opening to remove small noise
178
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
179
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
 
 
 
 
 
180
 
181
+ # Morphological closing to fill gaps
182
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
183
+
184
+ return binary_mask
185
+
186
+ def test_time_augmentation(input_tensor, model):
187
+ """Test-Time Augmentation (Novelty 2)"""
188
  predictions = []
189
 
190
+ # Original
191
+ pred, _ = model(input_tensor)
192
+ predictions.append(torch.sigmoid(pred))
 
 
 
 
 
 
 
 
 
193
 
194
+ # Horizontal flip
195
+ hflip = TF.hflip(input_tensor)
196
+ pred_h, _ = model(hflip)
197
+ pred_h = TF.hflip(pred_h)
198
+ predictions.append(torch.sigmoid(pred_h))
199
 
200
+ # Vertical flip
201
+ vflip = TF.vflip(input_tensor)
202
+ pred_v, _ = model(vflip)
203
+ pred_v = TF.vflip(pred_v)
204
+ predictions.append(torch.sigmoid(pred_v))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ # Average predictions
207
+ avg_pred = torch.mean(torch.stack(predictions), dim=0)
208
+ return avg_pred.squeeze().cpu().numpy()
209
+
210
+ def generate_attention_heatmap(attention_maps, size=(256,256)):
211
+ """Generate attention heatmap visualization (Novelty 3)"""
212
+ # Average attention maps from all levels
213
+ avg_attn = torch.mean(torch.cat([TF.resize(m, size) for m in attention_maps]), dim=0)
214
+ attn_np = avg_attn.squeeze().cpu().numpy()
215
 
216
+ # Normalize and apply colormap
217
+ attn_norm = (attn_np - attn_np.min()) / (attn_np.max() - attn_np.min() + 1e-8)
218
+ heatmap = plt.cm.hot(attn_norm)[:,:,:3] * 255
219
+ return heatmap.astype(np.uint8)
220
 
221
+ def download_dataset():
222
+ """Download and extract the dataset if not present"""
223
+ if not os.path.exists(DATASET_PATH):
224
+ print("📥 Downloading brain tumor dataset...")
225
+ try:
226
+ path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
227
+ print(f"Dataset downloaded to: {path}")
228
+
229
+ # Extract if zipped
230
+ for file in os.listdir(path):
231
+ if file.endswith('.zip'):
232
+ with zipfile.ZipFile(os.path.join(path, file), 'r') as zip_ref:
233
+ zip_ref.extractall(DATASET_PATH)
234
+ print("✅ Dataset extracted!")
235
+ return True
236
+ except Exception as e:
237
+ print(f"❌ Failed to download dataset: {e}")
238
+ return False
239
+ print("✅ Dataset already exists!")
240
+ return True
241
+
242
+ def get_random_sample():
243
+ """Get random image and mask from dataset"""
244
+ if not os.path.exists(DATASET_PATH):
245
+ if not download_dataset():
246
+ return None, None
247
+
248
+ images_path = os.path.join(DATASET_PATH, "images")
249
+ masks_path = os.path.join(DATASET_PATH, "masks")
250
+
251
+ image_files = [f for f in os.listdir(images_path) if f.endswith(('.png', '.jpg'))]
252
+
253
+ if not image_files:
254
+ return None, None
255
 
256
+ random_file = random.choice(image_files)
 
 
257
 
258
+ img_path = os.path.join(images_path, random_file)
259
+ mask_path = os.path.join(masks_path, random_file)
 
260
 
261
+ if not os.path.exists(mask_path):
262
+ return None, None
263
+
264
+ return Image.open(img_path), Image.open(mask_path)
265
 
266
+ def predict_tumor(image, use_tta=True, show_attention=True, is_dataset_sample=False, ground_truth=None):
267
+ current_model = load_your_attention_model()
 
268
 
269
  if current_model is None:
270
+ return None, "Failed to load your trained model."
271
 
272
  if image is None:
273
+ return None, "Please upload an image first."
274
 
275
  try:
276
+ print("Processing image...")
277
 
278
  input_tensor = preprocess_image(image).to(device)
279
 
280
+ # Use TTA if enabled
 
 
 
 
281
  if use_tta:
282
+ pred_np = test_time_augmentation(input_tensor, current_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  else:
284
+ pred, attn_maps = current_model(input_tensor)
285
+ pred_np = torch.sigmoid(pred).squeeze().cpu().numpy()
286
+ attn_maps = attn_maps if show_attention else None
287
+
288
+ # Post-processing
289
+ binary_mask = post_process_mask(pred_np)
290
+
291
+ # Generate attention heatmap if enabled
292
+ attention_heatmap = None
293
+ if show_attention and attn_maps:
294
+ attention_heatmap = generate_attention_heatmap(attn_maps)
295
+
296
+ # Create visualization
297
+ fig, axes = plt.subplots(1, 3 + int(show_attention) + int(is_dataset_sample and ground_truth is not None), figsize=(20, 5))
298
+ fig.suptitle('Brain Tumor Segmentation Results', fontsize=16)
299
+
300
+ # Original image
301
+ original_np = np.array(image.resize((256, 256)))
302
+ axes[0].imshow(original_np, cmap='gray')
303
+ axes[0].set_title('Original Image')
304
+ axes[0].axis('off')
305
+
306
+ # Predicted mask
307
+ axes[1].imshow(binary_mask * 255, cmap='gray')
308
+ axes[1].set_title('Predicted Mask')
309
+ axes[1].axis('off')
310
+
311
+ # Overlay
312
+ overlay = cv2.cvtColor(original_np, cv2.COLOR_GRAY2RGB) if len(original_np.shape) == 2 else original_np
313
+ overlay[binary_mask == 1] = [255, 0, 0] # Red for tumor
314
+ overlay = cv2.addWeighted(original_np, 0.7, overlay, 0.3, 0)
315
+ axes[2].imshow(overlay)
316
+ axes[2].set_title('Overlay')
317
+ axes[2].axis('off')
318
+
319
+ col = 3
320
+ if show_attention and attention_heatmap is not None:
321
+ axes[col].imshow(attention_heatmap)
322
+ axes[col].set_title('Attention Heatmap')
323
+ axes[col].axis('off')
324
+ col += 1
325
+
326
+ if is_dataset_sample and ground_truth is not None:
327
+ gt_np = np.array(ground_truth.resize((256, 256)))
328
+ axes[col].imshow(gt_np, cmap='gray')
329
+ axes[col].set_title('Ground Truth')
330
+ axes[col].axis('off')
331
 
332
  plt.tight_layout()
333
 
 
334
  buf = io.BytesIO()
335
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
336
  buf.seek(0)
337
  plt.close()
338
 
339
  result_image = Image.open(buf)
340
 
341
+ # Statistics
342
+ tumor_pixels = np.sum(binary_mask)
343
+ total_pixels = binary_mask.size
344
  tumor_percentage = (tumor_pixels / total_pixels) * 100
345
 
 
 
 
 
346
  analysis_text = f"""
347
+ ### Segmentation Statistics
348
+ - Tumor Area Percentage: {tumor_percentage:.2f}%
349
+ - Tumor Pixels: {tumor_pixels}
350
+ - Total Pixels: {total_pixels}
351
+ - TTA Used: {'Yes' if use_tta else 'No'}
352
+ - Attention Visualization: {'Yes' if show_attention else 'No'}
 
 
353
  """
354
+
355
+ if is_dataset_sample and ground_truth is not None:
356
+ gt_np = np.array(ground_truth.resize((256, 256)))
357
+ intersection = np.logical_and(binary_mask, gt_np > 0).sum()
358
+ union = np.logical_or(binary_mask, gt_np > 0).sum()
359
+ iou = intersection / (union + 1e-8)
360
+ dice = (2 * intersection) / (binary_mask.sum() + (gt_np > 0).sum() + 1e-8)
361
+
362
  analysis_text += f"""
363
+ ### Comparison with Ground Truth
364
+ - IoU Score: {iou:.4f}
365
+ - Dice Score: {dice:.4f}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  """
367
 
 
368
  return result_image, analysis_text
369
 
370
  except Exception as e:
371
+ return None, f"Error: {str(e)}"
 
 
372
 
373
+ def test_random_sample():
374
+ image, mask = get_random_sample()
 
375
  if image is None:
376
+ return None, "Failed to load dataset sample. Please download dataset first."
377
+ return predict_tumor(image, use_tta=True, show_attention=True, is_dataset_sample=True, ground_truth=mask)
378
 
379
+ # Custom CSS for professional, minimalist look
 
 
 
380
  css = """
381
+ body, .gradio-container { font-family: 'Arial', sans-serif; color: #333; }
382
+ h1, h2, h3, h4 { color: #2c3e50; font-weight: 500; }
383
+ .button { background-color: #3498db; color: white; border: none; border-radius: 4px; padding: 10px 20px; font-size: 16px; cursor: pointer; }
384
+ .button:hover { background-color: #2980b9; }
385
+ .card { border: 1px solid #e0e0e0; border-radius: 8px; padding: 20px; background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  """
387
 
388
+ with gr.Blocks(css=css, title="Brain Tumor Segmentation") as app:
389
+ gr.Markdown("# Brain Tumor Segmentation Using Attention U-Net")
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  with gr.Row():
392
  with gr.Column(scale=1):
393
+ gr.Markdown("### Input")
394
+ image_input = gr.Image(label="Upload Image", type="pil")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  with gr.Row():
397
+ predict_btn = gr.Button("Predict")
398
+ random_btn = gr.Button("Test Random Sample")
399
+ download_btn = gr.Button("Download Dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  with gr.Column(scale=2):
402
+ gr.Markdown("### Output")
403
+ output_image = gr.Image(label="Result")
404
+ analysis_output = gr.Textbox(label="Analysis", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  # Event handlers
407
+ predict_btn.click(
408
+ fn=predict_tumor,
409
+ inputs=[image_input],
410
+ outputs=[output_image, analysis_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  )
412
 
413
+ random_btn.click(
414
+ fn=test_random_sample,
415
  inputs=[],
416
+ outputs=[output_image, analysis_output]
 
417
  )
418
 
419
+ download_btn.click(
420
+ fn=download_dataset,
421
  inputs=[],
422
+ outputs=gr.Textbox(value="Dataset download status...")
423
  )
424
 
 
 
 
 
 
 
 
 
 
 
425
  if __name__ == "__main__":
426
+ app.launch()