ArchCoder commited on
Commit
c530021
·
verified ·
1 Parent(s): fa07b35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +294 -198
app.py CHANGED
@@ -10,14 +10,22 @@ import torchvision.transforms as transforms
10
  import torchvision.transforms.functional as TF
11
  import random
12
  import os
13
- import zipfile
14
  import urllib.request
15
  import kagglehub
 
16
 
 
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
  model = None
 
 
 
19
 
20
- # Your Attention U-Net classes (from your code)
 
 
 
 
21
  class DoubleConv(nn.Module):
22
  def __init__(self, in_channels, out_channels):
23
  super(DoubleConv, self).__init__()
@@ -59,7 +67,7 @@ class AttentionBlock(nn.Module):
59
  x1 = self.W_x(x)
60
  psi = self.relu(g1 + x1)
61
  psi = self.psi(psi)
62
- return x * psi, psi # Return attention map as well
63
 
64
  class AttentionUNET(nn.Module):
65
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
@@ -70,15 +78,12 @@ class AttentionUNET(nn.Module):
70
  self.attentions = nn.ModuleList()
71
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
72
 
73
- # Down part
74
  for feature in features:
75
  self.downs.append(DoubleConv(in_channels, feature))
76
  in_channels = feature
77
 
78
- # Bottleneck
79
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
80
 
81
- # Up part
82
  for feature in reversed(features):
83
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
84
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
@@ -88,7 +93,7 @@ class AttentionUNET(nn.Module):
88
 
89
  def forward(self, x):
90
  skip_connections = []
91
- attention_maps = [] # To store attention maps
92
 
93
  for down in self.downs:
94
  x = down(x)
@@ -105,222 +110,265 @@ class AttentionUNET(nn.Module):
105
  if x.shape != skip_connection.shape:
106
  x = TF.resize(x, size=skip_connection.shape[2:])
107
 
