ArchCoder commited on
Commit
82c9eec
·
verified ·
1 Parent(s): d57f983

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -54
app.py CHANGED
@@ -42,7 +42,6 @@ class DoubleConv(nn.Module):
42
  return self.conv(x)
43
 
44
 
45
- # Also, make sure your AttentionBlock.forward() returns the attention map:
46
  class AttentionBlock(nn.Module):
47
  def __init__(self, F_g, F_l, F_int):
48
  super(AttentionBlock, self).__init__()
@@ -238,30 +237,6 @@ def preprocess_for_model(image):
238
 
239
  return transform(image).unsqueeze(0)
240
 
241
- def apply_tta(model, input_tensor):
242
- """Test-Time Augmentation"""
243
- augmentations = [
244
- lambda x: x, # Original
245
- lambda x: TF.hflip(x), # Horizontal flip
246
- lambda x: TF.vflip(x), # Vertical flip
247
- ]
248
-
249
- predictions = []
250
- for i, aug in enumerate(augmentations):
251
- aug_input = aug(input_tensor)
252
- pred, _ = model(aug_input)
253
- pred = torch.sigmoid(pred)
254
-
255
- # Reverse augmentation
256
- if i == 1: # Reverse hflip
257
- pred = TF.hflip(pred)
258
- elif i == 2: # Reverse vflip
259
- pred = TF.vflip(pred)
260
-
261
- predictions.append(pred)
262
-
263
- return torch.mean(torch.stack(predictions), dim=0)
264
-
265
  def generate_attention_heatmap(attention_maps):
266
  """Generate attention heatmap - Fixed version"""
267
  if not attention_maps:
@@ -290,9 +265,8 @@ def generate_attention_heatmap(attention_maps):
290
 
291
  return heatmap
292
 
293
-
294
  def analyze_image(image, ground_truth, filename):
295
- """Main analysis function"""
296
  if model is None:
297
  return None, "Model not loaded. Please restart the application."
298
 
@@ -303,18 +277,22 @@ def analyze_image(image, ground_truth, filename):
303
  # Preprocess
304
  input_tensor = preprocess_for_model(image).to(device)
305
 
306
- # Apply TTA
307
  with torch.no_grad():
308
- avg_pred = apply_tta(model, input_tensor)
309
- _, attention_maps = model(input_tensor)
310
-
311
- # Get binary mask
312
- binary_mask = (avg_pred > 0.5).squeeze().cpu().numpy()
 
 
 
 
313
 