108
- attended_skip, att_map = self.attentions[idx // 2](x, skip_connection) # Get attention map
109
- attention_maps.append(att_map) # Store attention map
110
  concat_skip = torch.cat((attended_skip, x), dim=1)
111
  x = self.ups[idx+1](concat_skip)
112
 
113
  return self.final_conv(x), attention_maps
114
 
115
- def download_model():
116
- """Download your trained model from HuggingFace"""
 
 
 
117
  model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
118
  model_path = "best_attention_model.pth.tar"
119
 
 
120
  if not os.path.exists(model_path):
121
- print("📥 Downloading your trained model...")
122
  try:
123
  urllib.request.urlretrieve(model_url, model_path)
124
- print("✅ Model downloaded successfully!")
125
- except Exception as e:
126
- print(f"❌ Failed to download model: {e}")
127
- return None
128
- return model_path
129
-
130
- def load_your_attention_model():
131
- """Load YOUR trained Attention U-Net model"""
132
- global model
133
- if model is None:
134
- try:
135
- print("🔄 Loading your trained Attention U-Net model...")
136
-
137
- # Download model if needed
138
- model_path = download_model()
139
- if model_path is None:
140
- return None
141
-
142
- # Initialize your model architecture
143
- model = AttentionUNET(in_channels=1, out_channels=1).to(device)
144
-
145
- # Load your trained weights
146
- checkpoint = torch.load(model_path, map_location=device, weights_only=True)
147
- model.load_state_dict(checkpoint["state_dict"])
148
- model.eval()
149
-
150
- print("✅ Your Attention U-Net model loaded successfully!")
151
  except Exception as e:
152
- print(f" Error loading your model: {e}")
153
- model = None
154
- return model
155
-
156
- def download_dataset():
157
- """Download and extract the dataset using kagglehub"""
158
- dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
159
 
160
- # Extract if it's a zip
161
- extracted_path = "brain_tumor_dataset"
162
- if not os.path.exists(extracted_path):
163
- with zipfile.ZipFile(dataset_path, 'r') as zip_ref:
164
- zip_ref.extractall(extracted_path)
 
 
 
 
 
 
 
 
 
 
165
 
166
- images_path = os.path.join(extracted_path, 'images')
167
- masks_path = os.path.join(extracted_path, 'masks')
 
 
168
 
169
- return images_path, masks_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- def load_random_sample():
172
- """Load a random image and mask from the dataset"""
173
- images_path, masks_path = download_dataset()
 
174
 
175
- image_files = [f for f in os.listdir(images_path) if f.endswith(('.png', '.jpg'))]
176
- if not image_files:
177
  return None, None, "No images found in dataset"
178
 
179
- random_file = random.choice(image_files)
180
- img_path = os.path.join(images_path, random_file)
181
- mask_path = os.path.join(masks_path, random_file)
182
 
183
- image = Image.open(img_path).convert("L")
184
- mask = Image.open(mask_path).convert("L") if os.path.exists(mask_path) else None
 
 
 
 
 
185
 
186
- return image, mask, random_file
 
 
 
 
 
187
 
188
- def preprocess_for_your_model(image):
189
- """Preprocessing exactly like your Colab code"""
190
  if image.mode != 'L':
191
  image = image.convert('L')
192
 
193
- val_test_transform = transforms.Compose([
194
  transforms.Resize((256,256)),
195
  transforms.ToTensor()
196
  ])
197
 
198
- return val_test_transform(image).unsqueeze(0)
199
 
200
  def apply_tta(model, input_tensor):
201
- """Test-Time Augmentation: Apply augmentations and average predictions"""
202
  augmentations = [
203
  lambda x: x, # Original
204
- lambda x: TF.rotate(x, 90), # 90 deg rotation
205
- lambda x: TF.rotate(x, -90), # -90 deg rotation
206
  lambda x: TF.hflip(x), # Horizontal flip
207
- lambda x: TF.vflip(x) # Vertical flip
208
  ]
209
 
210
  predictions = []
211
- for aug in augmentations:
212
  aug_input = aug(input_tensor)
213
- pred = torch.sigmoid(model(aug_input)[0]) # Get prediction
214
- # Reverse the augmentation for averaging
215
- if aug == augmentations[1]: # Reverse 90 deg
216
- pred = TF.rotate(pred, -90)
217
- elif aug == augmentations[2]: # Reverse -90 deg
218
- pred = TF.rotate(pred, 90)
219
- elif aug == augmentations[3]: # Reverse hflip
220
  pred = TF.hflip(pred)
221
- elif aug == augmentations[4]: # Reverse vflip
222
  pred = TF.vflip(pred)
 
223
  predictions.append(pred)
224
 
225
- # Average predictions
226
- avg_pred = torch.mean(torch.stack(predictions), dim=0)
227
- return avg_pred
228
 
229
  def generate_attention_heatmap(attention_maps):
230
- """Generate combined attention heatmap"""
231
  if not attention_maps:
232
- return np.zeros((256, 256))
233
 
234
- # Average attention maps from different levels
235
  combined_att = torch.mean(torch.stack(attention_maps), dim=0).squeeze().cpu().numpy()
236
  combined_att = cv2.resize(combined_att, (256, 256))
237
  combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8)
238
  heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET)
239
  return heatmap
240
 
241
- def predict_tumor(image, ground_truth=None, filename=None):
242
- current_model = load_your_attention_model()
 
 
243
 
244
- if current_model is None:
245
- return None, "Failed to load your trained model."
246
-
247
  if image is None:
248
- return None, "Please upload or load an image first."
249
 
250
  try:
251
  # Preprocess
252
- input_tensor = preprocess_for_your_model(image).to(device)
253
 
254
  # Apply TTA
255
- avg_pred = apply_tta(current_model, input_tensor)
 
 
256
 
257
  # Get binary mask
258
- binary_mask = (avg_pred > 0.5).float().squeeze().cpu().numpy()
259
 
260
  # Post-processing
261
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
262
  binary_mask = cv2.morphologyEx(binary_mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
263
  binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
264
 
265
- # Extract attention maps
266
- _, attention_maps = current_model(input_tensor)
267
  att_heatmap = generate_attention_heatmap(attention_maps)
268
 
269
  # Create visualization
270
- fig, axes = plt.subplots(2, 3, figsize=(18, 12))
271
- fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=20)
 
 
 
 
272
 
273
- # Original
274
  axes[0,0].imshow(image, cmap='gray')
275
- axes[0,0].set_title('Original Image')
276
  axes[0,0].axis('off')