314
- # Post-processing
315
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
316
- binary_mask = cv2.morphologyEx(binary_mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
317
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
318
 
319
  # Generate attention heatmap
320
  att_heatmap = generate_attention_heatmap(attention_maps)
@@ -340,33 +318,36 @@ def analyze_image(image, ground_truth, filename):
340
 
341
  # Predicted mask
342
  if ground_truth is not None:
343
- axes[0,2].imshow(binary_mask, cmap='gray')
344
  axes[0,2].set_title('Predicted Mask', fontsize=12, weight='bold')
345
  axes[0,2].axis('off')
346
 
347
  # Ground truth
348
  gt_array = np.array(ground_truth.resize((256, 256)))
349
- axes[1,0].imshow(gt_array, cmap='gray')
 
 
 
350
  axes[1,0].set_title('Ground Truth Mask', fontsize=12, weight='bold')
351
  axes[1,0].axis('off')
352
 
353
  # Overlay comparison
354
  overlay = np.array(image.convert('RGB').resize((256, 256)))
355
- overlay[binary_mask > 0] = [0, 255, 0] # Green for prediction
356
- overlay[gt_array > 128] = [255, 0, 0] # Red for ground truth
357
  axes[1,1].imshow(overlay)
358
  axes[1,1].set_title('Prediction (Green) vs GT (Red)', fontsize=12, weight='bold')
359
  axes[1,1].axis('off')
360
 
361
- # Calculate IoU
362
- pred_binary = binary_mask > 0
363
- gt_binary = gt_array > 128
364
- intersection = np.sum(pred_binary & gt_binary)
365
- union = np.sum(pred_binary | gt_binary)
366
  iou = intersection / (union + 1e-8)
367
 
368
  # Dice score
369
- dice = (2 * intersection) / (np.sum(pred_binary) + np.sum(gt_binary) + 1e-8)
370
 
371
  axes[1,2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold')
372
  axes[1,2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold')
@@ -374,13 +355,13 @@ def analyze_image(image, ground_truth, filename):
374
  axes[1,2].set_ylim(0, 1)
375
  axes[1,2].axis('off')
376
  else:
377
- axes[1,0].imshow(binary_mask, cmap='gray')
378
  axes[1,0].set_title('Predicted Mask', fontsize=12, weight='bold')
379
  axes[1,0].axis('off')
380
 
381
  # Overlay
382
  overlay = np.array(image.convert('RGB').resize((256, 256)))
383
- overlay[binary_mask > 0] = [255, 0, 0]
384
  axes[1,1].imshow(overlay)
385
  axes[1,1].set_title('Prediction Overlay', fontsize=12, weight='bold')
386
  axes[1,1].axis('off')
@@ -396,8 +377,8 @@ def analyze_image(image, ground_truth, filename):
396
  result_image = Image.open(buf)
397
 
398
  # Generate analysis text
399
- tumor_pixels = np.sum(binary_mask)
400
- total_pixels = binary_mask.size
401
  tumor_percentage = (tumor_pixels / total_pixels) * 100
402
 
403
  analysis_text = f"""
@@ -410,7 +391,6 @@ def analyze_image(image, ground_truth, filename):
410
  - Tumor Pixels: {tumor_pixels:,}
411
 
412
  **Model Features:**
413
- - Test-Time Augmentation: Applied
414
  - Attention Visualization: Generated
415
  - Post-processing: Morphological cleanup
416
  """
@@ -425,7 +405,10 @@ def analyze_image(image, ground_truth, filename):
425
  return result_image, analysis_text
426
 
427
  except Exception as e:
428
- return None, f"Analysis failed: {str(e)}"
 
 
 
429
 
430
  # Initialize model and dataset at startup
431
  print("Initializing application components...")
@@ -475,7 +458,7 @@ with gr.Blocks(css=css, title="Brain Tumor Segmentation Analysis") as app:
475
 
476
  **Advanced Medical Image Analysis Tool**
477
 
478
- Features: Test-Time Augmentation, Attention Visualization, Dataset Integration
479
  """)
480
 
481
  # Status display
@@ -567,4 +550,4 @@ if __name__ == "__main__":
567
  server_port=7860,
568
  show_error=True,
569
  share=False
570
- )
 
42
  return self.conv(x)
43
 
44
 
 
45
  class AttentionBlock(nn.Module):
46
  def __init__(self, F_g, F_l, F_int):
47
  super(AttentionBlock, self).__init__()
 
237
 
238
  return transform(image).unsqueeze(0)
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  def generate_attention_heatmap(attention_maps):
241
  """Generate attention heatmap - Fixed version"""
242
  if not attention_maps:
 
265
 
266
  return heatmap
267
 
 
268
  def analyze_image(image, ground_truth, filename):
269
+ """Main analysis function - FIXED VERSION"""
270
  if model is None:
271
  return None, "Model not loaded. Please restart the application."
272
 
 
277
  # Preprocess
278
  input_tensor = preprocess_for_model(image).to(device)
279
 
280
+ # Get prediction and attention maps
281
  with torch.no_grad():
282
+ # Get model output (prediction + attention maps)
283
+ model_output, attention_maps = model(input_tensor)
284
+
285
+ # Apply sigmoid and threshold to get binary mask
286
+ pred_mask = torch.sigmoid(model_output)
287
+ binary_mask = (pred_mask > 0.5).float()
288
+
289
+ # Convert to numpy for further processing
290
+ binary_mask_np = binary_mask.squeeze().cpu().numpy()
291
 
292
+ # Post-processing (morphological operations)
293
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
294
+ binary_mask_np = cv2.morphologyEx(binary_mask_np.astype(np.uint8), cv2.MORPH_OPEN, kernel)
295
+ binary_mask_np = cv2.morphologyEx(binary_mask_np, cv2.MORPH_CLOSE, kernel)
296
 
297
  # Generate attention heatmap
298
  att_heatmap = generate_attention_heatmap(attention_maps)
 
318
 
319
  # Predicted mask
320
  if ground_truth is not None:
321
+ axes[0,2].imshow(binary_mask_np, cmap='gray')
322
  axes[0,2].set_title('Predicted Mask', fontsize=12, weight='bold')
323
  axes[0,2].axis('off')
324
 
325
  # Ground truth
326
  gt_array = np.array(ground_truth.resize((256, 256)))
327
+ # Normalize ground truth to binary (0 or 1)
328
+ gt_binary = (gt_array > 128).astype(np.uint8)
329
+
330
+ axes[1,0].imshow(gt_binary, cmap='gray')
331
  axes[1,0].set_title('Ground Truth Mask', fontsize=12, weight='bold')
332
  axes[1,0].axis('off')
333
 
334
  # Overlay comparison
335
  overlay = np.array(image.convert('RGB').resize((256, 256)))
336
+ overlay[binary_mask_np > 0] = [0, 255, 0] # Green for prediction
337
+ overlay[gt_binary > 0] = [255, 0, 0] # Red for ground truth
338
  axes[1,1].imshow(overlay)
339
  axes[1,1].set_title('Prediction (Green) vs GT (Red)', fontsize=12, weight='bold')
340
  axes[1,1].axis('off')
341
 
342
+ # Calculate IoU and Dice
343
+ pred_binary = binary_mask_np > 0
344
+ gt_binary_bool = gt_binary > 0
345
+ intersection = np.sum(pred_binary & gt_binary_bool)
346
+ union = np.sum(pred_binary | gt_binary_bool)
347
  iou = intersection / (union + 1e-8)
348
 
349
  # Dice score
350
+ dice = (2 * intersection) / (np.sum(pred_binary) + np.sum(gt_binary_bool) + 1e-8)
351
 
352
  axes[1,2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold')
353
  axes[1,2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold')
 
355
  axes[1,2].set_ylim(0, 1)
356
  axes[1,2].axis('off')
357
  else:
358
+ axes[1,0].imshow(binary_mask_np, cmap='gray')
359
  axes[1,0].set_title('Predicted Mask', fontsize=12, weight='bold')
360
  axes[1,0].axis('off')
361
 
362
  # Overlay
363
  overlay = np.array(image.convert('RGB').resize((256, 256)))
364
+ overlay[binary_mask_np > 0] = [255, 0, 0]
365
  axes[1,1].imshow(overlay)
366
  axes[1,1].set_title('Prediction Overlay', fontsize=12, weight='bold')
367
  axes[1,1].axis('off')
 
377
  result_image = Image.open(buf)
378
 
379
  # Generate analysis text
380
+ tumor_pixels = np.sum(binary_mask_np)
381
+ total_pixels = binary_mask_np.size
382
  tumor_percentage = (tumor_pixels / total_pixels) * 100
383
 
384
  analysis_text = f"""
 
391
  - Tumor Pixels: {tumor_pixels:,}
392
 
393
  **Model Features:**
 
394
  - Attention Visualization: Generated
395
  - Post-processing: Morphological cleanup
396
  """
 
405
  return result_image, analysis_text
406
 
407
  except Exception as e:
408
+ import traceback
409
+ error_msg = f"Analysis failed: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
410
+ print(error_msg) # For debugging
411
+ return None, error_msg
412
 
413
  # Initialize model and dataset at startup
414
  print("Initializing application components...")
 
458
 
459
  **Advanced Medical Image Analysis Tool**
460
 
461
+ Features: Attention Visualization, Dataset Integration, Morphological Post-processing
462
  """)
463
 
464
  # Status display
 
550
  server_port=7860,
551
  show_error=True,
552
  share=False
553
+ )