277
 
278
- # Attention Heatmap
279
- axes[0,1].imshow(np.array(image), cmap='gray')
280
- axes[0,1].imshow(att_heatmap, alpha=0.5)
281
- axes[0,1].set_title('Attention Heatmap')
282
  axes[0,1].axis('off')
283
 
284
- # Predicted Mask
285
- axes[0,2].imshow(binary_mask, cmap='gray')
286
- axes[0,2].set_title('Predicted Mask')
287
- axes[0,2].axis('off')
288
-
289
- # Ground Truth (if available)
290
  if ground_truth is not None:
291
- gt_np = np.array(ground_truth.resize((256, 256)))
292
- axes[1,0].imshow(gt_np, cmap='gray')
293
- axes[1,0].set_title('Ground Truth Mask')
 
 
 
 
 
294
  axes[1,0].axis('off')
295
 
296
- # Comparison Overlay
297
- overlay = np.array(image.convert('RGB'))
298
  overlay[binary_mask > 0] = [0, 255, 0] # Green for prediction
299
- overlay[gt_np > 0] = [255, 0, 0] # Red for ground truth
300
  axes[1,1].imshow(overlay)
301
- axes[1,1].set_title('Prediction (Green) vs GT (Red)')
302
  axes[1,1].axis('off')
303
 
304
- # IoU Calculation
305
- intersection = np.sum(binary_mask * (gt_np > 0))
306
- union = np.sum(binary_mask) + np.sum(gt_np > 0) - intersection
 
 
307
  iou = intersection / (union + 1e-8)
308
 
309
- axes[1,2].text(0.1, 0.5, f'IoU Score: {iou:.4f}', fontsize=20)
 
 
 
 
 
 
310
  axes[1,2].axis('off')
311
  else:
312
- # Overlay for prediction only
313
- overlay = np.array(image.convert('RGB'))
314
- overlay[binary_mask > 0] = [255, 0, 0]
315
- axes[1,0].imshow(overlay)
316
- axes[1,0].set_title('Prediction Overlay')
317
  axes[1,0].axis('off')
318
 
 
 
 
 
 
319
  axes[1,1].axis('off')
320
- axes[1,2].axis('off')
321
 
322
  plt.tight_layout()
323
 
 
324
  buf = io.BytesIO()
325
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
326
  buf.seek(0)
@@ -328,128 +376,176 @@ def predict_tumor(image, ground_truth=None, filename=None):
328
 
329
  result_image = Image.open(buf)
330
 
331
- # Statistics
332
  tumor_pixels = np.sum(binary_mask)
333
  total_pixels = binary_mask.size
334
  tumor_percentage = (tumor_pixels / total_pixels) * 100
335
 
336
  analysis_text = f"""
337
- ## Brain Tumor Segmentation Results
 
 
338
 
339
- ### Detection Summary
340
- - Tumor Percentage: {tumor_percentage:.2f}%
341
- - Tumor Pixels: {tumor_pixels}
342
- - File: {filename if filename else 'Uploaded Image'}
343
 
344
- ### Model Information
345
- - Your Attention U-Net Model
346
  - Test-Time Augmentation: Applied
347
- - Attention Visualization: Included
348
- """
 
349
 
350
  if ground_truth is not None:
351
- analysis_text += f"\n- IoU with Ground Truth: {iou:.4f}"
 
 
 
 
352
 
353
  return result_image, analysis_text
354
 
355
  except Exception as e:
356
- return None, f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
357
 
358
- def clear_all():
359
- return None, None, None, "Upload or load an image for analysis"
360
 
361
- # Professional CSS (white, clean, professional)
362
  css = """
363
  .gradio-container {
364
- max-width: 1400px !important;
365
  margin: auto !important;
366
- background-color: white !important;
367
- font-family: 'Arial', sans-serif !important;
368
  }
369
- h1, h2, h3, h4 {
370
- color: #333333 !important;
 
371
  }
372
- button {
373
- background-color: #f0f0f0 !important;
374
- color: #333333 !important;
375
- border: 1px solid #dddddd !important;
376
- border-radius: 4px !important;
377
  }
378
- button.primary {
379
- background-color: #007bff !important;
380
- color: white !important;
381
  }
382
- .output-image {
383
- border: 1px solid #dddddd !important;
384
- border-radius: 4px !important;
385
  }
386
- .markdown {
387
- line-height: 1.6 !important;
388
- color: #555555 !important;
389
  }
390
  """
391
 
392
- # Create professional Gradio interface
393
- with gr.Blocks(css=css, title="Brain Tumor Segmentation Application") as app:
394
 
395
  gr.Markdown("""
396
  # Brain Tumor Segmentation Using Attention U-Net
397
- A professional tool for medical image analysis
 
 
 
398
  """)
399
 
 
 
 
 
 
 
 
 
400
  with gr.Row():
401
  with gr.Column(scale=1):
402
  gr.Markdown("### Input Selection")
403
 
404
- image_input = gr.Image(
405
- label="Upload Brain MRI",
 
406
  type="pil",
407
- sources=["upload", "webcam"],
408
  height=300
409
  )
410
 
411
- load_random_btn = gr.Button("Load Random Sample from Dataset", variant="primary")
412
-
413
  with gr.Row():
414
- analyze_btn = gr.Button("Analyze Image", variant="primary", scale=2)
415
- clear_btn = gr.Button("Clear", scale=1)
 
 
 
 
 
 
 
 
 
 
416
 
417
  with gr.Column(scale=2):
418
  gr.Markdown("### Analysis Results")
419
 
420
- output_image = gr.Image(
421
- label="Segmentation Results",
422
  type="pil",
423
- height=400
424
  )
425
 
426
- analysis_output = gr.Markdown(
427
- value="Select an input method to begin analysis."
428
  )
429
 
430
- # Hidden state for ground truth and filename
431
- ground_truth_state = gr.State()
432
- filename_state = gr.State()
433
-
434
  # Event handlers
435
- analyze_btn.click(
436
- fn=predict_tumor,
437
- inputs=[image_input, ground_truth_state, filename_state],
438
- outputs=[output_image, analysis_output]
 
 
 
 
 
 
 
 
 
439
  )
440
 
441
- load_random_btn.click(
442
- fn=load_random_sample,
443
- inputs=[],
444
- outputs=[image_input, ground_truth_state, filename_state, analysis_output]
445
  )
446
 
447
- clear_btn.click(
448
- fn=clear_all,
449
- inputs=[],
450
- outputs=[image_input, output_image, ground_truth_state, analysis_output]
451
  )
452
 
453
  if __name__ == "__main__":
454
- print("Starting Brain Tumor Segmentation Application...")
455
- app.launch()
 
 
 
 
 
 
 
 
 
10
  import torchvision.transforms.functional as TF
11
  import random
12
  import os
 
13
  import urllib.request
14
  import kagglehub
15
+ from glob import glob
16
 
17
+ # Global variables - loaded once at startup
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
  model = None
20
+ dataset_images = []
21
+ dataset_masks = []
22
+ dataset_loaded = False
23
 
24
+ print("="*50)
25
+ print("BRAIN TUMOR SEGMENTATION APPLICATION")
26
+ print("="*50)
27
+
28
+ # Your Attention U-Net classes (unchanged)
29
  class DoubleConv(nn.Module):
30
  def __init__(self, in_channels, out_channels):
31
  super(DoubleConv, self).__init__()
 
67
  x1 = self.W_x(x)
68
  psi = self.relu(g1 + x1)
69
  psi = self.psi(psi)
70
+ return x * psi, psi
71
 
72
  class AttentionUNET(nn.Module):
73
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
 
78
  self.attentions = nn.ModuleList()
79
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
80
 
 
81
  for feature in features:
82
  self.downs.append(DoubleConv(in_channels, feature))
83
  in_channels = feature
84
 
 
85
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
86
 
 
87
  for feature in reversed(features):
88
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
89
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
 
93
 
94
  def forward(self, x):
95
  skip_connections = []
96
+ attention_maps = []
97
 
98
  for down in self.downs:
99
  x = down(x)
 
110
  if x.shape != skip_connection.shape:
111
  x = TF.resize(x, size=skip_connection.shape[2:])
112
 
113
+ attended_skip, att_map = self.attentions[idx // 2](x, skip_connection)
114
+ attention_maps.append(att_map)
115
  concat_skip = torch.cat((attended_skip, x), dim=1)
116
  x = self.ups[idx+1](concat_skip)
117
 
118
  return self.final_conv(x), attention_maps
119
 
120
+ def download_and_load_model():
121
+ """Download and load model once at startup"""
122
+ global model
123
+ print("Loading Attention U-Net model...")
124
+
125
  model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
126
  model_path = "best_attention_model.pth.tar"
127
 
128
+ # Download model if needed
129
  if not os.path.exists(model_path):
130
+ print("Downloading model weights...")
131
  try:
132
  urllib.request.urlretrieve(model_url, model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  except Exception as e:
134
+ print(f"Failed to download model: {e}")
135
+ return False
 
 
 
 
 
136
 
137
+ # Load model
138
+ try:
139
+ model = AttentionUNET(in_channels=1, out_channels=1).to(device)
140
+ checkpoint = torch.load(model_path, map_location=device, weights_only=True)
141
+ model.load_state_dict(checkpoint["state_dict"])
142
+ model.eval()
143
+ print("✓ Model loaded successfully!")
144
+ return True
145
+ except Exception as e:
146
+ print(f"Failed to load model: {e}")
147
+ return False
148
+
149
+ def download_and_load_dataset():
150
+ """Download and load entire dataset once at startup"""
151
+ global dataset_images, dataset_masks, dataset_loaded
152
 
153
+ if dataset_loaded:
154
+ return True
155
+
156
+ print("Loading brain tumor dataset...")
157
 
158
+ try:
159
+ # Download dataset using kagglehub - returns directory path
160
+ dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
161
+ print(f"Dataset downloaded to: {dataset_path}")
162
+
163
+ # Find images and masks directories
164
+ images_dir = os.path.join(dataset_path, 'images')
165
+ masks_dir = os.path.join(dataset_path, 'masks')
166
+
167
+ # If direct path doesn't exist, search subdirectories
168
+ if not os.path.exists(images_dir):
169
+ # Search for images and masks directories
170
+ for root, dirs, files in os.walk(dataset_path):
171
+ if 'images' in dirs:
172
+ images_dir = os.path.join(root, 'images')
173
+ if 'masks' in dirs:
174
+ masks_dir = os.path.join(root, 'masks')
175
+
176
+ if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
177
+ print("Could not find images/masks directories. Searching all files...")
178
+ # Fallback: find all image files
179
+ all_files = glob(os.path.join(dataset_path, "**/*.png"), recursive=True) + \
180
+ glob(os.path.join(dataset_path, "**/*.jpg"), recursive=True)
181
+
182
+ dataset_images = [f for f in all_files if '/images/' in f or 'image' in f.lower()]
183
+ dataset_masks = [f for f in all_files if '/masks/' in f or 'mask' in f.lower()]
184
+ else:
185
+ # Load image and mask file paths
186
+ dataset_images = glob(os.path.join(images_dir, "*.*"))
187
+ dataset_masks = glob(os.path.join(masks_dir, "*.*"))
188
+
189
+ dataset_images = sorted(dataset_images)
190
+ dataset_masks = sorted(dataset_masks)
191
+
192
+ print(f"✓ Found {len(dataset_images)} images and {len(dataset_masks)} masks")
193
+ dataset_loaded = True
194
+ return True
195
+
196
+ except Exception as e:
197
+ print(f"Failed to load dataset: {e}")
198
+ return False
199
 
200
+ def get_random_sample():
201
+ """Get a random image and corresponding mask from dataset"""
202
+ if not dataset_loaded:
203
+ return None, None, "Dataset not loaded"
204
 
205
+ if not dataset_images:
 
206
  return None, None, "No images found in dataset"
207
 
208
+ # Get random index
209
+ idx = random.randint(0, len(dataset_images) - 1)
210
+ img_path = dataset_images[idx]
211
 
212
+ # Find corresponding mask
213
+ img_name = os.path.basename(img_path)
214
+ mask_path = None
215
+ for mask in dataset_masks:
216
+ if os.path.basename(mask) == img_name:
217
+ mask_path = mask
218
+ break
219
 
220
+ try:
221
+ image = Image.open(img_path).convert("L")
222
+ mask = Image.open(mask_path).convert("L") if mask_path else None
223
+ return image, mask, img_name
224
+ except Exception as e:
225
+ return None, None, f"Error loading sample: {e}"
226
 
227
+ def preprocess_for_model(image):
228
+ """Preprocessing for your model"""
229
  if image.mode != 'L':
230
  image = image.convert('L')
231
 
232
+ transform = transforms.Compose([
233
  transforms.Resize((256,256)),
234
  transforms.ToTensor()
235
  ])
236
 
237
+ return transform(image).unsqueeze(0)
238
 
239
  def apply_tta(model, input_tensor):
240
+ """Test-Time Augmentation"""
241
  augmentations = [
242
  lambda x: x, # Original
 
 
243
  lambda x: TF.hflip(x), # Horizontal flip
244
+ lambda x: TF.vflip(x), # Vertical flip
245
  ]
246
 
247
  predictions = []
248
+ for i, aug in enumerate(augmentations):
249
  aug_input = aug(input_tensor)
250
+ pred, _ = model(aug_input)
251
+ pred = torch.sigmoid(pred)
252
+
253
+ # Reverse augmentation
254
+ if i == 1: # Reverse hflip
 
 
255
  pred = TF.hflip(pred)
256
+ elif i == 2: # Reverse vflip
257
  pred = TF.vflip(pred)
258
+
259
  predictions.append(pred)
260
 
261
+ return torch.mean(torch.stack(predictions), dim=0)
 
 
262
 
263
  def generate_attention_heatmap(attention_maps):
264
+ """Generate attention heatmap"""
265
  if not attention_maps:
266
+ return np.zeros((256, 256, 3))
267
 
268
+ # Combine attention maps
269
  combined_att = torch.mean(torch.stack(attention_maps), dim=0).squeeze().cpu().numpy()
270
  combined_att = cv2.resize(combined_att, (256, 256))
271
  combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8)
272
  heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET)
273
  return heatmap
274
 
275
+ def analyze_image(image, ground_truth, filename):
276
+ """Main analysis function"""
277
+ if model is None:
278
+ return None, "Model not loaded. Please restart the application."
279
 
 
 
 
280
  if image is None:
281
+ return None, "Please select an image first."
282
 
283
  try:
284
  # Preprocess
285
+ input_tensor = preprocess_for_model(image).to(device)
286
 
287
  # Apply TTA
288
+ with torch.no_grad():
289
+ avg_pred = apply_tta(model, input_tensor)
290
+ _, attention_maps = model(input_tensor)
291
 
292
  # Get binary mask
293
+ binary_mask = (avg_pred > 0.5).squeeze().cpu().numpy()
294
 
295
  # Post-processing
296
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
297
  binary_mask = cv2.morphologyEx(binary_mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
298
  binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
299
 
300
+ # Generate attention heatmap
 
301
  att_heatmap = generate_attention_heatmap(attention_maps)
302
 
303
  # Create visualization
304
+ if ground_truth is not None:
305
+ fig, axes = plt.subplots(2, 3, figsize=(15, 10))
306
+ else:
307
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
308
+
309
+ fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=16, weight='bold')
310
 
311
+ # Original image
312
  axes[0,0].imshow(image, cmap='gray')
313
+ axes[0,0].set_title('Original Image', fontsize=12, weight='bold')
314
  axes[0,0].axis('off')
315
 
316
+ # Attention heatmap
317
+ axes[0,1].imshow(image, cmap='gray')
318
+ axes[0,1].imshow(att_heatmap, alpha=0.4)
319
+ axes[0,1].set_title('Attention Heatmap', fontsize=12, weight='bold')
320
  axes[0,1].axis('off')
321
 
322
+ # Predicted mask
 
 
 
 
 
323
  if ground_truth is not None:
324
+ axes[0,2].imshow(binary_mask, cmap='gray')
325
+ axes[0,2].set_title('Predicted Mask', fontsize=12, weight='bold')
326
+ axes[0,2].axis('off')
327
+
328
+ # Ground truth
329
+ gt_array = np.array(ground_truth.resize((256, 256)))
330
+ axes[1,0].imshow(gt_array, cmap='gray')
331
+ axes[1,0].set_title('Ground Truth Mask', fontsize=12, weight='bold')
332
  axes[1,0].axis('off')
333
 
334
+ # Overlay comparison
335
+ overlay = np.array(image.convert('RGB').resize((256, 256)))
336
  overlay[binary_mask > 0] = [0, 255, 0] # Green for prediction
337
+ overlay[gt_array > 128] = [255, 0, 0] # Red for ground truth
338
  axes[1,1].imshow(overlay)
339
+ axes[1,1].set_title('Prediction (Green) vs GT (Red)', fontsize=12, weight='bold')
340
  axes[1,1].axis('off')
341
 
342
+ # Calculate IoU
343
+ pred_binary = binary_mask > 0
344
+ gt_binary = gt_array > 128
345
+ intersection = np.sum(pred_binary & gt_binary)
346
+ union = np.sum(pred_binary | gt_binary)
347
  iou = intersection / (union + 1e-8)
348
 
349
+ # Dice score
350
+ dice = (2 * intersection) / (np.sum(pred_binary) + np.sum(gt_binary) + 1e-8)
351
+
352
+ axes[1,2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold')
353
+ axes[1,2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold')
354
+ axes[1,2].set_xlim(0, 1)
355
+ axes[1,2].set_ylim(0, 1)
356
  axes[1,2].axis('off')
357
  else:
358
+ axes[1,0].imshow(binary_mask, cmap='gray')
359
+ axes[1,0].set_title('Predicted Mask', fontsize=12, weight='bold')
 
 
 
360
  axes[1,0].axis('off')
361
 
362
+ # Overlay
363
+ overlay = np.array(image.convert('RGB').resize((256, 256)))
364
+ overlay[binary_mask > 0] = [255, 0, 0]
365
+ axes[1,1].imshow(overlay)
366
+ axes[1,1].set_title('Prediction Overlay', fontsize=12, weight='bold')
367
  axes[1,1].axis('off')
 
368
 
369
  plt.tight_layout()
370
 
371
+ # Save plot
372
  buf = io.BytesIO()
373
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
374
  buf.seek(0)
 
376
 
377
  result_image = Image.open(buf)
378
 
379
+ # Generate analysis text
380
  tumor_pixels = np.sum(binary_mask)
381
  total_pixels = binary_mask.size
382
  tumor_percentage = (tumor_pixels / total_pixels) * 100
383
 
384
  analysis_text = f"""
385
+ # Analysis Results
386
+
387
+ **File:** {filename if filename else 'Uploaded Image'}
388
 
389
+ **Tumor Detection:**
390
+ - Tumor Area: {tumor_percentage:.2f}%
391
+ - Tumor Pixels: {tumor_pixels:,}
 
392
 
393
+ **Model Features:**
 
394
  - Test-Time Augmentation: Applied
395
+ - Attention Visualization: Generated
396
+ - Post-processing: Morphological cleanup
397
+ """
398
 
399
  if ground_truth is not None:
400
+ analysis_text += f"""
401
+ **Performance Metrics:**
402
+ - IoU Score: {iou:.4f}
403
+ - Dice Score: {dice:.4f}
404
+ """
405
 
406
  return result_image, analysis_text
407
 
408
  except Exception as e:
409
+ return None, f"Analysis failed: {str(e)}"
410
+
411
+ # Initialize model and dataset at startup
412
+ print("Initializing application components...")
413
+ model_loaded = download_and_load_model()
414
+ dataset_loaded_success = download_and_load_dataset()
415
+
416
+ if not model_loaded:
417
+ print("WARNING: Model failed to load!")
418
+ if not dataset_loaded_success:
419
+ print("WARNING: Dataset failed to load!")
420
 
421
+ print("Application ready!")
 
422
 
423
+ # Professional CSS
424
  css = """
425
  .gradio-container {
426
+ max-width: 1600px !important;
427
  margin: auto !important;
428
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
 
429
  }
430
+ .gr-button {
431
+ border-radius: 6px !important;
432
+ font-weight: 500 !important;
433
  }
434
+ .gr-button-primary {
435
+ background: #2563eb !important;
436
+ border-color: #2563eb !important;
 
 
437
  }
438
+ .gr-button-secondary {
439
+ background: #6b7280 !important;
440
+ border-color: #6b7280 !important;
441
  }
442
+ h1, h2, h3 {
443
+ color: #1f2937 !important;
 
444
  }
445
+ .gr-form {
446
+ border: 1px solid #e5e7eb !important;
447
+ border-radius: 8px !important;
448
  }
449
  """
450
 
451
+ # Create Gradio interface
452
+ with gr.Blocks(css=css, title="Brain Tumor Segmentation Analysis") as app:
453
 
454
  gr.Markdown("""
455
  # Brain Tumor Segmentation Using Attention U-Net
456
+
457
+ **Advanced Medical Image Analysis Tool**
458
+
459
+ Features: Test-Time Augmentation, Attention Visualization, Dataset Integration
460
  """)
461
 
462
+ # Status display
463
+ with gr.Row():
464
+ with gr.Column():
465
+ status_text = f"Model Status: {'✓ Loaded' if model_loaded else '✗ Failed'} | Dataset Status: {'✓ Loaded' if dataset_loaded_success else '✗ Failed'}"
466
+ if dataset_loaded_success:
467
+ status_text += f" | Images: {len(dataset_images)} | Masks: {len(dataset_masks)}"
468
+ gr.Markdown(f"**{status_text}**")
469
+
470
  with gr.Row():
471
  with gr.Column(scale=1):
472
  gr.Markdown("### Input Selection")
473
 
474
+ # Image display
475
+ image_display = gr.Image(
476
+ label="Selected Image",
477
  type="pil",
 
478
  height=300
479
  )
480
 
481
+ # Control buttons
 
482
  with gr.Row():
483
+ load_sample_btn = gr.Button("Load Random Sample", variant="primary", scale=1)
484
+ upload_btn = gr.UploadButton("Upload Image", file_types=["image"], scale=1)
485
+
486
+ analyze_btn = gr.Button("Analyze Image", variant="primary", size="lg")
487
+
488
+ # Dataset info
489
+ gr.Markdown(f"""
490
+ **Dataset Information:**
491
+ - Total Images: {len(dataset_images) if dataset_loaded_success else 'N/A'}
492
+ - Total Masks: {len(dataset_masks) if dataset_loaded_success else 'N/A'}
493
+ - Source: nikhilroxtomar/brain-tumor-segmentation
494
+ """)
495
 
496
  with gr.Column(scale=2):
497
  gr.Markdown("### Analysis Results")
498
 
499
+ result_display = gr.Image(
500
+ label="Segmentation Analysis",
501
  type="pil",
502
+ height=500
503
  )
504
 
505
+ analysis_text = gr.Markdown(
506
+ value="Load an image and click 'Analyze Image' to begin."
507
  )
508
 
509
+ # Hidden states
510
+ current_ground_truth = gr.State()
511
+ current_filename = gr.State()
512
+
513
  # Event handlers
514
+ def handle_sample_load():
515
+ image, mask, filename = get_random_sample()
516
+ return image, mask, filename
517
+
518
+ def handle_upload(file):
519
+ if file is not None:
520
+ image = Image.open(file.name).convert("L")
521
+ return image, None, os.path.basename(file.name)
522
+ return None, None, ""
523
+
524
+ load_sample_btn.click(
525
+ fn=handle_sample_load,
526
+ outputs=[image_display, current_ground_truth, current_filename]
527
  )
528
 
529
+ upload_btn.upload(
530
+ fn=handle_upload,
531
+ inputs=[upload_btn],
532
+ outputs=[image_display, current_ground_truth, current_filename]
533
  )
534
 
535
+ analyze_btn.click(
536
+ fn=analyze_image,
537
+ inputs=[image_display, current_ground_truth, current_filename],
538
+ outputs=[result_display, analysis_text]
539
  )
540
 
541
  if __name__ == "__main__":
542
+ print("\n" + "="*50)
543
+ print("LAUNCHING BRAIN TUMOR SEGMENTATION APPLICATION")
544
+ print("="*50)
545
+
546
+ app.launch(
547
+ server_name="0.0.0.0",
548
+ server_port=7860,
549
+ show_error=True,
550
+ share=False
551
+ )